Repository: aiming-lab/AutoResearchClaw Branch: main Commit: 258dae2bb28f Files: 422 Total size: 4.1 MB Directory structure: gitextract_tp1xyq09/ ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config.researchclaw.example.yaml ├── docs/ │ ├── BUG_FIX_DOCUMENT_20260316.md │ ├── BUG_TRACKER.md │ ├── CHANGELOG_ANTHROPIC_ADAPTER.md │ ├── PIPELINE_TEST_LOG_R5.md │ ├── README_AR.md │ ├── README_CN.md │ ├── README_DE.md │ ├── README_ES.md │ ├── README_FR.md │ ├── README_JA.md │ ├── README_KO.md │ ├── README_PT.md │ ├── README_RU.md │ ├── TESTER_GUIDE.md │ ├── TESTER_GUIDE_CN.md │ ├── TESTER_GUIDE_JA.md │ ├── agent_figure_and_benchmark_plan.md │ ├── figure_prompts/ │ │ ├── case_a_meta_learning.md │ │ └── case_b_rlhf_alignment.md │ ├── integration-guide.md │ ├── issue_tracker_v9.md │ ├── iteration_plan_v8.md │ ├── iteration_showcase_narrative.md │ ├── metaclaw-integration-plan.md │ ├── next_phase_showcase_plan.md │ ├── pipeline_critical_fixes_v8.md │ ├── rate_limit_fix_plan.md │ ├── sandbox_environment_fix_plan.md │ └── showcase/ │ └── SHOWCASE.md ├── prompts.default.yaml ├── pyproject.toml ├── researchclaw/ │ ├── __init__.py │ ├── __main__.py │ ├── adapters.py │ ├── agents/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── benchmark_agent/ │ │ │ ├── __init__.py │ │ │ ├── acquirer.py │ │ │ ├── orchestrator.py │ │ │ ├── selector.py │ │ │ ├── surveyor.py │ │ │ └── validator.py │ │ ├── code_searcher/ │ │ │ ├── __init__.py │ │ │ ├── agent.py │ │ │ ├── cache.py │ │ │ ├── github_client.py │ │ │ ├── pattern_extractor.py │ │ │ └── query_gen.py │ │ └── figure_agent/ │ │ ├── __init__.py │ │ ├── codegen.py │ │ ├── critic.py │ │ ├── decision.py │ │ ├── integrator.py │ │ ├── nano_banana.py │ │ ├── orchestrator.py │ │ ├── planner.py │ │ ├── renderer.py │ │ └── style_config.py │ ├── assessor/ │ │ ├── __init__.py │ │ ├── comparator.py │ │ ├── rubrics.py │ │ ├── scorer.py │ │ └── venue_recommender.py │ ├── calendar/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ └── conferences.yaml │ │ ├── deadlines.py │ │ ├── planner.py │ │ └── reminder.py │ ├── cli.py │ ├── collaboration/ │ │ ├── __init__.py │ │ ├── dedup.py │ │ ├── publisher.py │ │ ├── repository.py │ │ └── subscriber.py │ ├── config.py │ ├── copilot/ │ │ ├── __init__.py │ │ ├── branching.py │ │ ├── controller.py │ │ ├── feedback.py │ │ └── modes.py │ ├── dashboard/ │ │ ├── __init__.py │ │ ├── broadcaster.py │ │ ├── collector.py │ │ └── metrics.py │ ├── data/ │ │ ├── __init__.py │ │ ├── benchmark_knowledge.yaml │ │ ├── dataset_registry.yaml │ │ ├── docker_profiles.yaml │ │ ├── framework_docs/ │ │ │ ├── axolotl.md │ │ │ ├── llamafactory.md │ │ │ ├── peft.md │ │ │ ├── transformers_training.md │ │ │ └── trl.md │ │ └── seminal_papers.yaml │ ├── docker/ │ │ ├── Dockerfile │ │ ├── Dockerfile.biology │ │ ├── Dockerfile.chemistry │ │ ├── Dockerfile.economics │ │ ├── Dockerfile.generic │ │ ├── Dockerfile.math │ │ ├── Dockerfile.physics │ │ └── entrypoint.sh │ ├── domains/ │ │ ├── __init__.py │ │ ├── adapters/ │ │ │ ├── __init__.py │ │ │ ├── biology.py │ │ │ ├── chemistry.py │ │ │ ├── economics.py │ │ │ ├── generic.py │ │ │ ├── math.py │ │ │ ├── ml.py │ │ │ ├── neuroscience.py │ │ │ ├── physics.py │ │ │ ├── robotics.py │ │ │ └── security.py │ │ ├── detector.py │ │ ├── experiment_schema.py │ │ ├── profiles/ │ │ │ ├── _generic.yaml │ │ │ ├── biology_genomics.yaml │ │ │ ├── biology_protein.yaml │ │ │ ├── biology_singlecell.yaml │ │ │ ├── chemistry_molprop.yaml │ │ │ ├── chemistry_qm.yaml │ │ │ ├── economics_empirical.yaml │ │ │ ├── mathematics_numerical.yaml │ │ │ ├── mathematics_optimization.yaml │ │ │ ├── ml_compression.yaml │ │ │ ├── ml_generative.yaml │ │ │ ├── ml_generic.yaml │ │ │ ├── ml_graph.yaml │ │ │ ├── ml_nlp.yaml │ │ │ ├── ml_rl.yaml │ │ │ ├── ml_tabular.yaml │ │ │ ├── ml_vision.yaml │ │ │ ├── neuroscience_computational.yaml │ │ │ ├── neuroscience_imaging.yaml │ │ │ ├── physics_pde.yaml │ │ │ ├── physics_quantum.yaml │ │ │ ├── physics_simulation.yaml │ │ │ ├── robotics_control.yaml │ │ │ └── security_detection.yaml │ │ └── prompt_adapter.py │ ├── evolution.py │ ├── experiment/ │ │ ├── __init__.py │ │ ├── agentic_sandbox.py │ │ ├── code_agent.py │ │ ├── colab_sandbox.py │ │ ├── docker_sandbox.py │ │ ├── evaluators/ │ │ │ ├── __init__.py │ │ │ └── convergence.py │ │ ├── factory.py │ │ ├── git_manager.py │ │ ├── harness_template.py │ │ ├── metrics.py │ │ ├── runner.py │ │ ├── sandbox.py │ │ ├── ssh_sandbox.py │ │ ├── validator.py │ │ └── visualize.py │ ├── feedback/ │ │ └── FEEDBACK_ANALYSIS_PROMPT.md │ ├── hardware.py │ ├── health.py │ ├── knowledge/ │ │ ├── __init__.py │ │ ├── base.py │ │ └── graph/ │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── entities.py │ │ ├── query.py │ │ ├── relations.py │ │ └── visualizer.py │ ├── literature/ │ │ ├── __init__.py │ │ ├── arxiv_client.py │ │ ├── cache.py │ │ ├── models.py │ │ ├── novelty.py │ │ ├── openalex_client.py │ │ ├── search.py │ │ ├── semantic_scholar.py │ │ ├── trends.py │ │ └── verify.py │ ├── llm/ │ │ ├── __init__.py │ │ ├── acp_client.py │ │ ├── anthropic_adapter.py │ │ └── client.py │ ├── mcp/ │ │ ├── __init__.py │ │ ├── client.py │ │ ├── registry.py │ │ ├── server.py │ │ ├── tools.py │ │ └── transport.py │ ├── memory/ │ │ ├── __init__.py │ │ ├── decay.py │ │ ├── embeddings.py │ │ ├── experiment_memory.py │ │ ├── ideation_memory.py │ │ ├── retriever.py │ │ ├── store.py │ │ └── writing_memory.py │ ├── metaclaw_bridge/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── lesson_to_skill.py │ │ ├── prm_gate.py │ │ ├── session.py │ │ ├── skill_feedback.py │ │ └── stage_skill_map.py │ ├── overleaf/ │ │ ├── __init__.py │ │ ├── conflict.py │ │ ├── formatter.py │ │ ├── sync.py │ │ └── watcher.py │ ├── pipeline/ │ │ ├── __init__.py │ │ ├── _domain.py │ │ ├── _helpers.py │ │ ├── code_agent.py │ │ ├── contracts.py │ │ ├── executor.py │ │ ├── experiment_diagnosis.py │ │ ├── experiment_repair.py │ │ ├── opencode_bridge.py │ │ ├── paper_verifier.py │ │ ├── runner.py │ │ ├── stage_impls/ │ │ │ ├── __init__.py │ │ │ ├── _analysis.py │ │ │ ├── _code_generation.py │ │ │ ├── _execution.py │ │ │ ├── _experiment_design.py │ │ │ ├── _literature.py │ │ │ ├── _paper_writing.py │ │ │ ├── _review_publish.py │ │ │ ├── _synthesis.py │ │ │ └── _topic.py │ │ ├── stages.py │ │ └── verified_registry.py │ ├── project/ │ │ ├── __init__.py │ │ ├── idea_pool.py │ │ ├── manager.py │ │ ├── models.py │ │ └── scheduler.py │ ├── prompts.py │ ├── quality.py │ ├── report.py │ ├── server/ │ │ ├── __init__.py │ │ ├── app.py │ │ ├── dialog/ │ │ │ ├── __init__.py │ │ │ ├── intents.py │ │ │ ├── router.py │ │ │ └── session.py │ │ ├── middleware/ │ │ │ ├── __init__.py │ │ │ └── auth.py │ │ ├── routes/ │ │ │ ├── __init__.py │ │ │ ├── chat.py │ │ │ ├── pipeline.py │ │ │ ├── projects.py │ │ │ └── voice.py │ │ └── websocket/ │ │ ├── __init__.py │ │ ├── events.py │ │ └── manager.py │ ├── servers/ │ │ ├── __init__.py │ │ ├── cloud_executor.py │ │ ├── dispatcher.py │ │ ├── monitor.py │ │ ├── registry.py │ │ ├── slurm_executor.py │ │ └── ssh_executor.py │ ├── skills/ │ │ ├── __init__.py │ │ ├── builtin/ │ │ │ ├── __init__.py │ │ │ ├── domain/ │ │ │ │ ├── cv-classification/ │ │ │ │ │ └── SKILL.md │ │ │ │ ├── cv-detection/ │ │ │ │ │ └── SKILL.md │ │ │ │ ├── nlp-alignment/ │ │ │ │ │ └── SKILL.md │ │ │ │ ├── nlp-pretraining/ │ │ │ │ │ └── SKILL.md │ │ │ │ └── rl-policy-optimization/ │ │ │ │ └── SKILL.md │ │ │ ├── experiment/ │ │ │ │ ├── experimental-design/ │ │ │ │ │ └── SKILL.md │ │ │ │ ├── meta-analysis/ │ │ │ │ │ └── SKILL.md │ │ │ │ └── systematic-review/ │ │ │ │ └── SKILL.md │ │ │ └── tooling/ │ │ │ ├── data-loading/ │ │ │ │ └── SKILL.md │ │ │ ├── distributed-training/ │ │ │ │ └── SKILL.md │ │ │ ├── mixed-precision/ │ │ │ │ └── SKILL.md │ │ │ └── pytorch-training/ │ │ │ └── SKILL.md │ │ ├── loader.py │ │ ├── matcher.py │ │ ├── registry.py │ │ └── schema.py │ ├── templates/ │ │ ├── __init__.py │ │ ├── compiler.py │ │ ├── conference.py │ │ ├── converter.py │ │ ├── results_table_builder.py │ │ └── styles/ │ │ ├── iclr_2025/ │ │ │ ├── iclr2025_conference.bst │ │ │ └── iclr2025_conference.sty │ │ ├── iclr_2026/ │ │ │ ├── iclr2026_conference.bst │ │ │ └── iclr2026_conference.sty │ │ ├── icml_2025/ │ │ │ ├── icml2025.bst │ │ │ └── icml2025.sty │ │ ├── icml_2026/ │ │ │ ├── icml2026.bst │ │ │ └── icml2026.sty │ │ ├── neurips_2024/ │ │ │ └── neurips_2024.sty │ │ └── neurips_2025/ │ │ └── neurips_2025.sty │ ├── trends/ │ │ ├── __init__.py │ │ ├── auto_topic.py │ │ ├── daily_digest.py │ │ ├── feeds.py │ │ ├── opportunity_finder.py │ │ └── trend_analyzer.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── sanitize.py │ │ └── thinking_tags.py │ ├── voice/ │ │ ├── __init__.py │ │ ├── commands.py │ │ ├── synthesizer.py │ │ └── transcriber.py │ ├── web/ │ │ ├── __init__.py │ │ ├── _ssrf.py │ │ ├── agent.py │ │ ├── crawler.py │ │ ├── pdf_extractor.py │ │ ├── scholar.py │ │ └── search.py │ ├── wizard/ │ │ ├── __init__.py │ │ ├── quickstart.py │ │ ├── templates.py │ │ └── validator.py │ └── writing_guide.py ├── scripts/ │ ├── metaclaw_start.sh │ ├── plot_iteration_showcase.py │ ├── test_beast_mode_e2e.py │ ├── test_code_agent_live.py │ ├── test_code_agent_sandbox.py │ └── test_codegen_v2.py ├── sentinel.sh ├── tests/ │ ├── __init__.py │ ├── conftest.py │ ├── e2e_docker_sandbox.py │ ├── e2e_real_llm.py │ ├── test_anthropic.py │ ├── test_assessor.py │ ├── test_benchmark_agent.py │ ├── test_calendar.py │ ├── test_cli.py │ ├── test_code_agent.py │ ├── test_code_searcher.py │ ├── test_collaboration.py │ ├── test_compiler.py │ ├── test_convergence_evaluator.py │ ├── test_copilot.py │ ├── test_decision_agent.py │ ├── test_domain_detector.py │ ├── test_entry_point_validation.py │ ├── test_experiment_diagnosis.py │ ├── test_experiment_repair.py │ ├── test_experiment_schema.py │ ├── test_figure_agent.py │ ├── test_knowledge_graph.py │ ├── test_mcp.py │ ├── test_memory_system.py │ ├── test_metaclaw_bridge/ │ │ ├── __init__.py │ │ ├── test_config.py │ │ ├── test_lesson_to_skill.py │ │ ├── test_prm_gate.py │ │ ├── test_session.py │ │ ├── test_skill_feedback.py │ │ └── test_stage_skill_map.py │ ├── test_metric_parser.py │ ├── test_minimax_provider.py │ ├── test_neuroscience_domain.py │ ├── test_opencode_bridge.py │ ├── test_overleaf.py │ ├── test_paper_verifier.py │ ├── test_project_manager.py │ ├── test_prompt_adapter.py │ ├── test_rc_adapters.py │ ├── test_rc_cache.py │ ├── test_rc_checkpoint.py │ ├── test_rc_citation_resolve.py │ ├── test_rc_citation_verify.py │ ├── test_rc_cli.py │ ├── test_rc_config.py │ ├── test_rc_contracts.py │ ├── test_rc_docker_sandbox.py │ ├── test_rc_e2e_regression.py │ ├── test_rc_evolution.py │ ├── test_rc_executor.py │ ├── test_rc_hardware.py │ ├── test_rc_health.py │ ├── test_rc_kb.py │ ├── test_rc_literature.py │ ├── test_rc_llm.py │ ├── test_rc_novelty.py │ ├── test_rc_preflight.py │ ├── test_rc_prompts.py │ ├── test_rc_quality.py │ ├── test_rc_report.py │ ├── test_rc_runner.py │ ├── test_rc_sanitization.py │ ├── test_rc_sentinel.py │ ├── test_rc_stages.py │ ├── test_rc_templates.py │ ├── test_rc_validator.py │ ├── test_results_table_builder.py │ ├── test_robotics_adapter.py │ ├── test_servers.py │ ├── test_skills_library.py │ ├── test_ssh_and_colab_sandbox.py │ ├── test_trends.py │ ├── test_universal_codegen_integration.py │ ├── test_v6_improvements.py │ ├── test_verified_registry.py │ ├── test_web_crawler.py │ ├── test_web_integration.py │ ├── test_web_pdf_extractor.py │ ├── test_web_platform.py │ ├── test_web_scholar.py │ └── test_web_search.py └── website/ ├── features.html ├── getting-started.html ├── index.html ├── papers.html ├── pipeline.html └── style.css ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ HANDOFF_METACLAW_SKILL_LOOP.md .venv/ __pycache__/ *.pyc *.egg-info/ dist/ build/ workspaces/ .claude/* !.claude/agents/ !.claude/agents/*.md !.claude/skills/ !.claude/skills/**/SKILL.md .claude/settings.local.json # Experiment run artifacts (local only) artifacts/ output/ experiment_metaclaw/ promotional/ # Legacy experiment artifacts (pre-v5) exp/ logs/ writing/ # Root-level config (local overrides, not committed) /config.yaml # Sensitive / credentials user_token_cache.json *.secret .env .env.* config_run*.yaml # Literature search cache .researchclaw_cache/ # Playwright MCP logs .playwright-mcp/ # Internal dev/debug docs (not for public) docs/internal/ docs/kb/ docs/plans/ docs/BUGFIX_TRACKER*.md docs/IMPROVEMENT_PLAN*.md docs/IMPROVEMENT_*_EXECUTION.md docs/OPTIMIZATION_PLAN*.md docs/MULTI_CASE_EVALUATION*.md docs/pipeline_quality_issues*.md docs/autobench-loop.md RESEARCHCLAW_AGENTS.md RESEARCHCLAW_CLAUDE.md # Task-specific config files (keep example template only) config_agent_*.yaml config_case*.yaml config_v8_case*.yaml pipeline_run_*.log benchmarks/ # Logo generation prompts image/logo_prompt.md # macOS .DS_Store run.log # Misc temp files .history/ .serena/ cli_pause 暂停 进入 连续失败 重试一次 .venv_arc/ /config.arc.yaml config_*.yaml # Frontend (local dev only) frontend/ # Test outputs and run logs (local only) test_outputs*/ records/ run*_full_log.txt mdpdf.log scripts/md2pdf.py # Local docs (not for public) docs/tasks/ docs/feature_expansion_analysis.* docs/tester_guide_cn.* ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to AutoResearchClaw ## Setup 1. Fork and clone the repo 2. Create a venv and install with dev extras: ``` python3 -m venv .venv && source .venv/bin/activate pip install -e ".[dev]" ``` 3. Generate your local config: ``` researchclaw init ``` 4. Edit `config.arc.yaml` with your LLM settings ## Config Convention - `config.researchclaw.example.yaml` — tracked template (do not add secrets) - `config.arc.yaml` — your local config (gitignored, created by `researchclaw init`) - `config.yaml` — also gitignored, supported as fallback ## Running Tests ``` pytest tests/ ``` ## Checking Your Environment ``` researchclaw doctor ``` ## PR Guidelines - Branch from main - One concern per PR - Ensure `pytest tests/` passes - Include tests for new functionality ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2026 Aiming Lab Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================
Just chat with OpenClaw: "Research X" → done.
🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 Paper Showcase · 📖 Integration Guide · 💬 Discord Community
---|
|
🏆 Generated Paper Showcase 8 papers across 8 domains — math, statistics, biology, computing, NLP, RL, vision, robustness — generated fully autonomously with zero human intervention. |
| 📄 | paper_draft.md | Full academic paper (Introduction, Related Work, Method, Experiments, Results, Conclusion) |
| 📐 | paper.tex | Conference-ready LaTeX (NeurIPS / ICLR / ICML templates) |
| 📚 | references.bib | Real BibTeX references from OpenAlex, Semantic Scholar and arXiv — auto-pruned to match inline citations |
| 🔍 | verification_report.json | 4-layer citation integrity + relevance verification (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | Generated code + sandbox results + structured JSON metrics |
| 📊 | charts/ | Auto-generated condition comparison charts with error bars and confidence intervals |
| 📝 | reviews.md | Multi-agent peer review with methodology-evidence consistency checks |
| 🧬 | evolution/ | Self-learning lessons extracted from each run |
| 📦 | deliverables/ | All final outputs in one folder — compile-ready for Overleaf |
Built with 🦞 by the AutoResearchClaw team
================================================ FILE: config.researchclaw.example.yaml ================================================ project: name: "my-research" mode: "full-auto" research: topic: "Your research topic here" domains: - "machine-learning" daily_paper_count: 10 quality_threshold: 4.0 runtime: timezone: "America/New_York" max_parallel_tasks: 3 approval_timeout_hours: 12 retry_limit: 2 notifications: channel: "console" target: "" on_stage_start: true on_stage_fail: true on_gate_required: true knowledge_base: backend: "markdown" root: "docs/kb" openclaw_bridge: use_cron: false use_message: false use_memory: false use_sessions_spawn: false use_web_fetch: false use_browser: false llm: provider: "openai-compatible" base_url: "https://api.openai.com/v1" api_key_env: "OPENAI_API_KEY" api_key: "" primary_model: "gpt-4o" fallback_models: - "gpt-4.1" - "gpt-4o-mini" # --- MiniMax provider example --- # provider: "minimax" # api_key_env: "MINIMAX_API_KEY" # primary_model: "MiniMax-M2.5" # fallback_models: # - "MiniMax-M2.5-highspeed" security: hitl_required_stages: [5, 9, 20] allow_publish_without_approval: false redact_sensitive_logs: true experiment: # ★ mode 决定实验结果的真实性 # "sandbox" — 在本地沙盒中实际执行生成的 Python 代码,产出真实实验数据 # "docker" — 在 Docker 容器中执行,支持 GPU 直通、依赖自动安装、内存隔离 # "simulated" — 不执行代码,使用公式生成假数据(仅用于框架开发调试,不应用于论文生成) mode: "sandbox" time_budget_sec: 300 max_iterations: 10 metric_key: "primary_metric" metric_direction: "minimize" sandbox: # Use ".venv/Scripts/python.exe" on Windows python_path: ".venv/bin/python3" gpu_required: false max_memory_mb: 4096 # Docker sandbox settings (only used when mode: "docker") # Build image first: docker build -t researchclaw/experiment:latest researchclaw/docker/ docker: image: "researchclaw/experiment:latest" gpu_enabled: true # gpu_device_ids: [0] # empty = all GPUs memory_limit_mb: 8192 network_policy: "setup_only" # none | setup_only | pip_only | full # pip_pre_install: ["torchdiffeq", "einops"] auto_install_deps: true shm_size_mb: 2048 keep_containers: false ssh_remote: host: "" # SSH hostname or IP user: "" # SSH username (default: current user) port: 22 # SSH port key_path: "" # Path to private key (default: ~/.ssh/id_rsa) gpu_ids: [] # e.g. [0, 1] for CUDA_VISIBLE_DEVICES remote_workdir: "/tmp/researchclaw_experiments" remote_python: "python3" setup_commands: [] # e.g. ["source ~/venv/bin/activate", "pip install torch"] # Docker-over-SSH (most secure remote execution) use_docker: false # Set true to run experiments inside Docker on remote host docker_image: "researchclaw/experiment:latest" docker_network_policy: "none" # none | full docker_memory_limit_mb: 8192 docker_shm_size_mb: 2048 # OpenCode Beast Mode — external AI coding agent for complex experiments # Install: npm i -g opencode-ai@latest (or use `researchclaw setup`) opencode: enabled: true # Master switch (default: true) auto: true # Auto-trigger without confirmation (default: true) complexity_threshold: 0.2 # 0.0-1.0 — higher = only trigger on complex experiments model: "" # Override model (empty = use llm.primary_model) timeout_sec: 600 # Max seconds for OpenCode generation max_retries: 1 # Retry count on failure workspace_cleanup: true # Remove temp workspace after collection # ============================================================================ # SSH Remote Examples # ============================================================================ # # 1. Lab server (bare Python, basic sandboxing): # experiment: # mode: "ssh_remote" # ssh_remote: # host: "gpu-server.lab.edu" # user: "researcher" # key_path: "~/.ssh/id_rsa" # gpu_ids: [0] # remote_python: "python3" # # 2. Lab server (Docker — most secure): # experiment: # mode: "ssh_remote" # ssh_remote: # host: "gpu-server.lab.edu" # user: "researcher" # key_path: "~/.ssh/id_rsa" # gpu_ids: [0] # use_docker: true # docker_image: "researchclaw/experiment:latest" # docker_network_policy: "none" # # 3. Colab via SSH tunnel: # experiment: # mode: "ssh_remote" # ssh_remote: # host: "localhost" # port: 12345 # user: "root" # remote_python: "python3" # setup_commands: # - "pip install torch torchvision -q" # # 4. Colab via Google Drive (most robust, no SSH needed): # experiment: # mode: "colab_drive" # colab_drive: # drive_root: "~/Library/CloudStorage/GoogleDrive-you@gmail.com/My Drive/researchclaw" # poll_interval_sec: 30 # timeout_sec: 3600 # setup_script: "pip install torch torchvision -q" # # Then in Colab: run the colab_worker.py that appears in your Drive colab_drive: drive_root: "" # Local path to Google Drive mount poll_interval_sec: 30 # How often to check for results timeout_sec: 3600 # Max wait per experiment (1 hour) setup_script: "" # Shell commands to run before each experiment # Scientific Visualization Agent (Code-to-Viz + Nano Banana) # Uses a Decision Agent to analyze paper content and determine: # - Code figures (bar charts, line plots) → Matplotlib/TikZ # - Image figures (architecture, flowcharts) → Gemini Nano Banana figure_agent: enabled: true min_figures: 3 max_figures: 10 max_iterations: 3 render_timeout_sec: 30 # Security: Docker sandbox for visualization code execution # use_docker: null # null = auto-detect, true = force, false = disable docker_image: "researchclaw/experiment:latest" # Output format: "python" (Matplotlib/Seaborn) or "latex" (TikZ/PGFPlots) output_format: "python" # Nano Banana (Gemini native image generation) nano_banana_enabled: true # gemini_api_key: "" # or set GEMINI_API_KEY env var gemini_model: "gemini-2.5-flash-image" strict_mode: false dpi: 300 # === Prompts === # Customize LLM prompts by pointing to your own YAML file. # Copy prompts.default.yaml, edit the prompts you want, and set the path here. prompts: custom_file: "" # e.g. "my_prompts.yaml" (empty = use built-in defaults) # === MetaClaw Integration === # Enable the MetaClaw bridge to get skill injection, PRM quality gates, # and continuous learning from research pipeline failures. # Requires MetaClaw to be running: metaclaw start --mode skills_only metaclaw_bridge: enabled: false proxy_url: "http://localhost:30000" # MetaClaw proxy endpoint skills_dir: "~/.metaclaw/skills" # MetaClaw skills directory fallback_url: "" # Direct LLM URL if proxy is down fallback_api_key: "" # PRM quality gate: LLM-as-judge scoring at gate stages prm: enabled: false api_base: "" # OpenAI-compatible API for PRM judge api_key_env: "PRM_API_KEY" api_key: "" model: "gpt-5.4" votes: 3 # Majority vote count gate_stages: [5, 9, 15, 20] # Stages to apply PRM gating # Lesson-to-skill: auto-convert pipeline failures into MetaClaw skills lesson_to_skill: enabled: true min_severity: "error" # Only convert error-level lessons max_skills_per_run: 3 ================================================ FILE: docs/BUG_FIX_DOCUMENT_20260316.md ================================================ # Bug Fix Document — AutoResearchClaw Pipeline > 生成日期:2026-03-16 > 反馈来源:2 位测试者(user1: CV 方向 / GPU 环境, user2: Windows 环境) > 总计问题:9 个 ## 📊 总览 | 分类 | 数量 | |------|------| | 🔴 确认的 Bug(需修复) | **4** | | 🟠 架构改进(强烈建议) | **2** | | 🔵 功能需求 | **3** | ## 🔥 修复优先级 | 优先级 | ID | 问题 | 阶段 | 涉及文件 | |--------|----|------|------|----------| | 🔴 CRITICAL | BUG-001 | 论文硬件信息与实际不一致 | PAPER_DRAFT (17) | `executor.py`, `prompts.py` | | 🔴 CRITICAL | BUG-002 | Windows 环境 Docker 不可用导致实验链式失败 | EXPERIMENT_RUN (12) | `factory.py`, `docker_sandbox.py` | | 🔴 HIGH | BUG-003 | 论文内容自相矛盾(承诺评测但未执行) | PAPER_DRAFT (17), PEER_REVIEW (18) | `executor.py`, `prompts.py` | | 🔴 HIGH | BUG-004 | 生成代码缺少数值稳定性防护(NaN/Inf) | CODE_GENERATION (10) | `code_agent.py`, `prompts.py` | | 🟠 HIGH | ARCH-001 | Stage 17 过于严格的 hard block 策略 | PAPER_DRAFT (17) | `executor.py` | | 🟠 HIGH | ARCH-002 | Idea 降级时不询问用户确认 | EXPERIMENT_DESIGN (9), RESEARCH_DECISION (15) | `executor.py`, `stages.py` | --- ## 确认的 Bug — 详细修复方案 ### 🔴 `BUG-001` — 论文硬件信息与实际机器不一致 | 字段 | 内容 | |------|------| | **严重程度** | CRITICAL | | **所属阶段** | PAPER_DRAFT (Stage 17) | | **报告者** | user1 | **问题描述:** 论文中声称使用 A100 GPU 训练,但测试者实际机器上是 A5000。Pipeline 在 Stage 1 检测了硬件并保存到 `hardware_profile.json`,但在论文生成阶段完全没有利用这个信息来约束 LLM 输出。 **根因分析:** - `executor.py` 第 1226-1233 行:Stage 1 (TOPIC_INIT) 检测硬件,保存 `hardware_profile.json`,包含 `gpu_name`、`vram_gb` 等 - `executor.py` 第 2352-2391 行:硬件信息 **仅** 用于 CODE_GENERATION 阶段的代码生成 hints - `executor.py` 第 5776-5848 行:PAPER_DRAFT 阶段构建 prompt 时,**没有注入硬件 profile 信息** - LLM 在缺少约束的情况下会「幻觉」出常见的高端硬件名称(如 A100) **涉及文件:** - `researchclaw/pipeline/executor.py`(PAPER_DRAFT 阶段的 prompt 构建部分,约第 5776-5960 行) - `researchclaw/prompts.py`(paper writing prompt 模板) **修复方案:** 1. 在 PAPER_DRAFT 阶段的 prompt 构建中,读取 `stage-01/hardware_profile.json` 2. 将实际硬件信息(GPU 型号、VRAM、CPU 等)作为 **硬性约束** 注入 prompt,例如: ``` HARDWARE CONSTRAINT: The experiments were run on the following hardware: - GPU: {gpu_name} ({vram_gb} GB VRAM) - CPU: {cpu_info} You MUST use this exact hardware specification in the paper. Do NOT substitute with other GPU models. ``` 3. 在 PEER_REVIEW (Stage 18) 的 prompt 中增加一条审核规则:验证 paper 中提到的硬件是否与 `hardware_profile.json` 一致 **修复后预期行为:** 论文中的硬件描述必须与实际运行环境一致。
تحدث مع OpenClaw: «ابحث عن X» → تمّ.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 معرض الأوراق · 📖 دليل التكامل · 💬 مجتمع Discord
---|
|
🏆 معرض الأوراق المُولّدة 8 أوراق في 8 مجالات — الرياضيات، الإحصاء، الأحياء، الحوسبة، NLP، RL، الرؤية الحاسوبية، المتانة — مُولّدة بشكل مستقل تماماً بدون تدخل بشري. |
| 📄 | paper_draft.md | ورقة أكاديمية كاملة (مقدمة، أعمال سابقة، المنهجية، التجارب، النتائج، الخاتمة) |
| 📐 | paper.tex | LaTeX جاهز للمؤتمرات (قوالب NeurIPS / ICLR / ICML) |
| 📚 | references.bib | مراجع BibTeX حقيقية من OpenAlex و Semantic Scholar و arXiv — مُنقّحة تلقائياً لمطابقة الاستشهادات المضمّنة |
| 🔍 | verification_report.json | تحقق من سلامة الاستشهادات على 4 طبقات + التحقق من الصلة (arXiv، CrossRef، DataCite، LLM) |
| 🧪 | experiment runs/ | كود مُولّد + نتائج البيئة المعزولة + مقاييس JSON منظمة |
| 📊 | charts/ | رسوم بيانية مُولّدة تلقائياً لمقارنة الظروف مع أشرطة الخطأ وفترات الثقة |
| 📝 | reviews.md | مراجعة أقران متعددة الوكلاء مع فحص اتساق المنهجية والأدلة |
| 🧬 | evolution/ | دروس تعلّم ذاتي مستخلصة من كل تشغيل |
| 📦 | deliverables/ | جميع المخرجات النهائية في مجلد واحد — جاهزة للترجمة على Overleaf |
بُني بـ 🦞 بواسطة فريق AutoResearchClaw
================================================ FILE: docs/README_CN.md ================================================
直接与 OpenClaw 对话:"研究 X" → 搞定。
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 论文展示 · 📖 集成指南 · 💬 Discord 社区
---|
|
🏆 生成论文展示 8 篇论文覆盖 8 个领域 — 数学、统计、生物、计算、NLP、RL、视觉、鲁棒性 — 完全自主生成,零人工干预。 |
| 📄 | paper_draft.md | 完整学术论文(引言、相关工作、方法、实验、结果、结论) |
| 📐 | paper.tex | 适配顶会模板的 LaTeX 文件(NeurIPS / ICLR / ICML) |
| 📚 | references.bib | 来自 OpenAlex、Semantic Scholar 和 arXiv 的真实 BibTeX 引用——自动精简至与正文引用一致 |
| 🔍 | verification_report.json | 四层引用完整性 + 相关性核查(arXiv、CrossRef、DataCite、LLM) |
| 🧪 | experiment runs/ | 生成的代码 + 沙箱结果 + 结构化 JSON 指标 |
| 📊 | charts/ | 自动生成的条件对比图(含误差线和置信区间) |
| 📝 | reviews.md | 多 Agent 同行评审(含方法论-证据一致性检查) |
| 🧬 | evolution/ | 从每次运行中提取的自学习教训 |
| 📦 | deliverables/ | 所有最终产出集中在一个文件夹——可直接上传 Overleaf 编译 |
Built with 🦞 by the AutoResearchClaw team
================================================ FILE: docs/README_DE.md ================================================
Einfach mit OpenClaw chatten: "Research X" → erledigt.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 Paper-Showcase · 📖 Integrationsanleitung · 💬 Discord-Community
---|
|
🏆 Showcase generierter Paper 8 Paper aus 8 Disziplinen — Mathematik, Statistik, Biologie, Informatik, NLP, RL, Vision, Robustheit — vollstaendig autonom generiert ohne menschliches Eingreifen. |
| 📄 | paper_draft.md | Vollstaendiges wissenschaftliches Paper (Einleitung, Verwandte Arbeiten, Methode, Experimente, Ergebnisse, Fazit) |
| 📐 | paper.tex | Konferenzfertiges LaTeX (NeurIPS / ICLR / ICML Templates) |
| 📚 | references.bib | Echte BibTeX-Referenzen von OpenAlex, Semantic Scholar und arXiv — automatisch bereinigt, um Inline-Zitationen zu entsprechen |
| 🔍 | verification_report.json | 4-Schicht-Zitationsintegritaets- und Relevanzpruefung (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | Generierter Code + Sandbox-Ergebnisse + strukturierte JSON-Metriken |
| 📊 | charts/ | Automatisch generierte Vergleichsdiagramme mit Fehlerbalken und Konfidenzintervallen |
| 📝 | reviews.md | Multi-Agenten-Peer-Review mit Methodik-Evidenz-Konsistenzpruefungen |
| 🧬 | evolution/ | Selbstlernende Erkenntnisse aus jedem Durchlauf |
| 📦 | deliverables/ | Alle finalen Ergebnisse in einem Ordner — kompilierbereit fuer Overleaf |
Gebaut mit 🦞 vom AutoResearchClaw-Team
================================================ FILE: docs/README_ES.md ================================================
Chatea con OpenClaw: "Investiga X" → hecho.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 Galeria de articulos · 📖 Guia de integracion · 💬 Comunidad Discord
---| 📄 | paper_draft.md | Articulo academico completo (Introduccion, Trabajo relacionado, Metodo, Experimentos, Resultados, Conclusion) |
| 📐 | paper.tex | LaTeX listo para conferencia (plantillas NeurIPS / ICLR / ICML) |
| 📚 | references.bib | Referencias BibTeX reales de OpenAlex, Semantic Scholar y arXiv — auto-depuradas para coincidir con las citas en linea |
| 🔍 | verification_report.json | Verificacion de integridad + relevancia de citas en 4 capas (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | Codigo generado + resultados en sandbox + metricas JSON estructuradas |
| 📊 | charts/ | Graficos de comparacion de condiciones auto-generados con barras de error e intervalos de confianza |
| 📝 | reviews.md | Revision por pares multi-agente con verificacion de consistencia metodologia-evidencia |
| 🧬 | evolution/ | Lecciones de auto-aprendizaje extraidas de cada ejecucion |
| 📦 | deliverables/ | Todos los entregables finales en una sola carpeta — listos para compilar en Overleaf |
Construido con 🦞 por el equipo de AutoResearchClaw
================================================ FILE: docs/README_FR.md ================================================
Discutez avec OpenClaw : "Recherche X" → termine.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 Vitrine des articles · 📖 Guide d'integration · 💬 Communaute Discord
---| 📄 | paper_draft.md | Article academique complet (Introduction, Travaux connexes, Methode, Experiences, Resultats, Conclusion) |
| 📐 | paper.tex | LaTeX pret pour les conferences (templates NeurIPS / ICLR / ICML) |
| 📚 | references.bib | References BibTeX reelles provenant d'OpenAlex, Semantic Scholar et arXiv — auto-elaguees pour correspondre aux citations dans le texte |
| 🔍 | verification_report.json | Verification d'integrite et de pertinence des citations sur 4 couches (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | Code genere + resultats sandbox + metriques JSON structurees |
| 📊 | charts/ | Graphiques de comparaison de conditions auto-generes avec barres d'erreur et intervalles de confiance |
| 📝 | reviews.md | Relecture multi-agents avec verification de coherence methodologie-preuves |
| 🧬 | evolution/ | Lecons d'auto-apprentissage extraites de chaque execution |
| 📦 | deliverables/ | Tous les livrables finaux dans un seul dossier — pret a compiler pour Overleaf |
Construit avec 🦞 par l'equipe AutoResearchClaw
================================================ FILE: docs/README_JA.md ================================================
OpenClaw にチャットするだけ:「Xを研究して」→ 完了。
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 論文ショーケース · 📖 統合ガイド · 💬 Discordコミュニティ
---|
|
🏆 生成論文ショーケース 8つの分野にわたる8本の論文 — 数学、統計、生物学、コンピューティング、NLP、RL、ビジョン、ロバスト性 — 人間の介入なしに完全自律生成。 |
| 📄 | paper_draft.md | 完全な学術論文(序論、関連研究、手法、実験、結果、結論) |
| 📐 | paper.tex | 学会対応LaTeX(NeurIPS / ICLR / ICMLテンプレート) |
| 📚 | references.bib | OpenAlex、Semantic Scholar、arXivからの実際のBibTeX参考文献 — 本文中の引用に合わせて自動整理 |
| 🔍 | verification_report.json | 4層の引用整合性 + 関連性検証(arXiv、CrossRef、DataCite、LLM) |
| 🧪 | experiment runs/ | 生成されたコード + サンドボックス実行結果 + 構造化JSONメトリクス |
| 📊 | charts/ | 誤差棒と信頼区間付きの条件比較チャートを自動生成 |
| 📝 | reviews.md | 手法-証拠の一貫性チェック付きマルチエージェント査読 |
| 🧬 | evolution/ | 各実行から抽出された自己学習の教訓 |
| 📦 | deliverables/ | すべての最終成果物を1フォルダに集約 — Overleafですぐにコンパイル可能 |
Built with 🦞 by the AutoResearchClaw team
================================================ FILE: docs/README_KO.md ================================================
OpenClaw에 채팅하세요: "X 연구해줘" → 완료.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 논문 쇼케이스 · 📖 통합 가이드 · 💬 Discord 커뮤니티
---|
|
🏆 생성된 논문 쇼케이스 8개 분야에 걸친 8편의 논문 — 수학, 통계, 생물학, 컴퓨팅, NLP, RL, 비전, 견고성 — 인간 개입 없이 완전 자율 생성. |
| 📄 | paper_draft.md | 완성된 학술 논문 (서론, 관련 연구, 방법론, 실험, 결과, 결론) |
| 📐 | paper.tex | 학회 제출용 LaTeX (NeurIPS / ICLR / ICML 템플릿) |
| 📚 | references.bib | OpenAlex, Semantic Scholar, arXiv에서 가져온 실제 BibTeX 참고문헌 — 인라인 인용과 일치하도록 자동 정리 |
| 🔍 | verification_report.json | 4계층 인용 무결성 + 관련성 검증 (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | 생성된 코드 + 샌드박스 결과 + 구조화된 JSON 메트릭 |
| 📊 | charts/ | 오차 막대와 신뢰 구간이 포함된 자동 생성 조건 비교 차트 |
| 📝 | reviews.md | 방법론-증거 일관성 검사를 포함한 멀티 에이전트 피어 리뷰 |
| 🧬 | evolution/ | 각 실행에서 추출된 자기 학습 교훈 |
| 📦 | deliverables/ | 모든 최종 산출물을 하나의 폴더에 — Overleaf에 바로 컴파일 가능 |
Built with 🦞 by the AutoResearchClaw team
================================================ FILE: docs/README_PT.md ================================================
Converse com o OpenClaw: "Pesquise X" → pronto.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 Galeria de Artigos · 📖 Guia de Integração · 💬 Comunidade Discord
---|
|
🏆 Galeria de Artigos Gerados 8 artigos em 8 domínios — matemática, estatística, biologia, computação, NLP, RL, visão, robustez — gerados de forma totalmente autônoma sem intervenção humana. |
| 📄 | paper_draft.md | Artigo acadêmico completo (Introdução, Trabalhos Relacionados, Método, Experimentos, Resultados, Conclusão) |
| 📐 | paper.tex | LaTeX pronto para conferência (templates NeurIPS / ICLR / ICML) |
| 📚 | references.bib | Referências BibTeX reais do OpenAlex, Semantic Scholar e arXiv — auto-podadas para corresponder às citações inline |
| 🔍 | verification_report.json | Verificação de integridade + relevância de citações em 4 camadas (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | Código gerado + resultados do sandbox + métricas JSON estruturadas |
| 📊 | charts/ | Gráficos de comparação de condições gerados automaticamente com barras de erro e intervalos de confiança |
| 📝 | reviews.md | Revisão por pares multi-agente com verificações de consistência metodologia-evidência |
| 🧬 | evolution/ | Lições de autoaprendizagem extraídas de cada execução |
| 📦 | deliverables/ | Todas as saídas finais em uma pasta — pronto para compilar no Overleaf |
Construído com 🦞 pela equipe AutoResearchClaw
================================================ FILE: docs/README_RU.md ================================================
Просто напишите OpenClaw: «Исследуй X» → готово.
🇺🇸 English · 🇨🇳 中文 · 🇯🇵 日本語 · 🇰🇷 한국어 · 🇫🇷 Français · 🇩🇪 Deutsch · 🇪🇸 Español · 🇧🇷 Português · 🇷🇺 Русский · 🇸🇦 العربية
🏆 Галерея статей · 📖 Руководство по интеграции · 💬 Сообщество в Discord
---| 📄 | paper_draft.md | Полная академическая статья (Введение, Обзор литературы, Метод, Эксперименты, Результаты, Заключение) |
| 📐 | paper.tex | Готовый LaTeX-код (шаблоны NeurIPS / ICLR / ICML) |
| 📚 | references.bib | Реальные BibTeX-ссылки из OpenAlex, Semantic Scholar и arXiv — автоматически отфильтрованные под цитаты в тексте |
| 🔍 | verification_report.json | 4-уровневая проверка целостности и релевантности цитирования (arXiv, CrossRef, DataCite, LLM) |
| 🧪 | experiment runs/ | Сгенерированный код + результаты из песочницы + структурированные JSON-метрики |
| 📊 | charts/ | Автоматически сгенерированные графики сравнения с планками погрешностей и доверительными интервалами |
| 📝 | reviews.md | Мультиагентное рецензирование с проверкой согласованности методологии и результатов |
| 🧬 | evolution/ | Уроки для самообучения, извлеченные из каждого запуска |
| 📦 | deliverables/ | Все итоговые материалы в одной папке — готовы к загрузке в Overleaf |
Создано с 🦞 командой AutoResearchClaw
================================================ FILE: docs/TESTER_GUIDE.md ================================================
Help us stress-test the world's first fully autonomous research pipeline — across every domain.
⭐ Star the Repo · 🚀 Quick Start · 📋 Feedback Template · 🇨🇳 中文测试指南 · 🇯🇵 日本語テストガイド
--- ## 👋 Welcome, Tester! **AutoResearchClaw** is a fully autonomous academic paper generation pipeline. You give it a research idea — it handles everything else: literature search, experiment design, code generation, experiment execution, paper writing, peer review, and final delivery. **23 stages, zero human intervention.** We're looking for testers from **all disciplines and backgrounds** — machine learning, NLP, computer vision, reinforcement learning, bioinformatics, physics, social sciences, and beyond. The more diverse the testing, the better the pipeline becomes. **Your mission:** Run the pipeline with your own research idea, inspect the output, and submit a detailed feedback report. That's it. Every piece of feedback directly shapes the next version. --- ## 📋 Table of Contents 1. [Prerequisites](#-prerequisites) 2. [Installation & Setup](#-installation--setup) 3. [Running the Pipeline](#-running-the-pipeline) 4. [Inspecting the Output](#-inspecting-the-output) 5. [Feedback Report Requirements](#-feedback-report-requirements) 6. [Feedback Template](#-feedback-template) 7. [FAQ](#-faq) --- ## 📦 Prerequisites | Item | Minimum | Recommended | |------|---------|-------------| | OS | macOS / Linux / WSL2 | Linux (Ubuntu 22.04+) | | Python | 3.11+ | 3.11 or 3.12 | | Disk | 500 MB | 2 GB+ | | RAM | 8 GB | 16 GB+ | | GPU | Not required (sandbox mode) | NVIDIA GPU + CUDA 12.x (docker mode) | | Network | Required (LLM API + literature search) | Stable connection | | LLM API Key | **Required** | OpenAI or Anthropic | ### 🔑 About API Keys The pipeline calls a large language model (LLM) at every stage — writing, coding, reviewing, and more. You'll need an API key from **OpenAI** or **Anthropic**. > **We strongly recommend using the most capable models available for the best results:** > > | Provider | Recommended Model | Fallback | > |----------|------------------|----------| > | **OpenAI** | **GPT-5.4** (best) | GPT-5.1 or GPT-4.1 | > | **Anthropic** | **Claude Opus 4.6** (best) | Claude Sonnet 4.6 | > > Using a top-tier model significantly improves paper quality, code correctness, and experiment design. Older models (e.g., GPT-4o) may produce noticeably weaker output. --- ## 🛠 Installation & Setup ### ⚠️ Always Use the Latest Version > **This project is under active development.** The codebase is updated frequently, and different versions can produce significantly different results. > > **Before every test run, always pull the latest code:** > > ```bash > cd AutoResearchClaw > git pull origin main > pip install -e . # Re-install to pick up changes > ``` > > Record your version for the feedback report: > ```bash > git log --oneline -1 > ``` --- ### Option A: Claude Code (Fastest — Recommended ⚡) If you have [Claude Code](https://claude.ai/claude-code) (Anthropic's CLI tool), just paste this: ``` Please clone and install AutoResearchClaw: https://github.com/aiming-lab/AutoResearchClaw.git If already cloned, run git pull origin main to update to the latest version first. Then create a config file with: - LLM: OpenAI with gpt-5.4 (or Anthropic Claude Opus 4.6) - Experiment mode: sandbox (local execution) - Research topic: "⭐ If you find this project interesting, please give us a star on GitHub!
================================================ FILE: docs/TESTER_GUIDE_CN.md ================================================
欢迎来自各个领域的你,一起测试全球首个全自动学术论文生成 Pipeline。
⭐ Star 项目 · 🚀 快速开始 · 📋 反馈模板 · 🇬🇧 English · 🇯🇵 日本語テストガイド
--- ## 👋 你好,测试者! **AutoResearchClaw** 是一个全自动学术论文生成 Pipeline。你只需提供一个研究 idea,系统就会自动完成文献检索、实验设计、代码生成、实验执行、论文撰写、同行评审到最终交付的全部 **23 个阶段**——无需任何人工干预。 我们正在寻找来自**各个学科和领域**的测试者——机器学习、NLP、计算机视觉、强化学习、生物信息学、物理学、社会科学……领域越多样,Pipeline 就能变得越好。 **你的任务:** 用你自己的研究 idea 运行一次完整的 Pipeline,检查输出质量,然后向我们提交一份详细的反馈报告。就这么简单——你的每一条反馈都会直接推动下一个版本的改进。 --- ## 📋 目录 1. [环境要求](#-环境要求) 2. [安装与配置](#-安装与配置) 3. [运行测试](#-运行测试) 4. [查看交付结果](#-查看交付结果) 5. [反馈报告要求](#-反馈报告要求) 6. [反馈报告模板](#-反馈报告模板) 7. [常见问题](#-常见问题) --- ## 📦 环境要求 | 项目 | 最低要求 | 推荐配置 | |------|---------|---------| | 操作系统 | macOS / Linux / WSL2 | Linux (Ubuntu 22.04+) | | Python | 3.11+ | 3.11 或 3.12 | | 磁盘空间 | 500 MB | 2 GB+ | | 内存 | 8 GB | 16 GB+ | | GPU | 非必须(sandbox 模式) | NVIDIA GPU + CUDA 12.x(docker 模式) | | 网络 | 需要(调用 LLM API + 文献检索) | 稳定的网络连接 | | LLM API Key | **必须** | OpenAI 或 Anthropic | ### 🔑 关于 API Key Pipeline 在每个阶段都会调用大语言模型(LLM)来完成写作、编码、评审等任务。你需要准备一个 **OpenAI** 或 **Anthropic** 的 API Key。 > **强烈建议使用最新、最强的模型以获得最佳效果:** > > | 提供商 | 推荐模型 | 备选 | > |--------|---------|------| > | **OpenAI** | **GPT-5.4**(首选) | GPT-5.1 或 GPT-4.1 | > | **Anthropic** | **Claude Opus 4.6**(首选) | Claude Sonnet 4.6 | > > 使用顶级模型会显著提升论文写作质量、代码生成准确性和实验设计合理性。较低版本的模型(如 gpt-4o)可能导致输出质量明显下降。 --- ## 🛠 安装与配置 ### ⚠️ 请务必使用最新版本 > **本项目处于快速迭代阶段,** 代码更新频繁,不同版本之间的生成效果可能存在较大差异。 > > **每次测试前,请务必拉取最新代码:** > > ```bash > cd AutoResearchClaw > git pull origin main > pip install -e . # 重新安装以确保更新生效 > ``` > > 记录你的版本号,方便填写反馈报告: > ```bash > git log --oneline -1 > ``` --- ### 方式 A:使用 Claude Code(最快 ⚡ 推荐) 如果你正在使用 [Claude Code](https://claude.ai/claude-code)(Anthropic 的 CLI 工具),直接粘贴以下内容即可: ``` 请帮我克隆并安装 AutoResearchClaw 项目: https://github.com/aiming-lab/AutoResearchClaw.git 如果已经克隆过,请先 git pull origin main 更新到最新版本。 安装完成后,帮我创建一个配置文件,使用以下参数: - LLM: OpenAI,模型选择 gpt-5.4(或 Anthropic Claude Opus 4.6) - 实验模式: sandbox(本地沙盒执行) - 研究主题: "<在这里填入你的研究 idea>" - 自动审批所有 gate stage 我的 API Key 是: sk-xxxx(请设为环境变量,不要写在配置文件里) ``` Claude Code 会自动完成克隆、安装依赖、创建配置文件、运行 Pipeline 的全部步骤。 ### 方式 B:手动安装 ```bash # 1. 克隆项目 git clone https://github.com/aiming-lab/AutoResearchClaw.git cd AutoResearchClaw # ⚠️ 如果已经克隆过,务必先更新! # git pull origin main # 2. 创建 Python 虚拟环境 python3 -m venv .venv source .venv/bin/activate # macOS / Linux # .venv\Scripts\activate # Windows(推荐使用 WSL2) # 3. 安装项目 pip install -e . # 4. 验证安装成功 researchclaw --help ``` ### ⚙️ 配置文件 ```bash cp config.researchclaw.example.yaml config.yaml ``` 编辑 `config.yaml`,修改以下关键字段: ```yaml # === 项目设置 === project: name: "my-test" mode: "full-auto" # === 研究主题——用英文描述你的 idea === research: topic: "你的研究 idea,用英文描述,一两句话即可" domains: - "machine-learning" # 可选: nlp, cv, rl, graph-learning, etc. # === LLM 配置——请使用最强模型! === # # 方案一:OpenAI(推荐 GPT-5.4) llm: provider: "openai-compatible" base_url: "https://api.openai.com/v1" api_key_env: "OPENAI_API_KEY" primary_model: "gpt-5.4" # 首选最强模型 fallback_models: - "gpt-5.1" - "gpt-4.1" # 方案二:Anthropic Claude(推荐 Claude Opus 4.6) # llm: # provider: "openai-compatible" # base_url: "https://api.anthropic.com/v1" # api_key_env: "ANTHROPIC_API_KEY" # primary_model: "claude-opus-4-6" # fallback_models: # - "claude-sonnet-4-6" # === 实验模式 === experiment: mode: "sandbox" # sandbox = 本地执行(推荐) time_budget_sec: 600 # 每次实验最长运行时间(秒) max_iterations: 10 metric_key: "primary_metric" metric_direction: "minimize" # 或 "maximize" ``` ### 🔐 设置 API Key ```bash # OpenAI 用户: export OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxx" # Anthropic 用户: export ANTHROPIC_API_KEY="sk-ant-xxxxxxxxxxxxxxxxxxxxxxxx" # 可选:Semantic Scholar API Key(可加快文献检索) export S2_API_KEY="your-s2-key" ``` > **🔒 安全提醒:** 请勿将 API Key 硬编码在任何文件中。使用 `api_key_env` 指定环境变量名即可。 --- ## 🚀 运行测试 ### 快速开始 ```bash source .venv/bin/activate export OPENAI_API_KEY="sk-xxxx" # 或 ANTHROPIC_API_KEY researchclaw run --config config.yaml --auto-approve ``` ### 指定研究主题运行 ```bash researchclaw run \ --config config.yaml \ --topic "Investigating the effect of curriculum learning on image classification with adaptive difficulty scheduling" \ --auto-approve ``` ### ⏱ 预估运行时间 | 实验模式 | 预估时间 | 说明 | |---------|---------|------| | sandbox | 30 分钟 ~ 2 小时 | 取决于实验复杂度和 API 响应速度 | | docker (GPU) | 1 ~ 4 小时 | 可运行更复杂的深度学习实验 | 运行过程中终端会实时显示当前阶段和进度。**无需任何手动操作**,安心等待即可。 ### ✅ 如何知道运行结束 当看到类似以下输出时,表示 Pipeline 已成功完成: ``` [Stage 23/23] ✓ Deliverables packaged Pipeline complete — deliverables at: artifacts/rc-20260315-XXXXXX-YYYY/deliverables/ ``` ### 🔄 如果运行中断 Pipeline 支持断点续跑: ```bash researchclaw run --config config.yaml --resume ``` --- ## 🔍 查看交付结果 运行结束后,输出文件位于 `artifacts/rc-YYYYMMDD-HHMMSS-⭐ 如果你觉得这个项目有趣,请在 GitHub 上给我们一颗 Star!
================================================ FILE: docs/TESTER_GUIDE_JA.md ================================================
世界初の完全自律型研究パイプラインを、あらゆる分野でストレステストするためにご協力ください。
⭐ リポジトリにスターを付ける · 🚀 クイックスタート · 📋 フィードバックテンプレート · 🇺🇸 English Testing Guide · 🇨🇳 中文测试指南
--- ## 👋 テスターの皆さんへ **AutoResearchClaw** は、完全自律型の学術論文生成パイプラインです。研究アイデアを入力するだけで、文献検索、実験設計、コード生成、実験実行、論文執筆、査読、最終成果物の作成まで、すべてを自動で処理します。**23ステージ、人手介入ゼロ。** **あらゆる分野・バックグラウンド**のテスターを募集しています — 機械学習、NLP、コンピュータビジョン、強化学習、バイオインフォマティクス、物理学、社会科学など。テストが多様であるほど、パイプラインの改善に繋がります。 **あなたのミッション:** 自分の研究アイデアでパイプラインを実行し、出力を検査して、詳細なフィードバックレポートを提出してください。それだけです。すべてのフィードバックが次のバージョンに直接反映されます。 --- ## 📋 目次 1. [前提条件](#-前提条件) 2. [インストールとセットアップ](#-インストールとセットアップ) 3. [パイプラインの実行](#-パイプラインの実行) 4. [出力の確認](#-出力の確認) 5. [フィードバックレポートの要件](#-フィードバックレポートの要件) 6. [フィードバックテンプレート](#-フィードバックテンプレート) 7. [FAQ](#-faq) --- ## 📦 前提条件 | 項目 | 最小要件 | 推奨 | |------|---------|------| | OS | macOS / Linux / WSL2 | Linux (Ubuntu 22.04+) | | Python | 3.11+ | 3.11 または 3.12 | | ディスク | 500 MB | 2 GB+ | | RAM | 8 GB | 16 GB+ | | GPU | 不要(sandboxモード) | NVIDIA GPU + CUDA 12.x(dockerモード) | | ネットワーク | 必要(LLM API + 文献検索) | 安定した接続 | | LLM APIキー | **必須** | OpenAI または Anthropic | ### 🔑 APIキーについて パイプラインは、執筆、コーディング、レビューなど、すべてのステージで大規模言語モデル(LLM)を呼び出します。**OpenAI** または **Anthropic** のAPIキーが必要です。 > **最良の結果を得るために、利用可能な最も高性能なモデルの使用を強く推奨します:** > > | プロバイダー | 推奨モデル | フォールバック | > |-------------|-----------|--------------| > | **OpenAI** | **GPT-5.4**(最良) | GPT-5.1 または GPT-4.1 | > | **Anthropic** | **Claude Opus 4.6**(最良) | Claude Sonnet 4.6 | > > トップティアのモデルを使用することで、論文の品質、コードの正確性、実験設計が大幅に向上します。古いモデル(例:GPT-4o)では、出力品質が著しく低下する可能性があります。 --- ## 🛠 インストールとセットアップ ### ⚠️ 常に最新バージョンを使用してください > **このプロジェクトは活発に開発中です。** コードベースは頻繁に更新され、バージョンによって結果が大きく異なる場合があります。 > > **テスト実行の前に、必ず最新のコードをプルしてください:** > > ```bash > cd AutoResearchClaw > git pull origin main > pip install -e . # 変更を反映するために再インストール > ``` > > フィードバックレポート用にバージョンを記録してください: > ```bash > git log --oneline -1 > ``` --- ### オプションA:Claude Code(最速 — 推奨 ⚡) [Claude Code](https://claude.ai/claude-code)(AnthropicのCLIツール)をお持ちの場合、以下を貼り付けるだけです: ``` Please clone and install AutoResearchClaw: https://github.com/aiming-lab/AutoResearchClaw.git If already cloned, run git pull origin main to update to the latest version first. Then create a config file with: - LLM: OpenAI with gpt-5.4 (or Anthropic Claude Opus 4.6) - Experiment mode: sandbox (local execution) - Research topic: "<ここに研究アイデアを入力>" - Auto-approve all gate stages My API key is: sk-xxxx (set it as an environment variable, don't hardcode it) ``` Claude Codeがクローン、依存関係、設定、実行をすべて自動で処理します。 ### オプションB:手動インストール ```bash # 1. リポジトリをクローン git clone https://github.com/aiming-lab/AutoResearchClaw.git cd AutoResearchClaw # 2. 仮想環境を作成 python3 -m venv .venv source .venv/bin/activate # macOS / Linux # .venv\Scripts\activate # Windows(WSL2推奨) # 3. インストール pip install -e . # 4. 動作確認 researchclaw --help ``` ### ⚙️ 設定 ```bash cp config.researchclaw.example.yaml config.yaml ``` `config.yaml` を編集してください — 主要なフィールドは以下の通りです: ```yaml # === プロジェクト === project: name: "my-test" mode: "full-auto" # === 研究トピック — アイデアを英語で記述してください === research: topic: "Your research idea in 1-2 sentences" domains: - "machine-learning" # 選択肢: nlp, cv, rl, graph-learning など # === LLM — 利用可能な最も高性能なモデルを使用してください! === # # オプション1: OpenAI(GPT-5.4推奨) llm: provider: "openai-compatible" base_url: "https://api.openai.com/v1" api_key_env: "OPENAI_API_KEY" primary_model: "gpt-5.4" # 最良のモデル fallback_models: - "gpt-5.1" - "gpt-4.1" # オプション2: Anthropic Claude(Claude Opus 4.6推奨) # llm: # provider: "openai-compatible" # base_url: "https://api.anthropic.com/v1" # api_key_env: "ANTHROPIC_API_KEY" # primary_model: "claude-opus-4-6" # fallback_models: # - "claude-sonnet-4-6" # === 実験 === experiment: mode: "sandbox" # sandbox = ローカル実行(推奨) time_budget_sec: 600 # 実験実行あたりの最大秒数 max_iterations: 10 metric_key: "primary_metric" metric_direction: "minimize" # または "maximize" ``` ### 🔐 APIキーの設定 ```bash # OpenAIユーザー: export OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxx" # Anthropicユーザー: export ANTHROPIC_API_KEY="sk-ant-xxxxxxxxxxxxxxxxxxxxxxxx" # オプション:Semantic Scholar APIキー(文献検索を高速化) export S2_API_KEY="your-s2-key" ``` > **🔒 セキュリティ:** APIキーをファイルにハードコードしないでください。設定ファイルの `api_key_env` を使用して環境変数を参照してください。 --- ## 🚀 パイプラインの実行 ### クイックスタート ```bash source .venv/bin/activate export OPENAI_API_KEY="sk-xxxx" # または ANTHROPIC_API_KEY researchclaw run --config config.yaml --auto-approve ``` ### 特定のトピックを指定する場合 ```bash researchclaw run \ --config config.yaml \ --topic "Investigating the effect of curriculum learning on image classification with adaptive difficulty scheduling" \ --auto-approve ``` ### ⏱ 想定実行時間 | モード | 推定時間 | 備考 | |--------|---------|------| | sandbox | 30分 〜 2時間 | 実験の複雑さとAPIの速度に依存 | | docker (GPU) | 1 〜 4時間 | より大規模なディープラーニング実験向け | ターミナルにリアルタイムで進捗が表示されます。**手動介入は不要です** — あとは実行完了を待つだけです。 ### ✅ 完了の確認方法 以下のような出力が表示されます: ``` [Stage 23/23] ✓ Deliverables packaged Pipeline complete — deliverables at: artifacts/rc-20260315-XXXXXX-YYYY/deliverables/ ``` ### 🔄 中断された場合 パイプラインはチェックポイントをサポートしています — 再開するだけです: ```bash researchclaw run --config config.yaml --resume ``` --- ## 🔍 出力の確認 完了後、結果は `artifacts/rc-YYYYMMDD-HHMMSS-⭐ このプロジェクトに興味を持たれたら、GitHubでスターをお願いします!
================================================ FILE: docs/agent_figure_and_benchmark_plan.md ================================================ # Multi-Agent Figure Generation & Benchmark Selection — Task Requirements > **Created**: 2026-03-15 > **Updated**: 2026-03-15 > **Status**: BenchmarkAgent IMPLEMENTED, FigureAgent IMPLEMENTED > **Scope**: Two new multi-agent subsystems for AutoResearchClaw pipeline > > **Implementation Progress**: > - [x] Part B: BenchmarkAgent — fully implemented (4 agents + orchestrator + config + pipeline integration + 43 tests) > - [x] Part A: FigureAgent — fully implemented (5 agents + orchestrator + config + pipeline integration + 45 tests) > > **Key Research Findings (supplemental)**: > - Papers With Code was shut down by Meta in July 2025; HuggingFace Hub API is now the primary dataset discovery source > - AI Scientist v2 and MLR-Copilot both use pure LLM-driven dataset selection (no API search) — our API-based approach is more structured > - MLE-bench (OpenAI) validates the pre-download + container-mount pattern (matches our `setup_only` network policy) > - CodeSOTA (codesota.com) provides a lighter-weight benchmark database as an alternative to Papers With Code --- ## Executive Summary 当前 Pipeline 的图表生成和数据集/基准选择存在根本性缺陷: **图表问题**(实测产出): - 每次固定只生成 2 张图(`method_comparison.png` + `experiment_comparison.png`) - 图表类型单一:只有柱状图,无训练曲线、热力图、消融分析图等 - 数据无差异化:所有方法都显示 1.000,完全无信息量 - 样式简陋:默认 matplotlib 风格,远低于 AI 顶会标准 - 不适应实验内容:无论做什么研究都画一样的图 - DPI=150,不满足出版要求(300+ DPI) **数据集/基准问题**: - 当前仅通过 `dataset_guidance` 提示词列出预缓存数据集 - 无法根据研究领域动态搜索和选择最合适的 benchmark - 无法自动下载非预缓存数据集 - 缺乏 baseline 方法的自动复现能力 **解决方案**:设计两个独立的多 Agent 子系统: 1. **FigureAgent** — 智能图表生成系统(6 个子 Agent 协作) 2. **BenchmarkAgent** — 数据集与基准选择系统(4 个子 Agent 协作) --- ## Part A: FigureAgent — 多 Agent 图表生成系统 ### A.1 问题分析 #### 当前架构缺陷 ``` 现状:Stage 14 → visualize.py (5 个硬编码函数) → 固定 2 张图 → Stage 17/22 嵌入论文 ``` | 问题 | 严重程度 | 说明 | |------|---------|------| | 图表类型固定 | Critical | 只有 bar chart 和 line chart,缺少 heatmap、scatter、violin、architecture diagram 等 | | 不适应实验内容 | Critical | 知识蒸馏实验和 RL 实验画的图完全一样 | | 无智能决策 | Critical | 不分析"应该画什么",直接调用固定函数 | | 数据正确性无验证 | High | 不验证图中数据是否与实验结果一致 | | 样式不达标 | High | 默认 matplotlib,不符合学术论文视觉标准 | | 无架构图能力 | High | 不能生成方法流程图 / 模型架构图(顶会 Figure 1 必备) | | DPI 不足 | Medium | 150 DPI,出版要求 300+ | | 无 VLM 审查 | Medium | 生成后不检查质量,直接用 | #### 业界参考方案 | 项目 | 图表策略 | 核心创新 | |------|---------|---------| | AI Scientist v1 (Sakana) | 人工编写 `plot.py` 模板,LLM 不参与 | 可靠但不灵活 | | AI Scientist v2 (Sakana) | LLM 自主生成画图代码 + VLM 审查反馈循环 | **VLM-as-critic**,首篇通过 ICLR workshop 审稿 | | PlotGen (Adobe) | 三模态反馈:数值准确性 + 文本正确性 + 视觉质量 | **Tri-modal feedback**,MatPlotBench 最优 | | PaperBanana (Google) | 3 阶段 pipeline:Caption 精炼 → 参考检索 → 迭代渲染 | **Caption sharpening** + 参考图库 | ### A.2 目标架构 ``` ┌─────────────────────┐ │ FigureAgent │ │ (Orchestrator) │ └──────────┬──────────┘ │ ┌──────────┬───────────┼───────────┬──────────┐ ▼ ▼ ▼ ▼ ▼ ┌──────────┐┌──────────┐┌──────────┐┌──────────┐┌──────────┐ │ Planner ││ CodeGen ││ Renderer ││ Critic ││ Integra- │ │ Agent ││ Agent ││ Agent ││ Agent ││ tor Agent│ └──────────┘└──────────┘└──────────┘└──────────┘└──────────┘ │ │ │ │ │ ▼ ▼ ▼ ▼ ▼ 图表规划 代码生成 执行渲染 质量审查 论文嵌入 ``` #### Agent 职责定义 **1. Orchestrator(编排器)** - 接收:实验结果 JSON、论文草稿 markdown、研究主题描述 - 协调所有子 Agent 的执行顺序 - 管理迭代循环(Critic 不满意时回到 CodeGen) - 输出:最终图表集合 + 嵌入指令 **2. Planner Agent(图表规划)** - 输入:实验结果数据结构、论文 idea、研究领域 - 职责: - 分析实验数据,确定需要哪些图、每张图展示什么 - 为每张图生成精确的 caption specification(非模糊描述) - 确定图表类型(bar / line / heatmap / scatter / architecture / ablation 等) - 确定布局(single / subplot / multi-panel) - 输出图表规划清单(JSON 格式) - 关键规则: - 至少规划 4 张图:1 架构图 + 1 主结果图 + 1 消融图 + 1 分析图 - 根据研究领域自动选择合适的图表类型 - Caption sharpening:将模糊描述转化为精确视觉规范 **3. CodeGen Agent(代码生成)** - 输入:Planner 输出的图表规划 + 实验数据 - 职责: - 为每张图生成独立的 Python 绘图脚本 - 使用 SciencePlots 学术样式 (`plt.style.use(['science', 'ieee'])`) - 确保 colorblind-safe 配色 - 300+ DPI 输出 - 代码保存到 `charts/scripts/` 供复现 - 代码模板库: - 内置常用学术图表模板(training curve, bar comparison, heatmap, confusion matrix 等) - 新图表可基于模板扩展 **4. Renderer Agent(渲染执行)** - 输入:CodeGen 生成的 Python 脚本 - 职责: - 在 Docker sandbox 中执行绘图脚本 - 捕获执行错误并反馈给 CodeGen 修复 - 验证输出文件存在且可读 - 检查图像尺寸和分辨率 **5. Critic Agent(质量审查 — 三模态反馈)** - 输入:渲染后的图像 + 源数据 + caption 规范 - 职责(三维度审查,参考 PlotGen): - **数值准确性**:验证图中呈现的数值与源数据一致(读取 JSON → 对比图中数据点) - **文本正确性**:检查标题、坐标轴标签、图例是否准确完整 - **视觉质量**:通过 VLM(如 GPT-4o vision)审查整体美观度、可读性、学术规范 - 输出:pass / fail + 具体修改建议 - 如果 fail:将反馈发回 CodeGen Agent,最多迭代 3 次 **6. Integrator Agent(论文嵌入)** - 输入:通过审查的图表集合 + 论文草稿 - 职责: - 确定每张图在论文中的最佳位置 - 生成 LaTeX figure 环境代码(支持 subfigure 多面板) - 生成交叉引用(`\ref{fig:xxx}`) - 确保图表在正确的 section(架构图在 Method,结果图在 Results) - 更新论文文本中的图表引用语句 ### A.3 图表类型矩阵 根据研究领域和实验类型,Planner Agent 应遵循以下决策矩阵: | 实验类型 | 必须包含的图表 | 可选图表 | |---------|--------------|---------| | **分类任务** | 精度对比 bar chart、confusion matrix | ROC 曲线、t-SNE 可视化 | | **生成模型** | 生成样本 grid、FID/IS 曲线 | 插值可视化、attention map | | **强化学习** | reward curve (mean±std shading)、episode length | 策略可视化、环境截图 | | **知识蒸馏** | teacher-student 精度对比、知识迁移效率曲线 | 特征对齐热力图 | | **NLP** | BLEU/ROUGE 对比表、attention heatmap | 样本输出对比 | | **图神经网络** | 节点分类精度、图可视化 | 消息传递可视化 | | **元学习** | few-shot 精度 vs shot 数曲线 | 任务适应速度 | | **持续学习** | 遗忘率曲线、任务精度矩阵 | 表征漂移可视化 | | **所有类型** | 消融分析 (grouped bar)、训练 loss 曲线 | 超参敏感性热力图 | ### A.4 样式规范 所有图表必须遵循以下学术出版标准: ```python # 全局样式配置 (charts/style_config.py) STYLE_CONFIG = { "matplotlib_style": ["science", "ieee"], # SciencePlots "dpi": 300, # 出版级 "font_size": {"title": 12, "axis": 10, "tick": 8, "legend": 9}, "figure_width": { "single_column": 3.5, # IEEE single column (inches) "double_column": 7.0, # IEEE double column "full_page": 7.0, # Full width }, "colors": "bright", # colorblind-safe (Paul Tol) "line_styles": ["-", "--", "-.", ":"], # 配合 B&W 打印 "marker_styles": ["o", "s", "^", "D", "v", "P"], "error_bar_style": "shading", # mean ± std 用阴影而非 error bar "format": "pdf", # 矢量格式优先 "fallback_format": "png", # PNG 备用 } ``` ### A.5 实现计划 #### 文件结构 ``` researchclaw/ ├── agents/ │ └── figure_agent/ │ ├── __init__.py │ ├── orchestrator.py # FigureAgent 主编排器 │ ├── planner.py # Planner Agent │ ├── codegen.py # CodeGen Agent │ ├── renderer.py # Renderer Agent │ ├── critic.py # Critic Agent (三模态审查) │ ├── integrator.py # Integrator Agent │ ├── templates/ # 图表代码模板库 │ │ ├── bar_comparison.py │ │ ├── training_curve.py │ │ ├── heatmap.py │ │ ├── confusion_matrix.py │ │ ├── scatter_plot.py │ │ ├── ablation_grouped.py │ │ ├── violin_box.py │ │ └── multi_panel.py │ └── style_config.py # 全局样式配置 ``` #### 开发任务清单 | ID | 任务 | 依赖 | 估计改动量 | |----|------|------|-----------| | FA-01 | 创建 `agents/figure_agent/` 目录结构和基础类 | 无 | 新建 | | FA-02 | 实现 Planner Agent:图表规划逻辑 + 类型决策矩阵 | FA-01 | ~300 行 | | FA-03 | 实现 CodeGen Agent:代码生成 + 模板库 | FA-01 | ~500 行 | | FA-04 | 实现 Renderer Agent:sandbox 执行 + 错误处理 | FA-01, FA-03 | ~200 行 | | FA-05 | 实现 Critic Agent:三模态审查(数值 / 文本 / VLM) | FA-01, FA-04 | ~400 行 | | FA-06 | 实现 Integrator Agent:论文嵌入 + LaTeX subfigure 支持 | FA-01 | ~250 行 | | FA-07 | 实现 Orchestrator:编排循环 + 最大迭代控制 | FA-02 ~ FA-06 | ~300 行 | | FA-08 | 添加 SciencePlots 到 Docker 镜像 + 样式配置 | 无 | ~50 行 | | FA-09 | 修改 executor.py:Stage 14 调用 FigureAgent 替代 `visualize.py` | FA-07 | ~100 行 | | FA-10 | 修改 executor.py:Stage 17/22 使用 Integrator 输出 | FA-07 | ~100 行 | | FA-11 | 修改 converter.py:支持 subfigure 和 PDF 格式 | FA-06 | ~80 行 | | FA-12 | 添加图表代码模板库(8+ 模板) | FA-03 | ~600 行 | | FA-13 | 测试:单元测试 + 集成测试 | FA-01 ~ FA-12 | ~400 行 | | FA-14 | 向后兼容:保留 `visualize.py` 作为 fallback | FA-09 | ~30 行 | #### Pipeline 集成点 ``` Stage 12-13: 实验执行完成,生成 results.json │ ▼ Stage 14: Result Analysis │── 调用 FigureAgent.orchestrate() │ ├── Planner: 分析 results.json → 图表规划 │ ├── CodeGen: 生成绘图脚本 → charts/scripts/ │ ├── Renderer: 执行脚本 → charts/*.pdf + charts/*.png │ ├── Critic: 审查图表质量 (max 3 iterations) │ └── 输出: charts/ 目录 + figure_manifest.json │ ▼ Stage 17: Paper Draft │── Integrator: 读取 figure_manifest.json │ ├── 确定每张图的论文位置 │ ├── 注入 markdown 图片引用 + caption │ └── 更新交叉引用文本 │ ▼ Stage 22: Paper Export │── 复制 charts/ 到 submission/ │── converter.py 处理 subfigure 环境 └── 最终 LaTeX 编译验证 ``` --- ## Part B: BenchmarkAgent — 多 Agent 数据集与基准选择系统 ### B.1 问题分析 #### 当前架构缺陷 ``` 现状:dataset_guidance 提示词 (硬编码列表) + dataset_registry.yaml (静态清单) → LLM 自行选择 ``` | 问题 | 严重程度 | 说明 | |------|---------|------| | 数据集选择不智能 | Critical | 仅列出预缓存数据集,LLM 可能选择不合适的 benchmark | | 无领域适配 | Critical | 不根据研究领域搜索该领域的标准 benchmark | | 无最新性保证 | High | 不检查是否有更新、更好的 benchmark 可用 | | baseline 无法复现 | High | 不提供已有方法的参考实现 / 预训练权重 | | 下载路径硬编码 | Medium | 非预缓存数据集无法自动获取 | | 无数据集验证 | Medium | 不验证下载的数据集是否完整、格式正确 | #### 理想工作流 一个好的数据集/基准选择流程应该: 1. **理解研究问题** → 确定评估维度(分类精度?生成质量?推理速度?) 2. **搜索领域标准** → 查找该领域顶会论文常用的 benchmark 3. **评估适用性** → 数据集大小、难度、License、可获取性 4. **获取数据** → 自动下载或生成下载脚本 5. **获取 baseline** → 找到对比方法的开源实现或预训练权重 6. **验证完整性** → 确认数据集可正常加载和使用 ### B.2 目标架构 ``` ┌─────────────────────┐ │ BenchmarkAgent │ │ (Orchestrator) │ └──────────┬──────────┘ │ ┌──────────┬───────────┼───────────┐ ▼ ▼ ▼ ▼ ┌──────────┐┌──────────┐┌──────────┐┌──────────┐ │ Surveyor ││ Selector ││ Acquirer ││ Validator│ │ Agent ││ Agent ││ Agent ││ Agent │ └──────────┘└──────────┘└──────────┘└──────────┘ │ │ │ │ ▼ ▼ ▼ ▼ 领域调研 选择决策 数据获取 验证确认 ``` #### Agent 职责定义 **1. Orchestrator(编排器)** - 接收:研究主题、假设、实验设计方案 - 协调 4 个子 Agent 的执行 - 输出:`benchmark_plan.json`(包含数据集列表、下载脚本、baseline 方案) **2. Surveyor Agent(领域调研)** - 输入:研究主题关键词、相关文献列表 - 职责: - 搜索 Papers With Code 的领域 benchmark 排行榜 - 搜索 HuggingFace Datasets 的相关数据集 - 搜索 OpenML、Kaggle 的相关 benchmark - 分析近 2 年顶会论文(ICML、NeurIPS、ICLR)使用的数据集 - 汇总领域标准 benchmark 清单(含引用频次、数据规模、难度级别) - 输出:`survey_results.json` — 候选 benchmark 列表(按推荐度排序) - 数据源优先级: 1. Papers With Code (Benchmarks API) 2. HuggingFace Datasets Hub 3. torchvision / torchaudio / torchtext 内置 4. 顶会论文附录中的数据集描述 **3. Selector Agent(选择决策)** - 输入:survey_results.json + 实验约束(GPU 内存、时间预算、网络可用性) - 职责: - 根据约束过滤不可行的数据集(太大 / 需要申请 / License 不兼容) - 考虑 Docker sandbox 已缓存的数据集(优先使用) - 选择 primary benchmark(必须是领域标准)+ secondary benchmarks(补充验证) - 选择 baseline 方法(至少 2 个有开源实现的对比方法) - 生成选择理由文档(供论文 Experimental Setup section 使用) - 约束规则: - Tier 1(已缓存):无网络需求,最优先 - Tier 2(torchvision/HF datasets 可直接下载):需 setup 阶段网络 - Tier 3(需自定义下载脚本):仅在 `network_policy: full` 时可用 - 输出:`selected_benchmarks.json` + `baseline_methods.json` **4. Acquirer Agent(数据获取)** - 输入:selected_benchmarks.json - 职责: - 生成 `setup.py` 中的数据集下载代码 - 为每个数据集生成加载 boilerplate 代码 - 为 baseline 方法生成安装和调用代码 - 处理 HuggingFace `datasets.load_dataset()` / `torchvision.datasets` 等接口 - 生成 `requirements.txt` 中需要额外安装的包 - 输出: - `data_loading_snippets.py` — 数据加载代码片段(注入 CodeAgent) - `baseline_snippets.py` — baseline 调用代码片段 - `setup.py` 追加内容 — 下载脚本 **5. Validator Agent(验证确认)** - 输入:Acquirer 生成的下载/加载代码 - 职责: - 验证数据集 API 调用语法正确 - 验证数据集分割(train/val/test)存在 - 验证数据格式与实验代码兼容 - 验证 baseline 方法可运行 - 如果验证失败,反馈给 Acquirer 修复 - 输出:validation_report.json ### B.3 知识库设计 BenchmarkAgent 需要一个结构化知识库来支持决策: ```yaml # researchclaw/data/benchmark_knowledge.yaml domains: image_classification: standard_benchmarks: - name: CIFAR-10/100 source: torchvision tier: 1 # 已缓存 difficulty: easy/medium use_when: "小规模验证、快速原型" - name: ImageNet-1K source: torchvision tier: 3 # 需要下载 ~150GB difficulty: hard use_when: "大规模验证、与 SOTA 对比" common_baselines: - name: ResNet-50 source: "torchvision.models.resnet50(pretrained=True)" paper: "He et al., 2016" - name: ViT-B/16 source: "timm.create_model('vit_base_patch16_224', pretrained=True)" paper: "Dosovitskiy et al., 2021" reinforcement_learning: standard_benchmarks: - name: Gymnasium (MuJoCo) source: "gymnasium[mujoco]" tier: 2 - name: Atari source: "gymnasium[atari]" tier: 2 common_baselines: - name: PPO source: "stable-baselines3" paper: "Schulman et al., 2017" # ... 更多领域 ``` ### B.4 实现计划 #### 文件结构 ``` researchclaw/ ├── agents/ │ └── benchmark_agent/ │ ├── __init__.py │ ├── orchestrator.py # BenchmarkAgent 主编排器 │ ├── surveyor.py # Surveyor Agent (领域调研) │ ├── selector.py # Selector Agent (选择决策) │ ├── acquirer.py # Acquirer Agent (数据获取) │ ├── validator.py # Validator Agent (验证确认) │ └── knowledge_base.py # 知识库加载和查询 ├── data/ │ ├── benchmark_knowledge.yaml # 领域 benchmark 知识库 │ └── dataset_registry.yaml # 已有数据集注册表 (保留) ``` #### 开发任务清单 | ID | 任务 | 依赖 | 估计改动量 | |----|------|------|-----------| | BA-01 | 创建 `agents/benchmark_agent/` 目录结构和基础类 | 无 | 新建 | | BA-02 | 编写 `benchmark_knowledge.yaml` 知识库(覆盖 10+ 领域) | 无 | ~500 行 YAML | | BA-03 | 实现 Surveyor Agent:Papers With Code API + HF Datasets 搜索 | BA-01 | ~350 行 | | BA-04 | 实现 Selector Agent:约束过滤 + Tier 匹配 + 选择逻辑 | BA-01, BA-02 | ~300 行 | | BA-05 | 实现 Acquirer Agent:代码生成 + 下载脚本 | BA-01, BA-04 | ~350 行 | | BA-06 | 实现 Validator Agent:语法/可用性验证 | BA-01, BA-05 | ~250 行 | | BA-07 | 实现 Orchestrator:编排 + 迭代修复 | BA-02 ~ BA-06 | ~250 行 | | BA-08 | 修改 executor.py:Stage 6/7 调用 BenchmarkAgent | BA-07 | ~150 行 | | BA-09 | 修改 executor.py:将 benchmark_plan 注入 CodeAgent | BA-07 | ~100 行 | | BA-10 | 更新 prompts.py:基于 BenchmarkAgent 输出动态构建提示词 | BA-07 | ~100 行 | | BA-11 | 测试:单元测试 + 集成测试 | BA-01 ~ BA-10 | ~300 行 | | BA-12 | 向后兼容:保留 `dataset_registry.yaml` 作为 fallback | BA-08 | ~30 行 | #### Pipeline 集成点 ``` Stage 3: Topic Initialization │── 研究主题确定 ▼ Stage 4-5: Literature Collection & Screening │── 文献列表生成 ▼ Stage 6: Hypothesis Generation │── 调用 BenchmarkAgent.orchestrate() │ ├── Surveyor: 搜索领域标准 benchmark │ ├── Selector: 根据约束选择最优 benchmark + baseline │ ├── Acquirer: 生成下载/加载代码 │ └── Validator: 验证代码可执行 │── 输出: benchmark_plan.json ▼ Stage 7: Experiment Design │── benchmark_plan.json 注入实验设计 │── 实验方案明确使用哪些数据集和 baseline ▼ Stage 8-9: Code Generation (CodeAgent) │── data_loading_snippets 注入生成代码 │── baseline_snippets 注入对比方法 ▼ Stage 10-11: Experiment Execution │── setup.py 执行数据集下载 │── main.py 使用生成的数据加载代码 ▼ Stage 14: Result Analysis │── 对比结果基于真实 baseline,可信度高 ``` --- ## Part C: 共同基础设施 ### C.1 Agent 基类 两个多 Agent 系统共享同一套基础设施: ```python # researchclaw/agents/base.py class BaseAgent: """所有子 Agent 的基类""" def __init__(self, llm_client, config): self.llm = llm_client self.config = config self.logger = logging.getLogger(self.__class__.__name__) async def execute(self, context: dict) -> dict: """执行 Agent 任务,返回结果""" raise NotImplementedError def _call_llm(self, system_prompt, user_prompt, **kwargs): """统一 LLM 调用接口""" return self.llm.chat(system_prompt, user_prompt, **kwargs) class AgentOrchestrator: """Agent 编排器基类""" def __init__(self, agents: list[BaseAgent], max_iterations=3): self.agents = agents self.max_iterations = max_iterations async def orchestrate(self, context: dict) -> dict: """执行多 Agent 编排流程""" raise NotImplementedError ``` ### C.2 与现有 LLM Client 的集成 两个系统都通过现有的 `researchclaw/llm/client.py` 调用 LLM: - Planner / Selector / Critic 等决策类 Agent → 使用 `gpt-4.1` 或 `gpt-4o` - CodeGen 类 Agent → 使用 `gpt-4.1`(代码生成能力最强) - VLM Critic → 使用 `gpt-4o`(支持 vision) ### C.3 配置扩展 ```yaml # config.yaml 新增配置 agents: figure_agent: enabled: true max_iterations: 3 # Critic 反馈最大迭代次数 min_figures: 4 # 最少图表数 style: "science+ieee" # matplotlib 样式 dpi: 300 format: "pdf" # 优先格式 vlm_review: true # 是否启用 VLM 视觉审查 benchmark_agent: enabled: true max_search_results: 20 # Papers With Code 最大搜索结果 prefer_cached: true # 优先使用已缓存数据集 tier_limit: 2 # 最高允许的 Tier 级别 (1=缓存, 2=可下载, 3=大型) min_baselines: 2 # 最少 baseline 方法数 ``` --- ## Part D: 风险与兜底 ### D.1 向后兼容 | 组件 | 兜底策略 | |------|---------| | FigureAgent 失败 | 回退到现有 `visualize.py` 生成基础图表 | | BenchmarkAgent 失败 | 回退到 `dataset_registry.yaml` + `dataset_guidance` 提示词 | | VLM 审查不可用 | 跳过视觉审查,仅做数值 + 文本验证 | | SciencePlots 未安装 | 使用 `seaborn-v0_8-whitegrid` 样式 | | 网络不可用 | Surveyor 使用本地 `benchmark_knowledge.yaml` | ### D.2 Token 成本控制 | 操作 | 预估 Token 消耗 | 控制策略 | |------|----------------|---------| | Planner (1 次) | ~2K input + ~1K output | 固定 | | CodeGen (4 图 × 最多 3 次迭代) | ~3K × 12 = ~36K | 迭代次数上限 | | Critic (4 图 × 最多 3 次) | ~2K × 12 = ~24K | 迭代次数上限 | | VLM 审查 (4 图) | ~4K × 4 = ~16K | 仅终轮审查 | | Surveyor (1 次) | ~2K input + ~2K output | API 调用为主 | | Selector (1 次) | ~3K input + ~1K output | 固定 | | **总增量** | **~80K tokens** | 约增加 $0.30-0.50/run | ### D.3 测试策略 1. **单元测试**:每个 Agent 独立测试(mock LLM 响应) 2. **集成测试**:使用固定 results.json 测试 FigureAgent 完整流程 3. **回归测试**:确认 fallback 到旧系统仍可正常工作 4. **端到端测试**:Run 14+ 完整 Pipeline 运行,对比图表质量 --- ## Part E: 执行优先级 建议按以下顺序实施: ### Phase 1: FigureAgent 核心(优先级最高) 1. FA-01 ~ FA-03: 基础类 + Planner + CodeGen 2. FA-04 ~ FA-05: Renderer + Critic 3. FA-08: SciencePlots 集成 4. FA-12: 模板库 ### Phase 2: FigureAgent 集成 5. FA-06 ~ FA-07: Integrator + Orchestrator 6. FA-09 ~ FA-11: Pipeline 集成 7. FA-13 ~ FA-14: 测试 + 向后兼容 ### Phase 3: BenchmarkAgent 核心 8. BA-01 ~ BA-02: 基础类 + 知识库 9. BA-03 ~ BA-06: 4 个子 Agent 10. BA-07: Orchestrator ### Phase 4: BenchmarkAgent 集成 11. BA-08 ~ BA-10: Pipeline 集成 12. BA-11 ~ BA-12: 测试 + 向后兼容 ### Phase 5: 端到端验证 13. 完整 Pipeline 运行(Run 14+) 14. 对比图表质量和数据集选择质量 15. 根据结果调优 --- ## Appendix: 参考资料 | 来源 | 关键收获 | |------|---------| | [AI Scientist v2](https://github.com/SakanaAI/AI-Scientist-v2) | VLM-as-critic, 首篇通过 ICLR workshop 审稿 | | [PlotGen (Adobe)](https://arxiv.org/abs/2502.00988) | 三模态反馈:数值 + 文本 + 视觉 | | [PaperBanana (Google)](https://github.com/llmsresearch/paperbanana) | Caption sharpening + 参考图库检索 | | [SciencePlots](https://github.com/garrettj403/SciencePlots) | 学术论文 matplotlib 样式库 | | [VLM-Enhanced Discovery](https://arxiv.org/html/2511.14631) | Correction mode + Discovery mode | | [Papers With Code API](https://paperswithcode.com/api/v1/) | 领域 benchmark 排行榜搜索 | | [HuggingFace Datasets](https://huggingface.co/docs/datasets/) | 数据集搜索和加载 API | ================================================ FILE: docs/figure_prompts/case_a_meta_learning.md ================================================ # Case A: Continual Meta-Learning — Image Generation Prompt ## Prompt A premium, modern data visualization infographic on a clean white background with subtle light-gray grid lines. The chart is a **line chart** showing progressive performance improvement across 5 data points on the X-axis (labeled "Self-Iteration Round"). **Overall title** at the top in bold dark navy sans-serif font: "Case A: Continual Meta-Learning for Few-Shot Adaptation" **Y-axis:** "Few-Shot Accuracy (%)" ranging from 15% to 105%. **X-axis:** "Self-Iteration Round" with 5 labeled tick marks. **Data points and line:** - Point 0 (Baseline): 25.9% — large circle marker, colored **slate gray** (#757575). X-label below: "Baseline" with a small gray beaker/flask icon, subtitle "(Initial Code)". - Point 1 (Iter 1): 81.2% — large circle marker, colored **emerald green** (#2E7D32). X-label: "Iter 1" with a small green brain/neural-network icon, subtitle "(Deep Encoder + Meta-SGD)". - Point 2 (Iter 2): 77.5% — large circle marker, colored **crimson red** (#C62828). X-label: "Iter 2" with a small red warning-triangle icon, subtitle "(Prototype Net — Regression)". - Point 3 (Iter 3): 93.4% — large circle marker, colored **emerald green** (#2E7D32). X-label: "Iter 3" with a small green rocket icon, subtitle "(Linear Clf + L2 Anchor)". - Point 4 (Iter 4): 93.4% — large circle marker, colored **slate gray** (#757575). X-label: "Iter 4" with a small gray checkmark-circle icon, subtitle "(Converged)". **Connecting line:** Thick (3px) solid line in **royal blue** (#1565C0) connecting all 5 points in order. The area below the line (above the baseline value of 25.9%) is filled with a very light semi-transparent blue wash (#1565C0 at 8% opacity). **Annotations with callout arrows:** - Near Point 1: A green callout box with text "+55.3 pts" in bold green, below it "Deep encoder + context-gated replay" in smaller green text. A thin green arrow points from the callout to Point 1. Include a small upward-arrow icon. - Near Point 2: A red italic callout "Prototype net too simple" with a thin red arrow pointing to Point 2. Include a small X-mark icon. - Near Point 3: A green callout box with text "+15.9 pts" in bold green, below it "Linear clf + L2 anchor + cosine gating" in smaller green text. A thin green arrow points from the callout to Point 3. Include a small upward-arrow icon. **Reference line:** A horizontal **dashed orange line** (#E65100) at y=100% with a small label "Oracle (100%)" at the right end in italic orange text. Include a small trophy/target icon next to the label. **Summary stats box:** Upper-left corner, a rounded rectangle with light blue background (#E3F2FD) and blue border (#1565C0), containing monospace text: ``` Baseline: 25.9% → Best: 93.4% Improvement: +67.5 pts (261% rel.) ``` **Legend** at the bottom center with three items, each with a colored square swatch: - Green square: "Improved" - Red square: "Regressed (auto-recovered)" - Gray square: "No change / Baseline" **Style:** Clean, professional, tech-forward aesthetic. Use a modern sans-serif font (like Inter, SF Pro, or Helvetica Neue). Subtle drop shadows on the summary box and annotation callouts. Smooth anti-aliased lines. The overall feel should be suitable for a top-tier AI research company's product page or investor deck — polished, data-rich, and visually compelling. High contrast text. No 3D effects. Flat design with depth through subtle shadows and layering. **Dimensions:** 1200 x 900 pixels, 2x retina resolution. ================================================ FILE: docs/figure_prompts/case_b_rlhf_alignment.md ================================================ # Case B: RLHF with Curriculum Reward Shaping — Image Generation Prompt ## Prompt A premium, modern data visualization infographic on a clean white background with subtle light-gray grid lines. The chart is a **line chart with square markers** showing progressive performance improvement across 5 data points on the X-axis (labeled "Self-Iteration Round"). **Overall title** at the top in bold dark navy sans-serif font: "Case B: RLHF with Curriculum-Based Reward Shaping for LLM Alignment" **Y-axis:** "LLM Alignment Score (%)" ranging from 15% to 105%. **X-axis:** "Self-Iteration Round" with 5 labeled tick marks. **Data points and line:** - Point 0 (Baseline): 35.6% — large square marker, colored **slate gray** (#757575). X-label below: "Baseline" with a small gray play-button icon, subtitle "(Vanilla PPO)". - Point 1 (Iter 1): 35.6% — large square marker, colored **slate gray** (#757575). X-label: "Iter 1" with a small gray pause icon, subtitle "(No Change)". - Point 2 (Iter 2): 61.6% — large square marker, colored **emerald green** (#2E7D32). X-label: "Iter 2" with a small green sparkle/star icon, subtitle "(+Reward Model +Curriculum)". - Point 3 (Iter 3): 63.0% — large square marker, colored **emerald green** (#2E7D32). X-label: "Iter 3" with a small green chart-trending-up icon, subtitle "(+Rank-Norm +Policy EMA)". - Point 4 (Iter 4): 66.6% — large square marker, colored **emerald green** (#2E7D32). X-label: "Iter 4" with a small green shield-check icon, subtitle "(+Confidence Gating)". **Connecting line:** Thick (3px) solid line in **deep purple** (#6A1B9A) connecting all 5 points in order. The area below the line (above the baseline value of 35.6%) is filled with a very light semi-transparent purple wash (#6A1B9A at 8% opacity). **Annotations with callout arrows:** - Near Point 1: A gray italic callout "No improvement (minor code fix)" with a thin gray arrow pointing down to Point 1. Include a small minus-circle icon. - Near Point 2: A green callout box with text "+26.0 pts" in bold green, below it "+Learned reward model" and "+Curriculum scheduling" in smaller green text. A thin green arrow points from the callout to Point 2. Include a small upward-arrow icon and a tiny brain icon. - Near Point 3: A smaller green callout with text "+1.4 pts" in green, below it "+Rank-norm +Policy EMA" in smaller text. A thin green arrow points to Point 3. Include a small upward-arrow icon. - Near Point 4: A green callout box with text "+3.6 pts" in bold green, below it "+Confidence gating" and "+Mini-batch RM" in smaller green text. A thin green arrow points to Point 4. Include a small upward-arrow icon and a tiny lock/shield icon. **Summary stats box:** Upper-left corner, a rounded rectangle with light purple background (#F3E5F5) and purple border (#6A1B9A), containing monospace text: ``` Baseline: 35.6% → Best: 66.6% Improvement: +31.0 pts (87% rel.) ``` **Legend** at the bottom center with three items, each with a colored square swatch: - Green square: "Improved" - Red square: "Regressed (auto-recovered)" - Gray square: "No change / Baseline" **Style:** Clean, professional, tech-forward aesthetic. Use a modern sans-serif font (like Inter, SF Pro, or Helvetica Neue). Subtle drop shadows on the summary box and annotation callouts. Smooth anti-aliased lines. The overall feel should be suitable for a top-tier AI research company's product page or investor deck — polished, data-rich, and visually compelling. High contrast text. No 3D effects. Flat design with depth through subtle shadows and layering. **Dimensions:** 1200 x 900 pixels, 2x retina resolution. ================================================ FILE: docs/integration-guide.md ================================================ # AutoResearchClaw Integration Guide > **The simplest way to use AutoResearchClaw**: give the repo URL to [OpenClaw](https://github.com/openclaw/openclaw) and say *"Research [your topic]."* That's it — OpenClaw handles cloning, installing, configuring, and running the entire 23-stage pipeline for you. This guide is for humans who want to understand what's happening under the hood, or who prefer to set things up manually. --- ## Table of Contents 1. [The Easy Way: OpenClaw](#1-the-easy-way-openclaw) 2. [Manual Setup](#2-manual-setup) 3. [Configuration Walkthrough](#3-configuration-walkthrough) 4. [Running the Pipeline](#4-running-the-pipeline) 5. [Understanding the 23 Stages](#5-understanding-the-23-stages) 6. [Output Artifacts](#6-output-artifacts) 7. [Experiment Modes](#7-experiment-modes) 8. [Conference Templates](#8-conference-templates) 9. [OpenClaw Bridge (Advanced)](#9-openclaw-bridge-advanced) 10. [MetaClaw Integration (Cross-Run Learning)](#10-metaclaw-integration-cross-run-learning) 11. [Other AI Platforms](#11-other-ai-platforms) 12. [Python API](#12-python-api) 13. [Troubleshooting](#13-troubleshooting) 14. [FAQ](#14-faq) --- ## 1. The Easy Way: OpenClaw If you use [OpenClaw](https://github.com/openclaw/openclaw) as your AI assistant, you don't need to read the rest of this guide. ### Steps 1. Share the GitHub repo URL with OpenClaw: ``` https://github.com/aiming-lab/AutoResearchClaw ``` 2. OpenClaw reads `RESEARCHCLAW_AGENTS.md` and `README.md` — it now understands the entire system. > **Note:** `RESEARCHCLAW_AGENTS.md` is generated locally and listed in `.gitignore`. If it doesn't exist, OpenClaw can bootstrap from `README.md` and the project structure. 3. Say something like: ``` Research the application of graph neural networks in drug discovery ``` 4. OpenClaw will: - Clone the repo - Create a virtual environment and install dependencies (`pip install -e .`) - Copy `config.researchclaw.example.yaml` → `config.yaml` - Ask you for an OpenAI API key (or use your environment variable) - Run the full 23-stage pipeline - Return the paper, experiment code, charts, and citations **That's the whole process.** OpenClaw is designed to read agent definition files and bootstrap itself. AutoResearchClaw ships with these files specifically so that any OpenClaw-compatible AI assistant can pick it up and run. ### What if I want to tweak settings? Tell OpenClaw in natural language: - *"Use GPT-5.2 instead of GPT-4o"* - *"Run experiments in sandbox mode, not simulated"* - *"Target ICLR 2025 format instead of NeurIPS"* - *"Skip the quality gate, just auto-approve everything"* OpenClaw will modify `config.yaml` accordingly before running the pipeline. --- ## 2. Manual Setup ### Prerequisites | Requirement | Details | |-------------|---------| | **Python** | 3.11 or newer | | **LLM API** | Any OpenAI-compatible endpoint (OpenAI, Azure, local proxy, etc.) | | **Disk space** | ~100 MB for the repo + artifacts per run | | **Network** | Required for LLM API calls and literature search (Semantic Scholar, arXiv) | ### Installation ```bash # Clone the repository git clone https://github.com/aiming-lab/AutoResearchClaw.git cd AutoResearchClaw # Create a virtual environment (recommended) python3 -m venv .venv source .venv/bin/activate # macOS/Linux # .venv\Scripts\activate # Windows # Install pip install -e . ``` ### Verify Installation ```bash # Check the CLI is available researchclaw --help # Validate your configuration researchclaw validate --config config.yaml ``` --- ## 3. Configuration Walkthrough Start from the provided template: ```bash cp config.researchclaw.example.yaml config.yaml ``` Open `config.yaml` in your editor. Here's what each section does: ### LLM Settings (Required) This is the only section you **must** configure. Everything else has sensible defaults. ```yaml llm: base_url: "https://api.openai.com/v1" # Your LLM API endpoint api_key_env: "OPENAI_API_KEY" # Environment variable name... api_key: "" # ...or paste the key directly here primary_model: "gpt-4o" # Model to use (gpt-4o, gpt-5.2, etc.) fallback_models: # Tried in order if primary fails - "gpt-4.1" - "gpt-4o-mini" s2_api_key: "" # Optional: Semantic Scholar API key for higher rate limits ``` **Using an environment variable** (recommended for security): ```bash export OPENAI_API_KEY="sk-..." ``` **Using a direct key** (simpler, less secure): ```yaml llm: api_key: "sk-your-key-here" ``` **Using a proxy or alternative provider**: ```yaml llm: base_url: "https://your-proxy.example.com/v1" api_key: "your-proxy-key" primary_model: "gpt-4o" # Must be supported by your endpoint ``` ### Research Settings ```yaml research: topic: "Your research topic here" # Can also be set via CLI --topic flag domains: - "machine-learning" # Guides literature search scope daily_paper_count: 10 # Target papers to collect quality_threshold: 4.0 # Minimum paper quality score (1-5) ``` ### Experiment Settings ```yaml experiment: mode: "sandbox" # How experiments run (see Section 7) time_budget_sec: 300 # Max seconds per experiment run max_iterations: 10 # Max refinement loops in Stage 13 metric_key: "primary_metric" # What metric to optimize metric_direction: "minimize" # "minimize" or "maximize" sandbox: python_path: ".venv/bin/python3" # Python binary for sandbox execution gpu_required: false max_memory_mb: 4096 code_agent: # CodeAgent v2 (multi-phase code generation) enabled: true # Architecture planning + sequential file gen + hard validation benchmark_agent: # Automated dataset & baseline selection enabled: true # 4-agent pipeline: Surveyor→Selector→Acquirer→Validator figure_agent: # Academic figure generation enabled: true # 5-agent pipeline: Planner→CodeGen→Renderer→Critic→Integrator repair: # Anti-fabrication experiment repair enabled: true # Diagnose and fix failed experiments before paper writing max_cycles: 3 # Repair retry loops opencode: # OpenCode Beast Mode (see README for details) enabled: true ``` ### Export Settings ```yaml export: target_conference: "neurips_2025" # See Section 8 for all available templates authors: "Anonymous" # Author line in the paper bib_file: "references" # BibTeX file name (without .bib) ``` ### Everything Else (Optional) These have reasonable defaults. Change them only if you need to: ```yaml project: name: "my-research" # Just an identifier for your run mode: "full-auto" # "docs-first", "semi-auto", or "full-auto" runtime: timezone: "America/New_York" max_parallel_tasks: 3 approval_timeout_hours: 12 retry_limit: 2 security: hitl_required_stages: [5, 9, 20] # Stages that pause for human approval allow_publish_without_approval: false notifications: channel: "console" # "console", "discord", or "slack" knowledge_base: backend: "markdown" root: "docs/kb" ``` --- ## 4. Running the Pipeline ### Basic Run ```bash # Run with topic from config.yaml researchclaw run --config config.yaml --auto-approve # Override topic from command line researchclaw run --config config.yaml --topic "Transformer attention for time series" --auto-approve ``` ### CLI Commands | Command | What It Does | |---------|-------------| | `researchclaw setup` | Interactive first-time setup (installs OpenCode Beast Mode, checks Docker/LaTeX) | | `researchclaw init` | Interactive config creation (choose LLM provider, creates `config.arc.yaml`) | | `researchclaw run` | Run the full 23-stage pipeline | | `researchclaw validate` | Check your config file for errors | | `researchclaw doctor` | Diagnose environment issues (Python, dependencies, API connectivity) | | `researchclaw report --run-dirFrom a one-line idea to a conference-ready paper — fully autonomous, zero human intervention.
|
**💡** **Idea** |
➜ |
**📚** **Literature** 300–470 papers |
➜ |
**🧪** **Hypothesis** experiment design |
➜ |
**💻** **Code** 2K–15K lines |
➜ |
**🔬** **Execute** sandbox + refine |
➜ |
**📝** **Write** review & audit |
➜ |
**📄** **Paper** NeurIPS PDF |
Each run traverses 23 autonomous stages with iterative self-healing, multi-agent peer review, and citation verification — no human in the loop.
---Generated on Machine A · 4 papers across 4 non-ML domains
--- ### 📄 Paper I · Random Matrix Theory
Finite-dimensional correction pipeline: Wishart matrix generation → empirical spectral density estimation → MP baseline comparison → bulk/edge error decomposition → correction model fitting. Entirely auto-generated by the FigureAgent subsystem.
Monte Carlo IV evaluation pipeline: DGP specification → estimator suite (2SLS, LIML, Fuller-k, JIVE) → finite-sample risk surfaces → phase diagram construction. Entirely auto-generated by the FigureAgent subsystem.
PRIM benchmark workflow: synthetic outbreak generation (SIR/SEIR) → parameter estimation → profile likelihood vs. FIM diagnostics → identifiability regime mapping. Entirely auto-generated by the FigureAgent subsystem.
Feature-conditioned preconditioner evaluation: sparse matrix collection → structural descriptor extraction → solver–preconditioner grid (CG/GMRES/BiCGSTAB × ILU/Jacobi/SSOR/AMG) → setup-vs-solve tradeoff analysis → decision map. Entirely auto-generated by the FigureAgent subsystem.
Generated on Machine B · NVIDIA RTX 6000 Ada (48 GB) · 4 papers across 4 ML sub-fields
--- ### 📄 Paper V · Parameter-Efficient Fine-Tuning
Gradient spectral analysis → layer-wise rank scoring → dynamic rank allocation under budget constraint. Entirely auto-generated by the FigureAgent subsystem.
Learned state abstraction module integrated with count-based exploration in the DQN agent loop. Entirely auto-generated by the FigureAgent subsystem.
Frequency-aware token merging applied progressively across ViT layers with DCT-based spectral filtering. Entirely auto-generated by the FigureAgent subsystem.
Reliability-aware contrastive feature alignment between teacher and student across clean and corrupted views, with de-alignment on fragile teacher directions. Entirely auto-generated by the FigureAgent subsystem.
| 📋 Metric | I | II | III | IV | V | VI | VII | VIII | 🏆 Total |
|---|---|---|---|---|---|---|---|---|---|
| 🏷️ Domain | Math | Stats | Bio | NumLA | NLP | RL | CV | KD | 8 fields |
| 💻 Code (LOC) | 10,290 | 10,062 | 9,374 | 14,557 | 2,894 | 2,067 | 2,873 | 2,231 | 54,348 |
| ⏱️ Pipeline Time | 2h25m | 2h56m | 2h23m | 2h30m | 50m | 6h48m | 3h18m | 5h48m | ~27 hrs |
| 🔗 References | 26 | 41 | 29 | 33 | 60 | 25 | 40 | 37 | 291 cited |
| 📊 Figures | 5 | 6 | 6 | 4 | 7 | 6 | 7 | 9 | 50 figs |
| 📑 Pages | 16 | 14 | 18 | 16 | 17 | 11 | 10 | 19 | 121 pages |
Every paper above was generated by a single command:
```bash researchclaw run --topic "Your research idea here" --auto-approve ``` ================================================ FILE: prompts.default.yaml ================================================ # ============================================================================= # AutoResearchClaw — Default Prompt Templates # ============================================================================= # # Copy this file, edit any prompt you want to customize, and point your config # to the copy: # # prompts: # custom_file: "my_prompts.yaml" # # Template variables use {var_name} syntax — see docs/integration-guide.md # for a list of available variables per stage. # # Stages without an entry here (experiment_run, citation_verify) do not call # the LLM and therefore have no prompts to customize. # ============================================================================= blocks: compute_budget: | ## Compute Budget Constraint - Total execution time limit: {time_budget_sec} seconds - You MUST design experiments that complete within this budget - Estimate: a simple numpy loop runs ~10M iterations/sec; a nested loop over conditions runs proportionally slower - SCALING RULES (mandatory): - If total conditions > 100: reduce seeds to 3-5 (not 20) - If total conditions > 500: reduce to 2-3 representative conditions per factor - If time_budget < 300s: limit total optimization steps to ≤5,000 per run - If time_budget < 120s: limit total optimization steps to ≤1,000 per run - Always print intermediate results so partial data is captured on timeout - MANDATORY: print a "TIME_ESTIMATE: Xs" line before the main loop, estimating total runtime based on a small pilot (run 1 condition, extrapolate) - MANDATORY: implement a time guard — check elapsed time periodically and stop gracefully if approaching 80% of budget, saving all results collected so far pkg_hint_sandbox: ' AVAILABLE PACKAGES (sandbox mode): Python stdlib, numpy, math, random, statistics, json. Do NOT use: torch, tensorflow, jax, sklearn, pandas, scipy, matplotlib, or any deep learning framework. Write the experiment using ONLY numpy and stdlib. ' topic_constraint: ' === HARD TOPIC CONSTRAINT === The paper MUST be about: {topic} PROHIBITED content (unless user explicitly specifies case-study mode): - Do NOT treat environment setup, dependency installation, or infrastructure failures as a research contribution. - Do NOT present debugging logs, system errors, or configuration issues as experimental findings. - Do NOT drift to tangential topics not directly related to the stated topic. - Every section MUST connect back to the core research question. - The Abstract and Introduction MUST clearly state the research problem derived from: {topic} - The Method section MUST describe a technical approach, not a workflow. - The Results section MUST report quantitative outcomes of experiments, not environment status. === END CONSTRAINT === ' stages: code_generation: max_tokens: 8192 system: You are a computational scientist who writes real, runnable experiments. Your code implements actual algorithms with real mathematical operations. You NEVER fake results with random number generators. Always use the ```filename:xxx.py format for each file. Use numpy for numerical computation. Keep code self-contained and deterministic. user: "Generate a Python experiment project for the following research topic:\nTOPIC: {topic}\n\nCRITICAL REQUIREMENTS\ \ — your code MUST satisfy ALL of these:\n1. Implement REAL algorithms (e.g., gradient descent, Adam, SGD, etc.)\n \ \ using numpy arrays — NOT random.uniform() loops that fake results.\n2. Define REAL objective/loss functions (e.g.,\ \ Rosenbrock, quadratic,\n cross-entropy on synthetic data) with proper mathematical formulas.\n3. Run REAL optimization\ \ loops that compute gradients and update parameters.\n4. Collect REAL metrics (loss values, convergence rates) from\ \ the optimization.\n5. The code must be scientifically meaningful — a reviewer should see\n actual algorithm implementations,\ \ not random number generators.\n\nOUTPUT FORMAT — return multiple files using this exact format:\n```filename:main.py\n\ # entry point code\n```\n\n```filename:optimizers.py\n# optimizer implementations\n```\n\nCODE STRUCTURE:\n- main.py:\ \ entry point that runs experiments and prints metrics\n- Additional modules for algorithms, objective functions, utilities\n\ - Primary metric key: {metric}\n- main.py must print metric lines as `name: value` (one per line)\n- main.py must ALSO\ \ write a `results.json` file with structured experiment results\n (e.g. per-algorithm, per-function, per-dimension metrics\ \ as nested dicts/lists)\n- Use deterministic seeds (numpy.random.seed or random.seed)\n- No external data files, no\ \ network calls, no GPU required\n- FORBIDDEN: subprocess, os.system, eval, exec, shutil, socket\n- MUST implement convergence\ \ stopping criteria (e.g. stop when objective change < 1e-8 for\n N consecutive iterations) — do NOT just run a fixed\ \ number of iterations\n{pkg_hint}\nANTI-PATTERNS (do NOT do these):\n- Do NOT generate random numbers and pretend they\ \ are experiment results\n- Do NOT use `random.uniform()` to simulate a decreasing loss curve\n- Do NOT hardcode metric\ \ values or use trivial arithmetic as metrics\n- Do NOT run a fixed number of iterations without any convergence check\n- Do NOT implement convergence_rate or similar metrics as dummy return values\n (e.g. returning 1.0 or a constant) — measure actual iterations to convergence\n- If you report convergence_rate, define it as iterations_to_convergence / max_iterations\n or similar — it MUST differ between algorithms\n\nNUMPY 2.x COMPATIBILITY (CRITICAL):\n- np.trapz is REMOVED → use np.trapezoid\n- np.erfinv does NOT exist → use scipy.special.erfinv\n- np.bool, np.int, np.float, np.complex are REMOVED → use Python builtins\n- np.str, np.object are REMOVED → use str, object\n- np.math is REMOVED → use math module\n\nExperiment plan:\n{exp_plan}" experiment_design: system: You are a principal investigator designing ML experiments. user: '{preamble} Design an experiment plan as YAML. Required keys: objectives,datasets,baselines,proposed_methods,ablations,metrics,risks,compute_budget. Hypotheses: {hypotheses}' export_publish: max_tokens: 16384 system: You are a publication formatting editor. user: 'Format revised paper into clean final markdown for publication export. Preserve content quality and readability. Input paper: {revised}' hypothesis_gen: system: You formulate testable scientific hypotheses. user: 'Generate at least 2 falsifiable hypotheses from synthesis. Output markdown and for each hypothesis provide rationale, measurable prediction, failure condition. Synthesis: {synthesis}' knowledge_archive: system: You produce reproducibility-focused research retrospectives. user: '{preamble} Write retrospective archive markdown with lessons, reproducibility notes, and future work. Decision: {decision} Analysis: {analysis} Revised paper: {revised}' knowledge_extract: json_mode: true system: You extract high-signal evidence cards from papers. user: 'Extract structured knowledge cards from shortlist. Return JSON: {cards:[{card_id,title,cite_key,problem,method,data,metrics,findings,limitations,citation}]}. IMPORTANT: If the input contains cite_key fields, preserve them exactly in the output. Shortlist: {shortlist}' literature_collect: json_mode: true system: You are a literature mining assistant. user: 'Generate candidate papers from the search plan. Return JSON: {candidates:[...]} with >=20 rows. Each candidate must include id,title,source,url,year,abstract,collected_at. Topic: {topic} Search plan: {plan_text}' literature_screen: json_mode: true system: You are a strict domain-aware reviewer. Reject off-topic papers aggressively. user: 'Perform merged relevance+quality screening and return shortlist. Return JSON: {shortlist:[...]} each with title, cite_key (if present), relevance_score (0-1), quality_score (0-1), keep_reason. Preserve all original fields (paper_id, doi, arxiv_id, cite_key, etc.) from the input. Topic: {topic} Domains: {domains} Threshold: {quality_threshold} IMPORTANT: Only keep papers genuinely relevant to the topic above. Reject papers about unrelated domains even if they are high quality. Candidates JSONL: {candidates_text}' paper_draft: max_tokens: 32768 system: "You are a top-tier ML paper author writing for NeurIPS/ICML/ICLR.\n\n\ KEY PRINCIPLES (from accepted paper analyses):\n\ 1. NOVELTY: A good paper has 1-2 key ideas and keeps the rest simple. Think sushi, not curry.\n\ 2. NARRATIVE: The paper is a short, rigorous, evidence-based technical story with a takeaway readers care about.\n\ 3. FIGURE 1: The most important figure. It should convey whatever is most important — many readers go straight to Figure 1.\n\ 4. STRONG BASELINES: Invest real effort in making baselines competitive. Reviewers catch weak baselines.\n\ 5. ABLATIONS: Remove one component at a time and measure the effect. Without ablations, reviewers cannot tell which parts matter.\n\ 6. HONESTY: Acknowledge limitations explicitly. Papers that don't are substantially weaker.\n\ 7. CONTRIBUTIONS: State contributions clearly in Abstract AND Introduction. Many reviewers stop reading carefully after the intro.\n\ 8. REPRODUCIBILITY: Include all details needed to reproduce: hyperparameters, data processing, random seeds, hardware specs.\n\n\ COMMON REJECTION REASONS (avoid these):\n\ - Overclaiming: match claims to evidence\n\ - Missing ablations: systematically demonstrate each component's contribution\n\ - Weak baselines: tune baselines with the same effort as your method\n\ - Poor reproducibility: include every detail needed to replicate\n\n\ You ONLY use real experimental data — never fabricate or approximate numbers. Every metric value must exactly match the provided experiment output.\n\ You write at the depth and length expected for a 9-page conference paper (approximately 5000-6500 words in the main body, excluding references)." user: '{preamble} Write a FULL-LENGTH paper draft section by section in markdown. This paper must be suitable for submission to a top-tier ML conference (NeurIPS, ICML, ICLR). CRITICAL LENGTH REQUIREMENTS — each section MUST meet its minimum word count: 1. **Title**: Concise, informative (10-15 words) 2. **Abstract** (150-250 words): Problem, method, key results with numbers, conclusion 3. **Introduction** (800-1000 words): Motivation with real-world context, problem statement, research gap analysis, brief method overview, contribution list (3-4 bullet points), paper organization 4. **Related Work** (600-800 words): Organized by 3-4 thematic groups, each with 4-5 citations. Compare and contrast approaches, identify limitations of prior work, position this work clearly 5. **Method** (1000-1500 words): Formal problem definition with mathematical notation, detailed algorithm description with equations, complexity analysis, design rationale for key choices 6. **Experiments** (800-1200 words): Detailed experimental setup (datasets, preprocessing, data splits), baselines and their implementations, hyperparameter settings (in a table), evaluation metrics with justification, hardware and runtime information 7. **Results** (600-800 words): Main results table(s) with ALL metrics, per-condition analysis, statistical significance discussion, ablation studies, qualitative analysis where relevant 8. **Discussion** (400-600 words): Interpretation of key findings, unexpected results analysis, comparison with prior work, practical implications 9. **Limitations** (200-300 words): Honest assessment of scope, dataset, methodology, and generalizability limitations 10. **Conclusion** (200-300 words): Summary of contributions, main findings, and concrete future work directions TOTAL TARGET: 5000-6500 words in the main body. If any section is shorter than its minimum, EXPAND it with substantive technical content — NOT filler. QUALITY STANDARDS: - Use formal academic language throughout - Include mathematical notation where appropriate (use LaTeX-style $...$ for inline math) - Every claim must be supported by either a citation or experimental evidence - Results tables should use markdown table format with proper column headers - Provide algorithm pseudocode in the Method section when applicable Required sections: Title, Abstract, Introduction, Related Work, Method, Experiments, Results, Discussion, Limitations, Conclusion. Do NOT include a References section — it will be auto-generated. {topic_constraint}{exp_metrics_instruction}{citation_instruction}Outline: {outline}' paper_outline: max_tokens: 8192 system: You are an academic writing planner. user: '{preamble} Create a detailed paper outline in markdown. Include per-section goals and evidence links. {topic_constraint}{feedback}Analysis: {analysis} Decision: {decision}' paper_revision: max_tokens: 32768 system: You are a paper revision expert for NeurIPS/ICML/ICLR submissions. When revising, NEVER shorten existing sections — only expand, improve, and add content. The final paper must be at least as long as the draft. user: 'Revise the paper draft to address all review comments. CRITICAL: Maintain or INCREASE the paper length. Each section must meet its minimum word count: Abstract (150-250), Introduction (800-1000), Related Work (600-800), Method (1000-1500), Experiments (800-1200), Results (600-800), Discussion (400-600), Limitations (200-300), Conclusion (200-300). Return revised markdown only. {topic_constraint}Draft: {draft} Reviews: {reviews}' peer_review: max_tokens: 8192 system: You are a balanced conference reviewer who is rigorous about methodology-evidence consistency. user: 'Simulate peer review from at least 2 reviewer perspectives. Output markdown with Reviewer A and Reviewer B, each including strengths, weaknesses, and actionable revisions. Check specifically: 1. Does the paper stay on topic ({topic})? Flag any sections where the paper drifts to unrelated topics or presents environment issues as contributions. 2. METHODOLOGY-EVIDENCE CONSISTENCY: Compare the paper''s claims about experimental setup (number of trials, statistical tests, hyperparameters, baselines) against the actual experiment evidence provided below. Flag any discrepancies where the paper claims something that is NOT supported by the actual code or results. For example: - Paper claims N trials but code shows a different number - Paper claims statistical tests (ANOVA, t-test) but code has none - Paper reports metrics not present in actual results - Paper describes methods not implemented in code 3. TRIAL COUNT: The actual number of experiment runs is stated in the evidence below. If the paper claims a DIFFERENT number of trials (e.g., "100 independent trials" when only 1 was run), flag this as a CRITICAL fabrication that MUST be corrected. 4. PAPER LENGTH: This paper targets NeurIPS/ICML submission (9 pages). Check that each section has adequate depth. Flag sections that are too short: Abstract (<150 words), Introduction (<700 words), Related Work (<500 words), Method (<800 words), Experiments (<600 words), Results (<500 words). A paper with fewer than 4000 total words is CRITICALLY under-length. 5. REVIEW LIKE A TOP-CONFERENCE REVIEWER: - Is the contribution novel, or is it incremental over well-known work? - Are baselines properly tuned and competitive? - Are ablation studies present and meaningful? - Is every claim supported by evidence from the experiments? - Does the paper acknowledge its limitations honestly? - Would you recommend this paper be presented at NeurIPS/ICML? Why or why not? - Score the paper 1-10 following this rubric: 1-3 Reject (fundamental flaws), 4-5 Borderline (significant weaknesses), 6-7 Weak Accept (solid but not exciting), 8-9 Accept (strong contribution), 10 Strong Accept (exceptional). Paper draft: {draft} {experiment_evidence}' problem_decompose: system: You are a senior research strategist. user: 'Decompose this research problem into at least 4 prioritized sub-questions. Topic: {topic} Output markdown with sections: Source, Sub-questions, Priority Ranking, Risks. Goal context: {goal_text}' quality_gate: json_mode: true system: You are a final quality gate evaluator. user: 'Evaluate revised paper quality and return JSON. Schema: {score_1_to_10:number, verdict:string, strengths:[...], weaknesses:[...], required_actions:[...]}. Threshold: {quality_threshold} Paper: {revised}' research_decision: system: You are a research program lead making go/no-go decisions. user: 'Make a PROCEED or PIVOT decision from analysis. Output markdown with: Decision, Justification, Evidence, Next Actions. Analysis: {analysis}' resource_planning: json_mode: true system: You are an experiment scheduler. user: 'Create schedule JSON with GPU/time estimates. Schema: {tasks:[{id,name,depends_on,gpu_count,estimated_minutes,priority}], total_gpu_budget, generated}. Experiment plan: {exp_plan}' result_analysis: system: You are a quantitative ML analyst. Always cite exact numbers from the provided data. user: '{preamble} {data_context} Analyze run metrics and produce markdown report with statistical interpretation. Use the ACTUAL quantitative values provided above — do NOT invent numbers. Required sections: Metrics Summary (with real values), Comparative Findings, Statistical Checks, Limitations, Conclusion. Run context: {context}' search_strategy: json_mode: true system: You design literature retrieval strategies and source verification plans. You aim for COMPREHENSIVE coverage — a good research paper needs 30-60 references. user: 'Create a merged search strategy package. Return a JSON object with keys: search_plan_yaml, sources. search_plan_yaml must be valid YAML text with search_strategies containing at least 3 strategies, each with 3-5 diverse keyword queries (short, 3-6 words each). Generate at least 8 total queries. Cover: core topic, related methods, benchmarks/datasets, theoretical foundations, applications. sources must include id,name,type,url,status,query,verified_at. Topic: {topic} Problem tree: {problem_tree}' synthesis: system: You are a synthesis specialist for literature reviews. user: 'Produce merged synthesis output (topic clusters + research gaps). Output markdown with sections: Cluster Overview, Cluster 1..N, Gap 1..N, Prioritized Opportunities. Topic: {topic} Cards context: {cards_context}' topic_init: system: You are a rigorous research planner. user: 'Create a SMART research goal in markdown. Topic: {topic} Domains: {domains} Project: {project_name} Quality threshold: {quality_threshold} Required sections: Topic, Scope, SMART Goal, Constraints, Success Criteria, Generated.' sub_prompts: code_repair: system: You fix Python code validation errors while preserving functionality. user: 'The file `{fname}` in the experiment project has validation errors. Fix ALL issues and return ONLY the corrected file. ## Validation Issues in {fname} {issues_text} ## All Project Files {all_files_ctx} IMPORTANT: Do NOT use subprocess, os.system, eval, exec, or any network/shell calls. Return ONLY the corrected code for `{fname}`.' iterative_improve: max_tokens: 8192 system: You improve experiment projects and return valid executable Python code. Use ```filename:xxx.py format for each file. user: 'Improve the experiment code based on prior run results. Return the improved files using ```filename:xxx.py format for each file. Primary metric key: {metric_key} Metric direction: {metric_direction} Do not use subprocess, os.system, eval, exec, or any network/shell calls. Current project files: {files_context} Run summaries (JSON): {run_summaries}' iterative_repair: system: You fix Python code issues — both static validation errors and runtime bugs (NaN, Inf, division by zero, overflow). Diagnose the ROOT CAUSE from warnings and error messages. Do not add unsafe behavior. user: 'Fix all issues in the experiment code and return corrected Python code using ```filename:xxx.py format for each file. IMPORTANT: If you see NaN/Inf or RuntimeWarning about division or invalid values, trace the bug to its source (e.g. division by zero, uninitialized array, missing convergence check) and fix the actual code logic — do NOT just add try/except to suppress the error. ## Issues Found {issue_text} ## All Project Files {all_files_ctx}' version: '1.0' ================================================ FILE: pyproject.toml ================================================ [project] name = "researchclaw" version = "0.3.1" description = "ResearchClaw — Autonomous Research Pipeline. Turn any research idea into a paper." requires-python = ">=3.11" dependencies = [ "pyyaml>=6.0", "rich>=13.0", "arxiv>=2.1", "numpy>=1.24", ] readme = "README.md" license = {text = "MIT"} [project.optional-dependencies] anthropic = ["httpx>=0.24"] web = ["scholarly>=1.7", "crawl4ai>=0.2", "tavily-python>=0.3"] pdf = ["PyMuPDF>=1.23"] all = [ "httpx>=0.24", "scholarly>=1.7", "crawl4ai>=0.2", "tavily-python>=0.3", "PyMuPDF>=1.23", "huggingface-hub>=0.20", "matplotlib>=3.7", "scipy>=1.10", ] dev = ["pytest>=7.0", "httpx>=0.24"] [project.scripts] researchclaw = "researchclaw.cli:main" [tool.hatch.build.targets.wheel] packages = ["researchclaw", "sibyl", "arc"] [tool.hatch.build.targets.wheel.force-include] "researchclaw/templates/styles" = "researchclaw/templates/styles" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" ================================================ FILE: researchclaw/__init__.py ================================================ """ResearchClaw — Autonomous Research Pipeline.""" __version__ = "0.3.1" ================================================ FILE: researchclaw/__main__.py ================================================ """Allow running as `python -m researchclaw`.""" import sys from researchclaw.cli import main sys.exit(main()) ================================================ FILE: researchclaw/adapters.py ================================================ """Typed adapter interfaces and deterministic recording stubs.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Protocol @dataclass(frozen=True) class FetchResponse: url: str status_code: int text: str @dataclass(frozen=True) class BrowserPage: url: str title: str class CronAdapter(Protocol): def schedule_resume(self, run_id: str, stage_id: int, reason: str) -> str: ... class MessageAdapter(Protocol): def notify(self, channel: str, subject: str, body: str) -> str: ... class MemoryAdapter(Protocol): def append(self, namespace: str, content: str) -> str: ... class SessionsAdapter(Protocol): def spawn(self, name: str, command: tuple[str, ...]) -> str: ... class WebFetchAdapter(Protocol): def fetch(self, url: str) -> FetchResponse: ... class BrowserAdapter(Protocol): def open(self, url: str) -> BrowserPage: ... @dataclass class RecordingCronAdapter: calls: list[tuple[str, int, str]] = field(default_factory=list) def schedule_resume(self, run_id: str, stage_id: int, reason: str) -> str: self.calls.append((run_id, stage_id, reason)) return f"cron-{len(self.calls)}" @dataclass class RecordingMessageAdapter: calls: list[tuple[str, str, str]] = field(default_factory=list) def notify(self, channel: str, subject: str, body: str) -> str: self.calls.append((channel, subject, body)) return f"message-{len(self.calls)}" @dataclass class RecordingMemoryAdapter: entries: list[tuple[str, str]] = field(default_factory=list) def append(self, namespace: str, content: str) -> str: self.entries.append((namespace, content)) return f"memory-{len(self.entries)}" @dataclass class RecordingSessionsAdapter: calls: list[tuple[str, tuple[str, ...]]] = field(default_factory=list) def spawn(self, name: str, command: tuple[str, ...]) -> str: self.calls.append((name, command)) return f"session-{len(self.calls)}" @dataclass class RecordingWebFetchAdapter: calls: list[str] = field(default_factory=list) def fetch(self, url: str) -> FetchResponse: self.calls.append(url) return FetchResponse(url=url, status_code=200, text=f"stub fetch for {url}") @dataclass class RecordingBrowserAdapter: calls: list[str] = field(default_factory=list) def open(self, url: str) -> BrowserPage: self.calls.append(url) return BrowserPage(url=url, title=f"Stub browser page for {url}") @dataclass class MCPMessageAdapter: """MessageAdapter backed by an MCP tool call.""" server_uri: str = "http://localhost:3000" def notify(self, channel: str, subject: str, body: str) -> str: return f"mcp-notify-{channel}" @dataclass class MCPWebFetchAdapter: """WebFetchAdapter backed by an MCP tool call.""" server_uri: str = "http://localhost:3000" def fetch(self, url: str) -> FetchResponse: return FetchResponse(url=url, status_code=200, text=f"mcp fetch for {url}") @dataclass class AdapterBundle: cron: CronAdapter = field(default_factory=RecordingCronAdapter) message: MessageAdapter = field(default_factory=RecordingMessageAdapter) memory: MemoryAdapter = field(default_factory=RecordingMemoryAdapter) sessions: SessionsAdapter = field(default_factory=RecordingSessionsAdapter) web_fetch: WebFetchAdapter = field(default_factory=RecordingWebFetchAdapter) browser: BrowserAdapter = field(default_factory=RecordingBrowserAdapter) @classmethod def from_config(cls, config: object) -> AdapterBundle: """Build an AdapterBundle from RCConfig, wiring MCP adapters when enabled.""" bundle = cls() mcp_cfg = getattr(config, "mcp", None) if mcp_cfg and getattr(mcp_cfg, "server_enabled", False): uri = f"http://localhost:{getattr(mcp_cfg, 'server_port', 3000)}" bundle.message = MCPMessageAdapter(server_uri=uri) bundle.web_fetch = MCPWebFetchAdapter(server_uri=uri) return bundle ================================================ FILE: researchclaw/agents/__init__.py ================================================ """Multi-agent subsystems for AutoResearchClaw pipeline.""" ================================================ FILE: researchclaw/agents/base.py ================================================ """Base classes for multi-agent subsystems. Provides ``BaseAgent`` (individual agent) and ``AgentOrchestrator`` (coordinator for multi-agent workflows). Both use the existing ``LLMClient`` for model calls and follow the same structural-typing conventions as ``CodeAgent``. """ from __future__ import annotations import json import logging import re from dataclasses import dataclass, field from typing import Any, Protocol logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # LLM protocol (structural typing — no import dependency on llm.client) # --------------------------------------------------------------------------- class _LLMResponseLike(Protocol): # pragma: no cover content: str model: str prompt_tokens: int completion_tokens: int class _LLMClientLike(Protocol): # pragma: no cover def chat( self, messages: list[dict[str, str]], *, system: str | None = None, max_tokens: int | None = None, temperature: float | None = None, json_mode: bool = False, ) -> Any: ... # --------------------------------------------------------------------------- # Agent result # --------------------------------------------------------------------------- @dataclass class AgentStepResult: """Output from a single agent step.""" success: bool data: dict[str, Any] = field(default_factory=dict) error: str = "" llm_calls: int = 0 token_usage: int = 0 # --------------------------------------------------------------------------- # Base agent # --------------------------------------------------------------------------- class BaseAgent: """Base class for all sub-agents in a multi-agent system. Subclasses must implement ``execute(context) -> AgentStepResult``. """ name: str = "base" def __init__(self, llm: _LLMClientLike) -> None: self._llm = llm self._calls = 0 self._tokens = 0 self.logger = logging.getLogger(f"{__name__}.{self.name}") # -- LLM helpers ------------------------------------------------------- def _chat( self, system: str, user: str, *, max_tokens: int = 4096, temperature: float = 0.4, json_mode: bool = False, ) -> str: """Send a chat message and return the content string.""" self._calls += 1 resp = self._llm.chat( [{"role": "user", "content": user}], system=system, max_tokens=max_tokens, temperature=temperature, json_mode=json_mode, ) self._tokens += getattr(resp, "total_tokens", 0) return resp.content def _chat_json( self, system: str, user: str, *, max_tokens: int = 4096, temperature: float = 0.3, ) -> dict[str, Any]: """Send a chat message expecting JSON output. Falls back to regex extraction.""" raw = self._chat( system, user, max_tokens=max_tokens, temperature=temperature, json_mode=True, ) return self._parse_json(raw) or {} # -- JSON parsing (3-tier, matching CodeAgent convention) --------------- @staticmethod def _parse_json(text: str) -> dict[str, Any] | None: """Try to extract JSON from text using three strategies. Always returns a ``dict`` or ``None`` — lists and other JSON primitives are discarded so callers can safely use ``.get()``. """ def _as_dict(val: Any) -> dict[str, Any] | None: return val if isinstance(val, dict) else None # 1. Direct parse try: return _as_dict(json.loads(text)) except (json.JSONDecodeError, ValueError): pass # 2. Fenced code block m = re.search(r"```(?:json)?\s*\n(.*?)```", text, re.DOTALL) if m: try: return _as_dict(json.loads(m.group(1))) except (json.JSONDecodeError, ValueError): pass # 3. First balanced { ... } block (BUG-DA6-07: use non-greedy brace matching) depth = 0 start_idx = -1 for i, ch in enumerate(text): if ch == "{": if depth == 0: start_idx = i depth += 1 elif ch == "}": depth -= 1 if depth == 0 and start_idx >= 0: candidate = text[start_idx : i + 1] try: return _as_dict(json.loads(candidate)) except (json.JSONDecodeError, ValueError): start_idx = -1 # try next top-level block return None # -- Subclass API ------------------------------------------------------ def execute(self, context: dict[str, Any]) -> AgentStepResult: """Execute the agent's task. Must be overridden.""" raise NotImplementedError def _make_result( self, success: bool, data: dict[str, Any] | None = None, error: str = "", ) -> AgentStepResult: # BUG-DA6-01: Return per-call delta, then reset counters to avoid # double-counting when the same agent instance is reused across retries. calls, tokens = self._calls, self._tokens self._calls = 0 self._tokens = 0 return AgentStepResult( success=success, data=data or {}, error=error, llm_calls=calls, token_usage=tokens, ) # --------------------------------------------------------------------------- # Orchestrator # --------------------------------------------------------------------------- class AgentOrchestrator: """Coordinates a sequence of agents with optional retry loops. Subclasses implement ``orchestrate(context) -> dict`` which defines the specific workflow (sequential, branching, iterative, etc.). """ def __init__(self, llm: _LLMClientLike, *, max_iterations: int = 3) -> None: self._llm = llm self.max_iterations = max_iterations self.logger = logging.getLogger(f"{__name__}.orchestrator") self.total_llm_calls = 0 self.total_tokens = 0 def _accumulate(self, result: AgentStepResult) -> None: """Track cumulative LLM usage.""" self.total_llm_calls += result.llm_calls self.total_tokens += result.token_usage def orchestrate(self, context: dict[str, Any]) -> dict[str, Any]: """Run the multi-agent workflow. Must be overridden.""" raise NotImplementedError ================================================ FILE: researchclaw/agents/benchmark_agent/__init__.py ================================================ """BenchmarkAgent — multi-agent benchmark, dataset, and baseline selection. Architecture ------------ 1. **Surveyor** — searches HuggingFace Hub + local knowledge base for domain-relevant benchmarks, datasets, and baseline methods. 2. **Selector** — filters and ranks candidates based on hardware constraints, time budget, network policy, and tier availability. 3. **Acquirer** — generates data-loading code snippets, ``setup.py`` download scripts, baseline boilerplate, and ``requirements.txt`` entries. 4. **Validator** — validates generated code for syntax correctness and API compatibility. The ``BenchmarkOrchestrator`` coordinates the four agents and produces a ``BenchmarkPlan`` consumed by downstream pipeline stages (experiment design, code generation). """ from researchclaw.agents.benchmark_agent.orchestrator import ( BenchmarkOrchestrator, BenchmarkPlan, ) __all__ = ["BenchmarkOrchestrator", "BenchmarkPlan"] ================================================ FILE: researchclaw/agents/benchmark_agent/acquirer.py ================================================ """Acquirer Agent — generates data loading code and download scripts. Produces three outputs consumed by the code generation stage: 1. Data loading snippets (``get_datasets()`` function) 2. Baseline method snippets (model instantiation code) 3. ``setup.py`` additions for dataset downloading """ from __future__ import annotations import logging from typing import Any from researchclaw.agents.base import AgentStepResult, BaseAgent logger = logging.getLogger(__name__) class AcquirerAgent(BaseAgent): """Generates data loading, baseline, and download code.""" name = "acquirer" def _generate_data_loader( self, benchmarks: list[dict[str, Any]], topic: str, ) -> str: """Ask LLM to generate a robust data loading function.""" bench_specs = [] for b in benchmarks: spec = ( f"- {b.get('name', 'Unknown')} (tier {b.get('tier', '?')}, " f"role: {b.get('role', 'secondary')})\n" f" API: {b.get('api', 'N/A')}\n" f" Metrics: {b.get('metrics', [])}\n" f" Note: {b.get('note', '')}" ) bench_specs.append(spec) system = ( "You are an expert ML engineer. Generate a Python function that loads " "and prepares datasets for an ML experiment.\n\n" "REQUIREMENTS:\n" "- Function signature: def get_datasets(data_root='/workspace/data') -> dict\n" "- Returns dict with keys: 'train', 'val', 'test' (each a Dataset or DataLoader)\n" "- Include appropriate transforms (normalization, augmentation for training)\n" "- Handle both torchvision and HuggingFace datasets APIs\n" "- Include proper train/val/test splits\n" "- Add error handling with informative messages\n" "- For pre-cached datasets (tier 1), use download=False\n" "- For downloadable datasets (tier 2), use download=True in setup.py\n" "- Include a DATA_CONFIG dict with dataset metadata (num_classes, input_shape, etc.)\n\n" "Return ONLY the Python code, no explanation." ) user = ( f"Research Topic: {topic}\n\n" f"Datasets to load:\n" + "\n".join(bench_specs) + "\n\n" "Generate the data loading code." ) return self._chat(system, user, max_tokens=4096, temperature=0.2) def _generate_baseline_code( self, baselines: list[dict[str, Any]], benchmarks: list[dict[str, Any]], topic: str, ) -> str: """Ask LLM to generate baseline method instantiation code.""" base_specs = [] for bl in baselines: spec = ( f"- {bl.get('name', 'Unknown')}\n" f" Source: {bl.get('source', 'N/A')}\n" f" Paper: {bl.get('paper', 'N/A')}" ) base_specs.append(spec) primary_bench = next( (b for b in benchmarks if b.get("role") == "primary"), benchmarks[0] if benchmarks else {}, ) system = ( "You are an expert ML engineer. Generate Python code that instantiates " "baseline methods for comparison in an ML experiment.\n\n" "REQUIREMENTS:\n" "- Function signature: def get_baselines(num_classes, device='cuda') -> dict\n" "- Returns dict mapping method_name -> model (nn.Module)\n" "- Each model must be ready for training (correct output dimensions)\n" "- Use pretrained weights where available (for feature extractors)\n" "- Adapt final layer to match num_classes of the target dataset\n" "- Include a BASELINES_CONFIG dict with metadata (param_count, paper, etc.)\n" "- Handle missing optional packages gracefully\n\n" "Return ONLY the Python code, no explanation." ) user = ( f"Research Topic: {topic}\n" f"Primary Dataset: {primary_bench.get('name', 'N/A')} " f"({primary_bench.get('classes', '?')} classes)\n\n" f"Baseline Methods:\n" + "\n".join(base_specs) + "\n\n" "Generate the baseline instantiation code." ) return self._chat(system, user, max_tokens=4096, temperature=0.2) def _generate_setup_script( self, benchmarks: list[dict[str, Any]], required_pip: list[str], ) -> str: """Generate setup.py content for dataset downloading.""" # Tier 2 datasets need download scripts tier2 = [b for b in benchmarks if b.get("tier", 1) >= 2] if not tier2 and not required_pip: return "" lines = [ '"""Setup script for dataset downloading and environment preparation.', '', 'This script runs during Phase 1 (setup) of the Docker sandbox,', 'when network access is available. It downloads datasets and installs', 'any additional dependencies.', '"""', '', 'import os', 'import sys', '', 'DATA_ROOT = "/workspace/data"', 'HF_CACHE = os.path.join(DATA_ROOT, "hf")', '', '', 'def download_datasets():', ' """Download all required datasets."""', ' os.makedirs(DATA_ROOT, exist_ok=True)', ' os.makedirs(HF_CACHE, exist_ok=True)', '', ] for b in tier2: api = b.get("api", "") name = b.get("name", "unknown") if "torchvision" in api: # Convert download=False to download=True for setup dl_api = api.replace("download=False", "download=True") lines.extend([ f' # Download {name}', ' try:', f' import torchvision', f' {dl_api}', f' print(f"Downloaded {name}")', f' except Exception as e:', f' print(f"Warning: Failed to download {name}: {{e}}")', '', ]) elif "datasets.load_dataset" in api or "load_dataset" in api: # Rewrite qualified `datasets.load_dataset(...)` to # `load_dataset(...)` so it matches the `from datasets import` _dl_api = api.replace("datasets.load_dataset", "load_dataset") lines.extend([ f' # Download {name}', ' try:', f' from datasets import load_dataset', f' {_dl_api}', f' print(f"Downloaded {name}")', f' except Exception as e:', f' print(f"Warning: Failed to download {name}: {{e}}")', '', ]) elif "PygNodePropPredDataset" in api or "PygGraphPropPredDataset" in api: lines.extend([ f' # Download {name}', ' try:', f' from ogb.nodeproppred import PygNodePropPredDataset' if 'Node' in api else f' from ogb.graphproppred import PygGraphPropPredDataset', f' {api}', f' print(f"Downloaded {name}")', f' except Exception as e:', f' print(f"Warning: Failed to download {name}: {{e}}")', '', ]) lines.extend([ '', 'if __name__ == "__main__":', ' download_datasets()', ' print("Setup complete.")', ]) return "\n".join(lines) def _generate_requirements(self, required_pip: list[str]) -> str: """Generate requirements.txt content for additional packages.""" if not required_pip: return "" # Filter out packages that are already in the Docker image builtin = { "torch", "torchvision", "torchaudio", "numpy", "scipy", "sklearn", "scikit-learn", "pandas", "matplotlib", "seaborn", "tqdm", "gymnasium", "networkx", "timm", "einops", "torchmetrics", "transformers", "datasets", "accelerate", "peft", "trl", "bitsandbytes", "tokenizers", "safetensors", "h5py", "tensorboard", "pillow", "pyyaml", "kornia", "albumentations", } extra = [p for p in required_pip if p.lower() not in builtin] return "\n".join(extra) if extra else "" # -- Code cleanup ------------------------------------------------------ @staticmethod def _strip_fences(code: str) -> str: """Remove markdown code fences if present.""" code = code.strip() if code.startswith("```"): # Remove opening fence first_nl = code.index("\n") if "\n" in code else len(code) code = code[first_nl + 1:] if code.endswith("```"): code = code[:-3].rstrip() return code # -- Main entry point -------------------------------------------------- def execute(self, context: dict[str, Any]) -> AgentStepResult: """Generate data loading, baseline, and download code. Context keys: topic (str): Research topic selection (dict): Output from SelectorAgent """ topic = context.get("topic", "") selection = context.get("selection", {}) benchmarks = selection.get("selected_benchmarks", []) baselines = selection.get("selected_baselines", []) required_pip = selection.get("required_pip", []) if not benchmarks: return self._make_result(False, error="No benchmarks selected") # 1. Generate data loading code self.logger.info("Generating data loading code for %d datasets", len(benchmarks)) data_loader_code = self._strip_fences( self._generate_data_loader(benchmarks, topic) ) # 2. Generate baseline code baseline_code = "" if baselines: self.logger.info("Generating baseline code for %d methods", len(baselines)) baseline_code = self._strip_fences( self._generate_baseline_code(baselines, benchmarks, topic) ) # 3. Generate setup.py setup_code = self._generate_setup_script(benchmarks, required_pip) # 4. Generate requirements.txt requirements = self._generate_requirements(required_pip) result = { "data_loader_code": data_loader_code, "baseline_code": baseline_code, "setup_code": setup_code, "requirements": requirements, "benchmark_names": [b.get("name", "Unknown") for b in benchmarks], "baseline_names": [bl.get("name", "Unknown") for bl in baselines], } self.logger.info("Acquirer complete: %d code artifacts generated", sum(1 for v in result.values() if v)) return self._make_result(True, data=result) ================================================ FILE: researchclaw/agents/benchmark_agent/orchestrator.py ================================================ """BenchmarkAgent Orchestrator — coordinates the four sub-agents. Flow: Surveyor → Selector → Acquirer → Validator (→ retry if failed) Produces a ``BenchmarkPlan`` consumed by experiment design and code generation stages. """ from __future__ import annotations import json import logging import time from dataclasses import dataclass, field from pathlib import Path from typing import Any from researchclaw.agents.base import AgentOrchestrator from researchclaw.agents.benchmark_agent.acquirer import AcquirerAgent from researchclaw.agents.benchmark_agent.selector import SelectorAgent from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent from researchclaw.agents.benchmark_agent.validator import ValidatorAgent logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- @dataclass(frozen=True) class BenchmarkAgentConfig: """Configuration for the BenchmarkAgent system.""" enabled: bool = True # Surveyor enable_hf_search: bool = True max_hf_results: int = 10 enable_web_search: bool = False max_web_results: int = 5 web_search_min_local: int = 3 # Selector tier_limit: int = 2 min_benchmarks: int = 1 min_baselines: int = 2 prefer_cached: bool = True # Orchestrator max_iterations: int = 2 # max Acquirer→Validator retry loops # --------------------------------------------------------------------------- # Output data structure # --------------------------------------------------------------------------- @dataclass class BenchmarkPlan: """Final output from the BenchmarkAgent system. Consumed by: - Experiment design stage (selected benchmarks/baselines for plan) - Code generation stage (data_loader_code, baseline_code) - Docker sandbox (setup_code, requirements) """ # Selected items selected_benchmarks: list[dict[str, Any]] = field(default_factory=list) selected_baselines: list[dict[str, Any]] = field(default_factory=list) matched_domains: list[str] = field(default_factory=list) # Generated code data_loader_code: str = "" baseline_code: str = "" setup_code: str = "" requirements: str = "" # Metadata rationale: str = "" experiment_notes: str = "" validation_passed: bool = False validation_warnings: list[str] = field(default_factory=list) # Stats total_llm_calls: int = 0 total_tokens: int = 0 elapsed_sec: float = 0.0 def to_dict(self) -> dict[str, Any]: """Serialize to a JSON-safe dict.""" return { "selected_benchmarks": self.selected_benchmarks, "selected_baselines": self.selected_baselines, "matched_domains": self.matched_domains, "data_loader_code": self.data_loader_code, "baseline_code": self.baseline_code, "setup_code": self.setup_code, "requirements": self.requirements, "rationale": self.rationale, "experiment_notes": self.experiment_notes, "validation_passed": self.validation_passed, "validation_warnings": self.validation_warnings, "total_llm_calls": self.total_llm_calls, "total_tokens": self.total_tokens, "elapsed_sec": self.elapsed_sec, } def to_prompt_block(self) -> str: """Format as a prompt block for injection into code generation.""" parts = [] # Benchmark summary if self.selected_benchmarks: parts.append("## Selected Benchmarks") for b in self.selected_benchmarks: role = b.get("role", "secondary") metrics = b.get("metrics", []) parts.append( f"- **{b.get('name', 'Unknown')}** ({role}) — " f"metrics: {', '.join(str(m) for m in metrics)}" ) if b.get("api"): parts.append(f" API: `{b['api']}`") if b.get("note"): parts.append(f" Note: {b['note']}") # Baseline summary if self.selected_baselines: parts.append("\n## Selected Baselines") for bl in self.selected_baselines: parts.append( f"- **{bl.get('name', 'Unknown')}**: {bl.get('paper', 'N/A')}" ) if bl.get("source"): parts.append(f" Code: `{bl['source']}`") # Data loading code if self.data_loader_code: parts.append("\n## Data Loading Code (READY TO USE)") parts.append("```python") parts.append(self.data_loader_code) parts.append("```") # Baseline code if self.baseline_code: parts.append("\n## Baseline Methods Code (READY TO USE)") parts.append("```python") parts.append(self.baseline_code) parts.append("```") # Experiment notes if self.experiment_notes: parts.append(f"\n## Experiment Notes\n{self.experiment_notes}") return "\n".join(parts) # --------------------------------------------------------------------------- # Orchestrator # --------------------------------------------------------------------------- class BenchmarkOrchestrator(AgentOrchestrator): """Coordinates Surveyor → Selector → Acquirer → Validator pipeline.""" def __init__( self, llm: Any, config: BenchmarkAgentConfig | None = None, *, gpu_memory_mb: int = 49000, time_budget_sec: int = 300, network_policy: str = "setup_only", stage_dir: Path | None = None, ) -> None: cfg = config or BenchmarkAgentConfig() super().__init__(llm, max_iterations=cfg.max_iterations) self._config = cfg self._stage_dir = stage_dir # Initialize sub-agents self._surveyor = SurveyorAgent( llm, enable_hf_search=cfg.enable_hf_search, max_hf_results=cfg.max_hf_results, ) self._selector = SelectorAgent( llm, gpu_memory_mb=gpu_memory_mb, time_budget_sec=time_budget_sec, network_policy=network_policy, tier_limit=cfg.tier_limit, min_benchmarks=cfg.min_benchmarks, min_baselines=cfg.min_baselines, prefer_cached=cfg.prefer_cached, ) self._acquirer = AcquirerAgent(llm) self._validator = ValidatorAgent(llm) def _save_artifact(self, name: str, data: Any) -> None: """Save intermediate artifact to stage directory.""" if self._stage_dir is None: return self._stage_dir.mkdir(parents=True, exist_ok=True) path = self._stage_dir / name if isinstance(data, str): path.write_text(data, encoding="utf-8") else: path.write_text( json.dumps(data, indent=2, ensure_ascii=False, default=str), encoding="utf-8", ) def orchestrate(self, context: dict[str, Any]) -> BenchmarkPlan: """Run the full benchmark selection pipeline. Context keys: topic (str): Research topic/title hypothesis (str): Research hypothesis experiment_plan (str): Experiment plan text """ t0 = time.monotonic() topic = context.get("topic", "") hypothesis = context.get("hypothesis", "") self.logger.info("BenchmarkAgent starting for: %s", topic[:80]) plan = BenchmarkPlan() # ── Phase 1: Survey ─────────────────────────────────────── self.logger.info("Phase 1: Surveying benchmarks") survey_result = self._surveyor.execute({ "topic": topic, "hypothesis": hypothesis, "experiment_plan": context.get("experiment_plan", ""), }) self._accumulate(survey_result) if not survey_result.success: self.logger.warning("Survey failed: %s", survey_result.error) plan.elapsed_sec = time.monotonic() - t0 plan.total_llm_calls = self.total_llm_calls plan.total_tokens = self.total_tokens return plan survey = survey_result.data plan.matched_domains = survey.get("matched_domains", []) self._save_artifact("survey_results.json", survey) # ── Phase 2: Select ─────────────────────────────────────── self.logger.info("Phase 2: Selecting benchmarks and baselines") select_result = self._selector.execute({ "topic": topic, "survey": survey, }) self._accumulate(select_result) if not select_result.success: self.logger.warning("Selection failed: %s", select_result.error) plan.elapsed_sec = time.monotonic() - t0 plan.total_llm_calls = self.total_llm_calls plan.total_tokens = self.total_tokens return plan selection = select_result.data plan.selected_benchmarks = selection.get("selected_benchmarks", []) plan.selected_baselines = selection.get("selected_baselines", []) plan.rationale = selection.get("rationale", "") plan.experiment_notes = selection.get("experiment_notes", "") self._save_artifact("selection_results.json", selection) # ── Phase 3+4: Acquire + Validate (with retry) ─────────── for iteration in range(self.max_iterations): self.logger.info( "Phase 3: Acquiring code (iteration %d/%d)", iteration + 1, self.max_iterations, ) # Acquire acq_result = self._acquirer.execute({ "topic": topic, "selection": selection, }) self._accumulate(acq_result) if not acq_result.success: self.logger.warning("Acquisition failed: %s", acq_result.error) continue acquisition = acq_result.data self._save_artifact( f"acquisition_{iteration}.json", {k: v for k, v in acquisition.items() if k not in ("data_loader_code", "baseline_code", "setup_code")}, ) # Validate self.logger.info("Phase 4: Validating code (iteration %d/%d)", iteration + 1, self.max_iterations) val_result = self._validator.execute({ "acquisition": acquisition, }) self._accumulate(val_result) validation = val_result.data self._save_artifact(f"validation_{iteration}.json", validation) # Store results plan.data_loader_code = acquisition.get("data_loader_code", "") plan.baseline_code = acquisition.get("baseline_code", "") plan.setup_code = acquisition.get("setup_code", "") plan.requirements = acquisition.get("requirements", "") plan.validation_passed = validation.get("passed", False) plan.validation_warnings = validation.get("warnings", []) if plan.validation_passed: self.logger.info("Validation passed on iteration %d", iteration + 1) break self.logger.warning( "Validation failed (iteration %d): %s", iteration + 1, validation.get("errors", []), ) # ── Finalize ────────────────────────────────────────────── plan.total_llm_calls = self.total_llm_calls plan.total_tokens = self.total_tokens plan.elapsed_sec = time.monotonic() - t0 # Save final plan self._save_artifact("benchmark_plan.json", plan.to_dict()) self.logger.info( "BenchmarkAgent complete: %d benchmarks, %d baselines, " "validation=%s, %d LLM calls, %.1fs", len(plan.selected_benchmarks), len(plan.selected_baselines), "PASS" if plan.validation_passed else "FAIL", plan.total_llm_calls, plan.elapsed_sec, ) return plan ================================================ FILE: researchclaw/agents/benchmark_agent/selector.py ================================================ """Selector Agent — filters and ranks benchmark candidates. Applies hardware constraints, time budget, network policy, and tier priorities to select the optimal combination of datasets and baselines. """ from __future__ import annotations import logging from pathlib import Path from typing import Any import yaml from researchclaw.agents.base import AgentStepResult, BaseAgent logger = logging.getLogger(__name__) _KNOWLEDGE_PATH = Path(__file__).resolve().parent.parent.parent / "data" / "benchmark_knowledge.yaml" # Maximum dataset size (MB) by tier and network policy _SIZE_LIMITS: dict[str, int] = { "none": 0, # No download allowed — tier 1 only "setup_only": 5000, # Can download during setup phase "pip_only": 0, # pip only, no data download "full": 50000, # Generous limit } class SelectorAgent(BaseAgent): """Filters and ranks datasets/baselines based on constraints.""" name = "selector" def __init__( self, llm: Any, *, gpu_memory_mb: int = 49000, time_budget_sec: int = 300, network_policy: str = "setup_only", tier_limit: int = 2, min_benchmarks: int = 1, min_baselines: int = 2, prefer_cached: bool = True, ) -> None: super().__init__(llm) self._gpu_mb = gpu_memory_mb self._time_sec = time_budget_sec self._network_policy = network_policy self._tier_limit = tier_limit self._min_bench = min_benchmarks self._min_base = min_baselines self._prefer_cached = prefer_cached # -- Filtering --------------------------------------------------------- def _filter_benchmarks( self, benchmarks: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Filter benchmarks by tier, size, and network policy.""" max_size = _SIZE_LIMITS.get(self._network_policy, 5000) filtered: list[dict[str, Any]] = [] for b in benchmarks: tier = b.get("tier", 3) size = b.get("size_mb", 0) # Tier filter if tier > self._tier_limit: continue # Network policy filter if tier >= 2 and self._network_policy in ("none", "pip_only"): continue # Size filter (tier 2+ only — tier 1 is pre-cached) if tier >= 2 and size > max_size: continue filtered.append(b) return filtered def _filter_baselines( self, baselines: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Filter baselines by pip availability.""" filtered: list[dict[str, Any]] = [] for bl in baselines: pip_deps = bl.get("pip", []) # If no network, only allow baselines with no extra pip deps if self._network_policy == "none" and pip_deps: continue filtered.append(bl) return filtered # -- Ranking ----------------------------------------------------------- def _rank_benchmarks( self, benchmarks: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Sort benchmarks by preference: tier 1 > tier 2, knowledge_base > hf, downloads.""" def _score(b: dict[str, Any]) -> tuple[int, int, int]: tier = b.get("tier", 3) # Prefer lower tier (cached first) tier_score = -tier if self._prefer_cached else 0 # Prefer knowledge_base over hf/llm origin_score = { "knowledge_base": 2, "huggingface_hub": 1, "llm_suggestion": 0, }.get(b.get("origin", ""), 0) # Downloads as tiebreaker downloads = b.get("downloads", 0) return (tier_score, origin_score, downloads) return sorted(benchmarks, key=_score, reverse=True) def _rank_baselines( self, baselines: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Sort baselines: knowledge_base first, fewer deps preferred.""" def _score(bl: dict[str, Any]) -> tuple[int, int]: origin_score = 1 if bl.get("origin") == "knowledge_base" else 0 dep_score = -len(bl.get("pip", [])) return (origin_score, dep_score) return sorted(baselines, key=_score, reverse=True) # -- Selection --------------------------------------------------------- def _select_with_llm( self, topic: str, benchmarks: list[dict[str, Any]], baselines: list[dict[str, Any]], ) -> dict[str, Any]: """Ask LLM to make final selection from filtered candidates.""" bench_summary = "\n".join( f"- {b.get('name', 'Unknown')} (tier {b.get('tier', '?')}, " f"origin: {b.get('origin', '?')}, " f"metrics: {b.get('metrics', [])})" for b in benchmarks[:15] ) base_summary = "\n".join( f"- {bl.get('name', 'Unknown')}: {bl.get('paper', 'N/A')}" for bl in baselines[:10] ) system = ( "You are an ML experiment design expert. Select the BEST combination " "of benchmarks and baselines for a research paper.\n\n" "Return JSON:\n" "{\n" ' "primary_benchmark": "name",\n' ' "secondary_benchmarks": ["name1", "name2"],\n' ' "selected_baselines": ["name1", "name2", "name3"],\n' ' "rationale": "why these choices are optimal",\n' ' "experiment_notes": "specific setup guidance"\n' "}\n\n" "RULES:\n" "- Select 1 primary benchmark (the main evaluation dataset)\n" "- Select 0-2 secondary benchmarks (additional validation)\n" "- Select 2-4 baselines (must include at least 1 classic + 1 recent)\n" "- Primary benchmark MUST be the domain standard\n" "- Prefer benchmarks that top-venue papers commonly use\n" "- Consider dataset size vs time budget\n" "- CRITICAL: Only select benchmarks that are RELEVANT to the research " "topic's domain. Do NOT select image classification datasets (CIFAR, " "MNIST) for non-image tasks like PDE solvers, RL, or optimization.\n" "- CRITICAL: Baselines must be COMPETING METHODS, not optimizers. " "SGD/Adam/AdamW/Cosine LR are NOT baselines — they are training " "tools. Baselines must be alternative approaches to the same problem." ) user = ( f"Research Topic: {topic}\n\n" f"Available Benchmarks:\n{bench_summary}\n\n" f"Available Baselines:\n{base_summary}\n\n" f"Constraints: GPU={self._gpu_mb}MB, " f"time_budget={self._time_sec}s, " f"network_policy={self._network_policy}\n\n" "Make your selection." ) return self._chat_json(system, user, max_tokens=2048) def _resolve_selection( self, selection: dict[str, Any], benchmarks: list[dict[str, Any]], baselines: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Resolve LLM-selected names back to full benchmark/baseline dicts.""" # Build name lookup bench_map = {b.get("name", f"bench_{i}"): b for i, b in enumerate(benchmarks)} base_map = {bl.get("name", f"base_{i}"): bl for i, bl in enumerate(baselines)} selected_bench: list[dict[str, Any]] = [] primary = selection.get("primary_benchmark", "") if primary and primary in bench_map: entry = {**bench_map[primary], "role": "primary"} selected_bench.append(entry) for name in selection.get("secondary_benchmarks", []): if name in bench_map and name != primary: entry = {**bench_map[name], "role": "secondary"} selected_bench.append(entry) selected_base: list[dict[str, Any]] = [] for name in selection.get("selected_baselines", []): if name in base_map: selected_base.append(base_map[name]) return selected_bench, selected_base # -- Required baselines injection -------------------------------------- def _inject_required_baselines( self, topic: str, selected: list[dict[str, Any]], ranked: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Load required_baselines from knowledge base and inject missing ones. Returns the list of newly injected baseline dicts. """ try: kb = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8")) domains = kb.get("domains", {}) if isinstance(kb, dict) else {} except Exception: # noqa: BLE001 return [] topic_lower = topic.lower() injected: list[dict[str, Any]] = [] selected_names = {b.get("name", "").lower() for b in selected} for _domain_id, domain_data in domains.items(): if not isinstance(domain_data, dict): continue keywords = domain_data.get("keywords", []) if not any(kw.lower() in topic_lower for kw in keywords): continue required = domain_data.get("required_baselines", []) if not required: continue # Find each required baseline in ranked list or create stub all_baselines = domain_data.get("common_baselines", []) bl_by_name = {b.get("name", ""): b for b in all_baselines} for req_name in required: if req_name.lower() in selected_names: continue # Try to find full entry from knowledge base if req_name in bl_by_name: entry = {**bl_by_name[req_name], "origin": "required_baseline"} else: entry = {"name": req_name, "origin": "required_baseline", "pip": []} selected.append(entry) selected_names.add(req_name.lower()) injected.append(entry) return injected # -- Main entry point -------------------------------------------------- def execute(self, context: dict[str, Any]) -> AgentStepResult: """Select optimal benchmarks and baselines from survey results. Context keys: topic (str): Research topic survey (dict): Output from SurveyorAgent """ topic = context.get("topic", "") survey = context.get("survey", {}) benchmarks = survey.get("benchmarks", []) baselines = survey.get("baselines", []) if not benchmarks and not baselines: return self._make_result(False, error="No candidates to select from") # 1. Filter by constraints filtered_bench = self._filter_benchmarks(benchmarks) filtered_base = self._filter_baselines(baselines) self.logger.info( "Filtered: %d/%d benchmarks, %d/%d baselines", len(filtered_bench), len(benchmarks), len(filtered_base), len(baselines), ) # 2. Rank ranked_bench = self._rank_benchmarks(filtered_bench) ranked_base = self._rank_baselines(filtered_base) # 3. LLM-assisted final selection (if enough candidates) if len(ranked_bench) >= 2 or len(ranked_base) >= 2: selection = self._select_with_llm(topic, ranked_bench, ranked_base) selected_bench, selected_base = self._resolve_selection( selection, ranked_bench, ranked_base, ) else: # Not enough to warrant LLM call — use top ranked # BUG-DA6-06: Create copies to avoid mutating input dicts selected_bench = [{**b, "role": "primary"} if i == 0 else {**b, "role": "secondary"} for i, b in enumerate(ranked_bench[:3])] selected_base = ranked_base[:self._min_base] selection = {} # 4. Fallback: ensure minimums if len(selected_bench) < self._min_bench and ranked_bench: for b in ranked_bench: if b not in selected_bench: selected_bench.append({**b, "role": "secondary"}) if len(selected_bench) >= self._min_bench: break if len(selected_base) < self._min_base and ranked_base: for bl in ranked_base: if bl not in selected_base: selected_base.append(bl) if len(selected_base) >= self._min_base: break # 4b. Improvement E: Inject required baselines from knowledge base _injected_required = self._inject_required_baselines( topic, selected_base, ranked_base, ) if _injected_required: self.logger.info( "Injected %d required baselines: %s", len(_injected_required), [b.get("name") for b in _injected_required], ) # 5. Collect required pip packages required_pip: list[str] = [] seen_pip: set[str] = set() for item in selected_bench + selected_base: for pkg in item.get("pip", []): if pkg not in seen_pip: seen_pip.add(pkg) required_pip.append(pkg) result = { "selected_benchmarks": selected_bench, "selected_baselines": selected_base, "required_pip": required_pip, "rationale": selection.get("rationale", ""), "experiment_notes": selection.get("experiment_notes", ""), "total_filtered": len(filtered_bench), } self.logger.info( "Selected: %d benchmarks, %d baselines, %d pip packages", len(selected_bench), len(selected_base), len(required_pip), ) return self._make_result(True, data=result) ================================================ FILE: researchclaw/agents/benchmark_agent/surveyor.py ================================================ """Surveyor Agent — searches for domain-relevant benchmarks and baselines. Data sources (in priority order): 1. Local ``benchmark_knowledge.yaml`` — always available, no network. 2. HuggingFace Hub API (``huggingface_hub``) — dataset discovery by task/keyword. 3. LLM fallback — asks the LLM to suggest benchmarks when APIs unavailable. """ from __future__ import annotations import logging from pathlib import Path from typing import Any import yaml from researchclaw.agents.base import AgentStepResult, BaseAgent logger = logging.getLogger(__name__) _KNOWLEDGE_PATH = Path(__file__).resolve().parent.parent.parent / "data" / "benchmark_knowledge.yaml" # --------------------------------------------------------------------------- # HuggingFace Hub helpers (optional dependency) # --------------------------------------------------------------------------- _HF_AVAILABLE = False try: from huggingface_hub import HfApi # type: ignore[import-untyped] _HF_AVAILABLE = True except ImportError: pass # Mapping from our domain keywords to HuggingFace task_categories filters _DOMAIN_TO_HF_TASK: dict[str, list[str]] = { "image_classification": ["image-classification"], "text_classification": ["text-classification", "sentiment-analysis"], "language_modeling": ["text-generation"], "question_answering": ["question-answering"], "generative_models": ["unconditional-image-generation"], "graph_neural_networks": ["graph-ml"], "reinforcement_learning": ["reinforcement-learning"], "tabular_learning": ["tabular-classification", "tabular-regression"], "llm_finetuning": ["text-generation"], } class SurveyorAgent(BaseAgent): """Searches local knowledge base and HuggingFace Hub for benchmarks.""" name = "surveyor" def __init__( self, llm: Any, *, enable_hf_search: bool = True, max_hf_results: int = 10, ) -> None: super().__init__(llm) self._enable_hf = enable_hf_search and _HF_AVAILABLE self._max_hf = max_hf_results self._knowledge = self._load_knowledge() # -- Knowledge base ---------------------------------------------------- @staticmethod def _load_knowledge() -> dict[str, Any]: """Load the local benchmark knowledge base.""" try: data = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8")) return data.get("domains", {}) if isinstance(data, dict) else {} except Exception: # noqa: BLE001 logger.warning("Failed to load benchmark_knowledge.yaml", exc_info=True) return {} def _match_domains(self, topic: str) -> list[str]: """Return domain IDs whose keywords appear in the topic.""" topic_lower = topic.lower() matched: list[str] = [] for domain_id, info in self._knowledge.items(): keywords = info.get("keywords", []) for kw in keywords: if kw in topic_lower: matched.append(domain_id) break return matched def _get_local_candidates(self, domain_ids: list[str]) -> dict[str, Any]: """Retrieve benchmarks and baselines from local knowledge base.""" benchmarks: list[dict[str, Any]] = [] baselines: list[dict[str, Any]] = [] seen_bench: set[str] = set() seen_base: set[str] = set() for did in domain_ids: info = self._knowledge.get(did, {}) for b in info.get("standard_benchmarks", []): name = b.get("name", "") if name not in seen_bench: seen_bench.add(name) benchmarks.append({**b, "source_domain": did, "origin": "knowledge_base"}) for bl in info.get("common_baselines", []): name = bl.get("name", "") if name not in seen_base: seen_base.add(name) baselines.append({**bl, "source_domain": did, "origin": "knowledge_base"}) return {"benchmarks": benchmarks, "baselines": baselines} # -- HuggingFace Hub --------------------------------------------------- def _search_hf_datasets(self, topic: str, domain_ids: list[str]) -> list[dict[str, Any]]: """Search HuggingFace Hub for relevant datasets.""" if not self._enable_hf: return [] results: list[dict[str, Any]] = [] seen: set[str] = set() try: api = HfApi() # Strategy 1: Search by task category for did in domain_ids: for task_cat in _DOMAIN_TO_HF_TASK.get(did, []): try: datasets = api.list_datasets( filter=[f"task_categories:{task_cat}"], sort="downloads", direction=-1, limit=self._max_hf, ) for ds in datasets: if ds.id not in seen: seen.add(ds.id) results.append({ "name": ds.id, "downloads": getattr(ds, "downloads", 0), "origin": "huggingface_hub", "api": f"datasets.load_dataset('{ds.id}', cache_dir='/workspace/data/hf')", "tier": 2, }) except Exception: # noqa: BLE001 logger.debug("HF task search failed for %s", task_cat) # Strategy 2: Keyword search on topic keywords = self._extract_search_keywords(topic) for kw in keywords[:3]: try: datasets = api.list_datasets( search=kw, sort="downloads", direction=-1, limit=self._max_hf, ) for ds in datasets: if ds.id not in seen: seen.add(ds.id) results.append({ "name": ds.id, "downloads": getattr(ds, "downloads", 0), "origin": "huggingface_hub", "api": f"datasets.load_dataset('{ds.id}', cache_dir='/workspace/data/hf')", "tier": 2, }) except Exception: # noqa: BLE001 logger.debug("HF keyword search failed for %s", kw) except Exception as exc: # noqa: BLE001 logger.warning("HuggingFace Hub search failed: %s", exc) return results @staticmethod def _extract_search_keywords(topic: str) -> list[str]: """Extract 1-3 word search keywords from a topic string.""" # Remove common filler words to get meaningful search terms stop = { "a", "an", "the", "for", "in", "on", "of", "to", "with", "and", "or", "is", "are", "using", "via", "based", "towards", "novel", "new", "improved", "approach", "method", "methods", "study", } words = [w.lower().strip(".,;:!?()[]") for w in topic.split()] filtered = [w for w in words if w and w not in stop and len(w) > 2] # Return 2-3 keyword phrases keywords: list[str] = [] if len(filtered) >= 2: keywords.append(" ".join(filtered[:2])) if len(filtered) >= 3: keywords.append(" ".join(filtered[:3])) if filtered: keywords.append(filtered[0]) return keywords # -- LLM fallback ------------------------------------------------------ def _llm_suggest_benchmarks(self, topic: str, hypothesis: str) -> dict[str, Any]: """Ask LLM to suggest benchmarks and baselines when APIs unavailable.""" system = ( "You are an expert ML researcher. Given a research topic and hypothesis, " "suggest appropriate benchmarks, datasets, and baseline methods.\n\n" "Return a JSON object with:\n" "- benchmarks: array of {name, domain, metrics: [], api (Python one-liner), " " tier (1=pre-cached, 2=downloadable), size_mb}\n" "- baselines: array of {name, source (Python code), paper (citation), pip: []}\n" "- rationale: string explaining why these are the right choices\n\n" "CRITICAL RULES:\n" "- Benchmarks and baselines MUST be DOMAIN-APPROPRIATE for the topic.\n" "- Do NOT suggest image classification datasets (CIFAR, ImageNet, MNIST) " "for non-image topics like PDE solvers, RL, combinatorial optimization, etc.\n" "- Do NOT suggest optimizers (SGD, Adam, AdamW) as METHOD baselines — " "optimizers are training tools, NOT research methods to compare against.\n" "- Baselines must be COMPETING METHODS from the same research domain.\n\n" "DOMAIN-SPECIFIC GUIDANCE:\n" "- Physics/PDE/Scientific computing: Use SYNTHETIC data (Burgers eq, " "Darcy flow, Navier-Stokes, heat equation). Baselines: FNO, DeepONet, " "PINN, spectral methods.\n" "- Combinatorial optimization (TSP, graph coloring, scheduling): Use " "SYNTHETIC instances (random TSP, Erdos-Renyi graphs). Baselines: " "classical MCTS, LKH, OR-Tools, Concorde, RL-based methods.\n" "- Reinforcement learning: Use Gymnasium environments (CartPole, " "LunarLander, HalfCheetah). Baselines: PPO, SAC, DQN, TD3.\n" "- Graph learning: Use standard graph benchmarks (Cora, CiteSeer, " "ogbn-arxiv). Baselines: GCN, GAT, GraphSAGE.\n" "- If the domain naturally requires SYNTHETIC data (PDE, optimization, " "theoretical analysis), explicitly set tier=1 and api='synthetic' and " "describe the data generation procedure in the 'source' field.\n\n" "- Prefer well-known, widely-used benchmarks from top venues\n" "- Prefer baselines with open-source PyTorch implementations\n" "- Include at least 2 datasets and 2 baselines" ) user = ( f"Research Topic: {topic}\n" f"Hypothesis: {hypothesis}\n\n" "Suggest appropriate benchmarks, datasets, and baseline methods. " "Make sure they are relevant to the specific domain of this research." ) result = self._chat_json(system, user, max_tokens=4096) return result # -- Main entry point -------------------------------------------------- def execute(self, context: dict[str, Any]) -> AgentStepResult: """Survey available benchmarks and baselines for the given topic. Context keys: topic (str): Research topic/title hypothesis (str): Research hypothesis experiment_plan (str): Experiment plan from previous stages """ topic = context.get("topic", "") hypothesis = context.get("hypothesis", "") if not topic: return self._make_result(False, error="No topic provided") self.logger.info("Surveying benchmarks for topic: %s", topic[:80]) # 1. Match domains from knowledge base domain_ids = self._match_domains(topic) if hypothesis: domain_ids = list(dict.fromkeys( domain_ids + self._match_domains(hypothesis) )) self.logger.info("Matched domains: %s", domain_ids) # 2. Get local candidates local = self._get_local_candidates(domain_ids) # 3. Search HuggingFace Hub (if available) hf_datasets = self._search_hf_datasets(topic, domain_ids) # 4. LLM fallback if no local matches llm_suggestions: dict[str, Any] = {} if not local["benchmarks"] and not hf_datasets: self.logger.info("No local/HF matches — falling back to LLM") llm_suggestions = self._llm_suggest_benchmarks(topic, hypothesis) # 5. Combine results all_benchmarks = local["benchmarks"] + hf_datasets if llm_suggestions.get("benchmarks"): for b in llm_suggestions["benchmarks"]: b["origin"] = "llm_suggestion" all_benchmarks.append(b) all_baselines = local["baselines"] if llm_suggestions.get("baselines"): for bl in llm_suggestions["baselines"]: bl["origin"] = "llm_suggestion" all_baselines.append(bl) survey_result = { "matched_domains": domain_ids, "benchmarks": all_benchmarks, "baselines": all_baselines, "hf_datasets_found": len(hf_datasets), "llm_fallback_used": bool(llm_suggestions), "rationale": llm_suggestions.get("rationale", ""), } self.logger.info( "Survey complete: %d benchmarks, %d baselines, %d HF datasets", len(all_benchmarks), len(all_baselines), len(hf_datasets), ) return self._make_result(True, data=survey_result) ================================================ FILE: researchclaw/agents/benchmark_agent/validator.py ================================================ """Validator Agent — validates generated code for correctness. Performs three levels of validation: 1. **Syntax check** — ``ast.parse()`` on generated Python code. 2. **Import check** — verifies that referenced modules are importable or listed in requirements. 3. **LLM review** — asks the LLM to review code for common pitfalls (wrong API usage, missing transforms, incorrect splits). """ from __future__ import annotations import ast import logging import re from typing import Any from researchclaw.agents.base import AgentStepResult, BaseAgent logger = logging.getLogger(__name__) # Packages available in Docker image (no pip install needed) _BUILTIN_MODULES = { "torch", "torchvision", "torchaudio", "numpy", "scipy", "sklearn", "pandas", "matplotlib", "seaborn", "tqdm", "gymnasium", "networkx", "timm", "einops", "torchmetrics", "transformers", "datasets", "accelerate", "peft", "trl", "bitsandbytes", "tokenizers", "safetensors", "h5py", "tensorboard", "PIL", "yaml", "kornia", "albumentations", "cv2", "mujoco", "os", "sys", "json", "re", "pathlib", "typing", "collections", "functools", "itertools", "math", "random", "copy", "dataclasses", "abc", "io", "csv", "glob", "shutil", "time", "datetime", "logging", "warnings", "argparse", "pickle", "struct", "hashlib", } class ValidatorAgent(BaseAgent): """Validates generated code artifacts for syntax and API correctness.""" name = "validator" def _check_syntax(self, code: str, label: str) -> list[str]: """Check Python syntax via ast.parse. Returns list of errors.""" if not code.strip(): return [] try: ast.parse(code) return [] except SyntaxError as e: return [f"{label}: SyntaxError at line {e.lineno}: {e.msg}"] def _check_imports( self, code: str, label: str, extra_requirements: list[str], ) -> list[str]: """Check that imported modules are available or declared.""" if not code.strip(): return [] warnings: list[str] = [] # Extract import statements import_pattern = re.compile( r"^\s*(?:import|from)\s+(\w+)", re.MULTILINE, ) imports = set(import_pattern.findall(code)) # Build allowed set allowed = set(_BUILTIN_MODULES) # Map pip package names to import names pip_to_import = { "torch-geometric": "torch_geometric", "ogb": "ogb", "stable-baselines3": "stable_baselines3", "xgboost": "xgboost", "opencv-python": "cv2", "scikit-learn": "sklearn", "gymnasium[mujoco]": "gymnasium", "huggingface_hub": "huggingface_hub", } for pkg in extra_requirements: import_name = pip_to_import.get(pkg, pkg.replace("-", "_")) allowed.add(import_name) for mod in imports: if mod not in allowed: warnings.append( f"{label}: import '{mod}' not in Docker image or requirements" ) return warnings def _llm_review( self, data_code: str, baseline_code: str, setup_code: str, benchmark_names: list[str], baseline_names: list[str], ) -> dict[str, Any]: """Ask LLM to review generated code for common pitfalls.""" system = ( "You are a code reviewer specializing in ML experiment code. " "Review the following generated code for correctness.\n\n" "Check for:\n" "1. Correct API usage (torchvision, HuggingFace datasets, PyG, etc.)\n" "2. Proper data transforms and normalization\n" "3. Correct train/val/test split handling\n" "4. Compatible input/output dimensions between data and models\n" "5. Missing error handling for optional dependencies\n" "6. Hardcoded paths that should use variables\n" "7. Missing download=True in setup.py for tier 2 datasets\n\n" "Return JSON:\n" "{\n" ' "passed": true/false,\n' ' "issues": ["issue 1", "issue 2"],\n' ' "suggestions": ["suggestion 1"],\n' ' "severity": "none" | "warning" | "error"\n' "}" ) code_sections = [] if data_code: code_sections.append(f"## Data Loading Code\n```python\n{data_code}\n```") if baseline_code: code_sections.append(f"## Baseline Code\n```python\n{baseline_code}\n```") if setup_code: code_sections.append(f"## Setup Script\n```python\n{setup_code}\n```") user = ( f"Benchmarks: {', '.join(benchmark_names)}\n" f"Baselines: {', '.join(baseline_names)}\n\n" + "\n\n".join(code_sections) ) return self._chat_json(system, user, max_tokens=2048) # -- Main entry point -------------------------------------------------- def execute(self, context: dict[str, Any]) -> AgentStepResult: """Validate all generated code artifacts. Context keys: acquisition (dict): Output from AcquirerAgent """ acq = context.get("acquisition", {}) data_code = acq.get("data_loader_code", "") baseline_code = acq.get("baseline_code", "") setup_code = acq.get("setup_code", "") requirements = acq.get("requirements", "") benchmark_names = acq.get("benchmark_names", []) baseline_names = acq.get("baseline_names", []) extra_pip = [r.strip() for r in requirements.split("\n") if r.strip()] all_errors: list[str] = [] all_warnings: list[str] = [] # 1. Syntax checks for code, label in [ (data_code, "data_loader"), (baseline_code, "baseline"), (setup_code, "setup"), ]: errors = self._check_syntax(code, label) all_errors.extend(errors) # 2. Import checks for code, label in [ (data_code, "data_loader"), (baseline_code, "baseline"), ]: warnings = self._check_imports(code, label, extra_pip) all_warnings.extend(warnings) # 3. LLM review (only if no syntax errors) llm_review: dict[str, Any] = {} if not all_errors: llm_review = self._llm_review( data_code, baseline_code, setup_code, benchmark_names, baseline_names, ) if llm_review.get("severity") == "error": all_errors.extend(llm_review.get("issues", [])) elif llm_review.get("issues"): all_warnings.extend(llm_review.get("issues", [])) passed = len(all_errors) == 0 severity = "error" if all_errors else ("warning" if all_warnings else "none") result = { "passed": passed, "errors": all_errors, "warnings": all_warnings, "severity": severity, "llm_review": llm_review, "suggestions": llm_review.get("suggestions", []), } self.logger.info( "Validation %s: %d errors, %d warnings", "PASSED" if passed else "FAILED", len(all_errors), len(all_warnings), ) return self._make_result(passed, data=result) ================================================ FILE: researchclaw/agents/code_searcher/__init__.py ================================================ """Code Searcher agent — searches GitHub for reference code before generation. This agent searches GitHub repositories and code to find relevant examples that inform the blueprint generation process, especially for domains where the LLM's internal knowledge may be insufficient. """ from researchclaw.agents.code_searcher.agent import CodeSearchAgent, CodeSearchResult __all__ = ["CodeSearchAgent", "CodeSearchResult"] ================================================ FILE: researchclaw/agents/code_searcher/agent.py ================================================ """Code Search Agent — orchestrates GitHub search, pattern extraction, and caching. This is the main entry point for code search. It: 1. Checks cache for existing results 2. Generates search queries (LLM or heuristic) 3. Searches GitHub for repos and code 4. Reads key files from top repos 5. Extracts patterns using LLM 6. Caches results for future use """ from __future__ import annotations import logging from dataclasses import dataclass, field from typing import Any from researchclaw.agents.code_searcher.cache import SearchCache from researchclaw.agents.code_searcher.github_client import ( CodeSnippet, GitHubClient, RepoAnalysis, RepoInfo, ) from researchclaw.agents.code_searcher.pattern_extractor import CodePatterns, extract_patterns from researchclaw.agents.code_searcher.query_gen import generate_search_queries from researchclaw.domains.detector import DomainProfile logger = logging.getLogger(__name__) @dataclass class CodeSearchResult: """Complete result from a code search operation.""" patterns: CodePatterns = field(default_factory=CodePatterns) repos_found: list[RepoInfo] = field(default_factory=list) snippets_found: list[CodeSnippet] = field(default_factory=list) repo_analyses: list[RepoAnalysis] = field(default_factory=list) queries_used: list[str] = field(default_factory=list) from_cache: bool = False github_requests: int = 0 def to_prompt_context(self) -> str: """Format as context block for injection into code generation prompts.""" if not self.patterns.has_content: return "" return self.patterns.to_prompt_context() def to_cache_dict(self) -> dict[str, Any]: """Serialize for caching.""" return { "api_patterns": self.patterns.api_patterns, "file_structure": self.patterns.file_structure, "evaluation_patterns": self.patterns.evaluation_patterns, "library_versions": self.patterns.library_versions, "repos": [ { "full_name": r.full_name, "description": r.description, "stars": r.stars, "html_url": r.html_url, } for r in self.repos_found[:5] ], "queries": self.queries_used, } @classmethod def from_cache_dict(cls, data: dict[str, Any]) -> CodeSearchResult: """Deserialize from cache.""" patterns = CodePatterns( api_patterns=data.get("api_patterns", []), file_structure=data.get("file_structure", {}), evaluation_patterns=data.get("evaluation_patterns", []), library_versions=data.get("library_versions", {}), ) repos = [ RepoInfo( full_name=r.get("full_name", ""), description=r.get("description", ""), stars=r.get("stars", 0), html_url=r.get("html_url", ""), ) for r in data.get("repos", []) ] return cls( patterns=patterns, repos_found=repos, queries_used=data.get("queries", []), from_cache=True, ) class CodeSearchAgent: """Orchestrates code search for reference material before code generation. Usage:: agent = CodeSearchAgent(llm=llm_client) result = agent.search( topic="PDE solver comparison", domain=domain_profile, specific_needs=["finite element method", "convergence test"], ) context = result.to_prompt_context() """ def __init__( self, llm: Any | None = None, github_token: str | None = None, cache: SearchCache | None = None, max_repos_to_analyze: int = 3, max_code_searches: int = 3, ) -> None: self._llm = llm self._github = GitHubClient(token=github_token) self._cache = cache or SearchCache() self._max_repos = max_repos_to_analyze self._max_code_searches = max_code_searches def search( self, topic: str, domain: DomainProfile, specific_needs: list[str] | None = None, ) -> CodeSearchResult: """Execute a complete code search for a research topic. Flow: 1. Check cache 2. Generate search queries 3. Search GitHub repos + code 4. Read key files from top repos 5. Extract patterns 6. Cache results Parameters ---------- topic : str Research topic. domain : DomainProfile Detected domain profile. specific_needs : list[str], optional Specific library/API needs. Returns ------- CodeSearchResult """ logger.info("Code search started for: %.60s (domain=%s)", topic, domain.domain_id) # 1. Check cache cached = self._cache.get(domain.domain_id, topic) if cached: logger.info("Using cached code search results") return CodeSearchResult.from_cache_dict(cached) # 2. Generate search queries queries = generate_search_queries( topic=topic, domain_name=domain.display_name, core_libraries=domain.core_libraries, specific_needs=specific_needs, llm=self._llm, ) # Add domain-specific search terms from profile if domain.github_search_terms: for term in domain.github_search_terms[:2]: if term not in queries: queries.append(term) result = CodeSearchResult(queries_used=queries) # 3. Search GitHub repos (use first query) if queries: try: repos = self._github.search_repos(queries[0], max_results=10) # Filter: recent, well-starred repos = [ r for r in repos if r.stars >= 10 # minimum quality threshold ] result.repos_found = repos[:self._max_repos * 2] except Exception: logger.warning("Repo search failed, continuing", exc_info=True) # 4. Search GitHub code (use remaining queries) code_snippets: list[str] = [] for query in queries[1:self._max_code_searches + 1]: try: snippets = self._github.search_code(query, max_results=5) result.snippets_found.extend(snippets) except Exception: logger.warning("Code search failed for query: %s", query) # 5. Read key files from top repos for repo in result.repos_found[:self._max_repos]: try: analysis = self._analyze_repo(repo) if analysis: result.repo_analyses.append(analysis) # Collect code snippets for content in analysis.key_files.values(): if content: code_snippets.append(content) except Exception: logger.warning("Failed to analyze repo: %s", repo.full_name) # Also fetch content for code search results for snippet in result.snippets_found[:5]: try: content = self._github.get_file_content( snippet.repo_full_name, snippet.file_path, ) if content: snippet.content = content code_snippets.append(content) except Exception: pass # 6. Extract patterns if code_snippets: result.patterns = extract_patterns( code_snippets=code_snippets, topic=topic, domain_name=domain.display_name, llm=self._llm, ) result.github_requests = self._github.request_count # 7. Cache results if result.patterns.has_content: self._cache.put(domain.domain_id, topic, result.to_cache_dict()) logger.info( "Code search complete: %d repos, %d snippets, %d patterns, %d API calls", len(result.repos_found), len(result.snippets_found), len(result.patterns.api_patterns), result.github_requests, ) return result def _analyze_repo(self, repo: RepoInfo) -> RepoAnalysis | None: """Analyze a repository by reading key files.""" analysis = RepoAnalysis(repo=repo) # Get README readme = self._github.get_readme(repo.full_name) if readme: analysis.readme = readme[:3000] # truncate # Get file tree file_tree = self._github.get_repo_tree( repo.full_name, repo.default_branch, ) analysis.file_tree = file_tree # Identify and read key files key_patterns = [ "main.py", "run.py", "train.py", "experiment.py", "requirements.txt", "setup.py", "pyproject.toml", ] for pattern in key_patterns: matches = [f for f in file_tree if f.endswith(pattern)] for match in matches[:1]: # first match only content = self._github.get_file_content( repo.full_name, match, max_size_kb=50, ) if content: analysis.key_files[match] = content # Parse requirements req_content = analysis.key_files.get("requirements.txt", "") if req_content: analysis.requirements = [ line.strip().split("==")[0].split(">=")[0] for line in req_content.splitlines() if line.strip() and not line.startswith("#") ] return analysis ================================================ FILE: researchclaw/agents/code_searcher/cache.py ================================================ """Disk-based cache for code search results. Caches search results by domain + topic hash with a configurable TTL (default 30 days). This avoids redundant GitHub API calls for similar topics within the same domain. """ from __future__ import annotations import hashlib import json import logging import time from dataclasses import asdict from pathlib import Path from typing import Any logger = logging.getLogger(__name__) _DEFAULT_CACHE_DIR = Path(__file__).parent.parent.parent / "data" / "code_search_cache" _DEFAULT_TTL_DAYS = 30 class SearchCache: """Disk-based cache for code search results. Cache structure:: code_search_cache/ {domain_id}/ {topic_hash}.json """ def __init__( self, cache_dir: Path | None = None, ttl_days: int = _DEFAULT_TTL_DAYS, ) -> None: self._cache_dir = cache_dir or _DEFAULT_CACHE_DIR self._ttl_sec = ttl_days * 86400 def get(self, domain_id: str, topic: str) -> dict[str, Any] | None: """Get cached result if it exists and is not expired.""" cache_path = self._cache_path(domain_id, topic) if not cache_path.exists(): return None try: data = json.loads(cache_path.read_text(encoding="utf-8")) timestamp = data.get("_cached_at", 0) if time.time() - timestamp > self._ttl_sec: logger.debug("Cache expired for %s/%s", domain_id, topic[:40]) cache_path.unlink(missing_ok=True) return None logger.info("Cache hit for %s/%s", domain_id, topic[:40]) return data except Exception: logger.warning("Failed to read cache", exc_info=True) return None def put(self, domain_id: str, topic: str, data: dict[str, Any]) -> None: """Store a result in the cache.""" cache_path = self._cache_path(domain_id, topic) cache_path.parent.mkdir(parents=True, exist_ok=True) data["_cached_at"] = time.time() data["_domain_id"] = domain_id data["_topic_hash"] = self._topic_hash(topic) try: cache_path.write_text( json.dumps(data, indent=2, default=str), encoding="utf-8", ) logger.debug("Cached result for %s/%s", domain_id, topic[:40]) except Exception: logger.warning("Failed to write cache", exc_info=True) def clear(self, domain_id: str | None = None) -> int: """Clear cache. Returns number of entries removed.""" count = 0 if domain_id: domain_dir = self._cache_dir / domain_id if domain_dir.is_dir(): for f in domain_dir.glob("*.json"): f.unlink() count += 1 else: if self._cache_dir.is_dir(): for f in self._cache_dir.rglob("*.json"): f.unlink() count += 1 return count def stats(self) -> dict[str, int]: """Return cache statistics.""" total = 0 expired = 0 by_domain: dict[str, int] = {} if not self._cache_dir.is_dir(): return {"total": 0, "expired": 0} for f in self._cache_dir.rglob("*.json"): total += 1 domain = f.parent.name by_domain[domain] = by_domain.get(domain, 0) + 1 try: data = json.loads(f.read_text(encoding="utf-8")) if time.time() - data.get("_cached_at", 0) > self._ttl_sec: expired += 1 except Exception: pass return {"total": total, "expired": expired, **by_domain} def _cache_path(self, domain_id: str, topic: str) -> Path: return self._cache_dir / domain_id / f"{self._topic_hash(topic)}.json" @staticmethod def _topic_hash(topic: str) -> str: return hashlib.sha256(topic.lower().strip().encode()).hexdigest()[:16] ================================================ FILE: researchclaw/agents/code_searcher/github_client.py ================================================ """GitHub REST API client for code and repository search. Handles rate limiting, authentication, and response parsing for: - Repository search (``/search/repositories``) - Code search (``/search/code``) - File content retrieval (``/repos/{owner}/{repo}/contents/{path}``) - README retrieval Rate limits: - Authenticated: 30 req/min for search, 5000 req/hr for core - Code search: 10 req/min - Unauthenticated: 10 req/min for search """ from __future__ import annotations import logging import os import time from dataclasses import dataclass, field from typing import Any from urllib.parse import quote logger = logging.getLogger(__name__) _GITHUB_API = "https://api.github.com" @dataclass class RepoInfo: """Summary of a GitHub repository.""" full_name: str # "owner/repo" description: str = "" stars: int = 0 language: str = "" updated_at: str = "" html_url: str = "" default_branch: str = "main" topics: list[str] = field(default_factory=list) @dataclass class CodeSnippet: """A code snippet found via GitHub code search.""" repo_full_name: str file_path: str file_url: str = "" content: str = "" # populated after fetching score: float = 0.0 @dataclass class RepoAnalysis: """Analysis of a repository's structure and content.""" repo: RepoInfo readme: str = "" requirements: list[str] = field(default_factory=list) key_files: dict[str, str] = field(default_factory=dict) # path -> content file_tree: list[str] = field(default_factory=list) class GitHubClient: """GitHub REST API client with rate limiting and caching. Uses ``GITHUB_TOKEN`` env var for authentication (strongly recommended). Falls back to unauthenticated access (much lower rate limits). """ def __init__(self, token: str | None = None) -> None: self._token = token or os.environ.get("GITHUB_TOKEN", "") self._last_search_time: float = 0 self._search_interval: float = 6.0 # 10 req/min → 6s between requests self._request_count: int = 0 def _headers(self) -> dict[str, str]: headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", } if self._token: headers["Authorization"] = f"Bearer {self._token}" return headers def _rate_limit_wait(self) -> None: """Enforce rate limiting between search requests.""" elapsed = time.time() - self._last_search_time if elapsed < self._search_interval: wait = self._search_interval - elapsed logger.debug("Rate limit: waiting %.1fs", wait) time.sleep(wait) self._last_search_time = time.time() def _get(self, url: str, params: dict[str, str] | None = None) -> dict[str, Any] | None: """Make a GET request to the GitHub API.""" import urllib.request import urllib.error import json if params: query_str = "&".join(f"{k}={quote(str(v))}" for k, v in params.items()) url = f"{url}?{query_str}" req = urllib.request.Request(url, headers=self._headers()) self._request_count += 1 try: with urllib.request.urlopen(req, timeout=15) as resp: return json.loads(resp.read().decode("utf-8")) except urllib.error.HTTPError as e: if e.code == 403: logger.warning("GitHub API rate limited (403). Skipping.") return None if e.code == 422: logger.warning("GitHub API validation error (422): %s", url) return None logger.warning("GitHub API error %d: %s", e.code, url) return None except Exception: logger.warning("GitHub API request failed: %s", url, exc_info=True) return None def search_repos( self, query: str, language: str = "Python", sort: str = "stars", max_results: int = 10, ) -> list[RepoInfo]: """Search for repositories matching a query. Parameters ---------- query : str Search query (e.g., "PDE solver finite element"). language : str Filter by programming language. sort : str Sort order: "stars", "updated", "best-match". max_results : int Maximum number of results to return. Returns ------- list[RepoInfo] """ self._rate_limit_wait() search_q = f"{query} language:{language}" params = { "q": search_q, "sort": sort, "order": "desc", "per_page": str(min(max_results, 30)), } data = self._get(f"{_GITHUB_API}/search/repositories", params) if data is None: return [] repos: list[RepoInfo] = [] for item in data.get("items", [])[:max_results]: repos.append(RepoInfo( full_name=item.get("full_name", ""), description=item.get("description", "") or "", stars=item.get("stargazers_count", 0), language=item.get("language", "") or "", updated_at=item.get("updated_at", ""), html_url=item.get("html_url", ""), default_branch=item.get("default_branch", "main"), topics=item.get("topics", []), )) logger.info("Found %d repos for query: %.60s", len(repos), query) return repos def search_code( self, query: str, language: str = "Python", max_results: int = 10, ) -> list[CodeSnippet]: """Search for code snippets matching a query. Note: Code search has stricter rate limits (10 req/min). Parameters ---------- query : str Search query (e.g., "from pyscf import gto scf"). language : str Filter by programming language. max_results : int Maximum results. Returns ------- list[CodeSnippet] """ self._rate_limit_wait() search_q = f"{query} language:{language}" params = { "q": search_q, "per_page": str(min(max_results, 30)), } data = self._get(f"{_GITHUB_API}/search/code", params) if data is None: return [] snippets: list[CodeSnippet] = [] for item in data.get("items", [])[:max_results]: repo = item.get("repository", {}) snippets.append(CodeSnippet( repo_full_name=repo.get("full_name", ""), file_path=item.get("path", ""), file_url=item.get("html_url", ""), score=item.get("score", 0.0), )) logger.info("Found %d code snippets for query: %.60s", len(snippets), query) return snippets def get_file_content( self, repo_full_name: str, path: str, max_size_kb: int = 100, ) -> str | None: """Get the content of a file from a repository. Parameters ---------- repo_full_name : str Repository in "owner/repo" format. path : str File path within the repository. max_size_kb : int Skip files larger than this. Returns ------- str or None File content, or None if not found/too large. """ import base64 url = f"{_GITHUB_API}/repos/{repo_full_name}/contents/{quote(path, safe='/')}" data = self._get(url) if data is None: return None size = data.get("size", 0) if size > max_size_kb * 1024: logger.debug("File too large (%d KB): %s/%s", size // 1024, repo_full_name, path) return None content = data.get("content", "") encoding = data.get("encoding", "") if encoding == "base64": try: return base64.b64decode(content).decode("utf-8", errors="replace") except Exception: return None return content def get_readme(self, repo_full_name: str) -> str | None: """Get the README content of a repository.""" import base64 url = f"{_GITHUB_API}/repos/{repo_full_name}/readme" data = self._get(url) if data is None: return None content = data.get("content", "") encoding = data.get("encoding", "") if encoding == "base64": try: return base64.b64decode(content).decode("utf-8", errors="replace") except Exception: return None return content def get_repo_tree( self, repo_full_name: str, branch: str = "main", ) -> list[str]: """Get the file tree of a repository (flat list of paths).""" url = f"{_GITHUB_API}/repos/{repo_full_name}/git/trees/{branch}" params = {"recursive": "1"} data = self._get(url, params) if data is None: return [] tree = data.get("tree", []) return [item["path"] for item in tree if item.get("type") == "blob"] @property def request_count(self) -> int: return self._request_count @property def has_token(self) -> bool: return bool(self._token) ================================================ FILE: researchclaw/agents/code_searcher/pattern_extractor.py ================================================ """Extract reusable code patterns from GitHub search results. Uses LLM to analyze reference code and extract: - API call patterns (how to use a specific library) - File organization patterns (project structure) - Data processing patterns (data loading / preprocessing) - Evaluation patterns (how to compute and report metrics) """ from __future__ import annotations import json import logging import re from dataclasses import dataclass, field from typing import Any logger = logging.getLogger(__name__) @dataclass class CodePatterns: """Extracted patterns from reference code.""" api_patterns: list[str] = field(default_factory=list) file_structure: dict[str, str] = field(default_factory=dict) data_patterns: list[str] = field(default_factory=list) evaluation_patterns: list[str] = field(default_factory=list) library_versions: dict[str, str] = field(default_factory=dict) raw_snippets: list[str] = field(default_factory=list) def to_prompt_context(self) -> str: """Format patterns as context for code generation prompts.""" parts: list[str] = [] if self.api_patterns: parts.append("## Reference API Usage Patterns") for i, pattern in enumerate(self.api_patterns[:5], 1): parts.append(f"### Pattern {i}") parts.append(f"```python\n{pattern}\n```") if self.file_structure: parts.append("\n## Reference Project Structure") for fname, desc in self.file_structure.items(): parts.append(f"- `{fname}`: {desc}") if self.evaluation_patterns: parts.append("\n## Reference Evaluation Patterns") for pattern in self.evaluation_patterns[:3]: parts.append(f"```python\n{pattern}\n```") return "\n".join(parts) @property def has_content(self) -> bool: return bool(self.api_patterns or self.file_structure or self.evaluation_patterns) _EXTRACT_PROMPT = """\ You are analyzing reference code to extract reusable patterns for a research project. Research topic: {topic} Domain: {domain_name} Here are code snippets from relevant GitHub repositories: {code_snippets} Extract the following patterns as JSON: {{ "api_patterns": [ "# Short, self-contained code snippet showing key API usage", "# Each should be 3-10 lines showing one specific API call pattern" ], "file_structure": {{ "filename.py": "what this file does" }}, "evaluation_patterns": [ "# How results are computed and reported" ], "library_versions": {{ "library_name": "recommended version" }} }} Focus on: 1. How the core libraries are imported and used 2. Common data loading / preprocessing patterns 3. How experiments are structured 4. How results are computed and reported Return ONLY valid JSON.""" def extract_patterns( code_snippets: list[str], topic: str, domain_name: str, llm: Any | None = None, ) -> CodePatterns: """Extract code patterns from reference snippets. Parameters ---------- code_snippets : list[str] Code content from GitHub repos. topic : str Research topic for context. domain_name : str Domain name for context. llm : LLMClient, optional LLM for pattern extraction. Falls back to heuristic if not provided. Returns ------- CodePatterns """ if not code_snippets: return CodePatterns() if llm is not None: return _llm_extract(code_snippets, topic, domain_name, llm) return _heuristic_extract(code_snippets) def _llm_extract( snippets: list[str], topic: str, domain_name: str, llm: Any, ) -> CodePatterns: """Extract patterns using LLM analysis.""" try: # Truncate snippets to fit context combined = "" for i, snippet in enumerate(snippets[:5]): truncated = snippet[:2000] if len(snippet) > 2000 else snippet combined += f"\n--- Snippet {i+1} ---\n{truncated}\n" prompt = _EXTRACT_PROMPT.format( topic=topic, domain_name=domain_name, code_snippets=combined, ) if hasattr(llm, "chat"): import asyncio try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop and loop.is_running(): return _heuristic_extract(snippets) resp = llm.chat( [{"role": "user", "content": prompt}], system="You extract code patterns as JSON.", max_tokens=1500, ) else: return _heuristic_extract(snippets) content = resp.content if hasattr(resp, "content") else str(resp) # Parse JSON from response json_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", content, re.DOTALL) if json_match: data = json.loads(json_match.group()) return CodePatterns( api_patterns=data.get("api_patterns", []), file_structure=data.get("file_structure", {}), evaluation_patterns=data.get("evaluation_patterns", []), library_versions=data.get("library_versions", {}), raw_snippets=snippets[:5], ) except Exception: logger.warning("LLM pattern extraction failed", exc_info=True) return _heuristic_extract(snippets) def _heuristic_extract(snippets: list[str]) -> CodePatterns: """Extract patterns using regex heuristics (no LLM needed).""" patterns = CodePatterns(raw_snippets=snippets[:5]) for snippet in snippets: # Extract import statements as API patterns imports = re.findall(r"^(?:from|import)\s+.+$", snippet, re.MULTILINE) for imp in imports[:10]: if imp not in patterns.api_patterns: patterns.api_patterns.append(imp) # Extract function/class definitions for structure hints defs = re.findall(r"^(?:def|class)\s+(\w+)", snippet, re.MULTILINE) for d in defs[:5]: if d not in patterns.file_structure: patterns.file_structure[d] = "detected function/class" # Deduplicate patterns.api_patterns = list(dict.fromkeys(patterns.api_patterns))[:10] return patterns ================================================ FILE: researchclaw/agents/code_searcher/query_gen.py ================================================ """LLM-based search query generation for code search. Given a research topic and domain, generates targeted search queries for GitHub repository and code search. """ from __future__ import annotations import json import logging import re from typing import Any logger = logging.getLogger(__name__) _QUERY_GEN_PROMPT = """\ You are generating GitHub search queries to find reference code for a research experiment. Research topic: {topic} Domain: {domain_name} Core libraries: {libraries} Specific needs: {needs} Generate 3-5 search queries that will help find: 1. Example implementations using the domain's core libraries 2. Similar research projects or experiments 3. Specific API usage patterns needed for this experiment Rules: - Each query should be 3-8 words (GitHub search works best with short queries) - Include library names when searching for API usage - Include domain-specific terms - Focus on FINDING CODE, not documentation Respond as a JSON array of strings. Example: ["pyscf DFT hartree fock example", "molecular energy calculation python"] Queries:""" def generate_search_queries( topic: str, domain_name: str, core_libraries: list[str], specific_needs: list[str] | None = None, llm: Any | None = None, ) -> list[str]: """Generate search queries for GitHub code search. If no LLM is provided, generates queries from topic keywords and library names using heuristic rules. Parameters ---------- topic : str Research topic. domain_name : str Domain display name. core_libraries : list[str] Domain's core libraries. specific_needs : list[str], optional Specific API/library needs. llm : LLMClient, optional LLM for query generation. Returns ------- list[str] 3-5 search queries. """ if llm is not None: return _llm_generate(topic, domain_name, core_libraries, specific_needs or [], llm) return _heuristic_generate(topic, domain_name, core_libraries, specific_needs or []) def _heuristic_generate( topic: str, domain_name: str, libraries: list[str], needs: list[str], ) -> list[str]: """Generate queries without LLM using keyword extraction.""" queries: list[str] = [] # Clean topic: extract key phrases topic_words = _extract_key_phrases(topic) # Query 1: Topic + main library if libraries: queries.append(f"{topic_words} {libraries[0]}") # Query 2: Domain + "python example" queries.append(f"{domain_name.lower()} python example") # Query 3: Specific library usage for lib in libraries[:2]: queries.append(f"{lib} example tutorial python") # Query 4: Specific needs for need in needs[:2]: queries.append(f"{need} python") # Deduplicate and limit seen: set[str] = set() unique: list[str] = [] for q in queries: q_norm = q.lower().strip() if q_norm not in seen: seen.add(q_norm) unique.append(q) return unique[:5] def _llm_generate( topic: str, domain_name: str, libraries: list[str], needs: list[str], llm: Any, ) -> list[str]: """Generate queries using LLM.""" try: prompt = _QUERY_GEN_PROMPT.format( topic=topic, domain_name=domain_name, libraries=", ".join(libraries), needs=", ".join(needs) if needs else "general usage", ) # Synchronous LLM call — LLMClient.chat() is sync and takes # (messages, *, system=, max_tokens=) signature. if hasattr(llm, "chat"): resp = llm.chat( [{"role": "user", "content": prompt}], system="You generate concise GitHub search queries.", max_tokens=200, ) else: return _heuristic_generate(topic, domain_name, libraries, needs) content = resp.content if hasattr(resp, "content") else str(resp) # Parse JSON array from response json_match = re.search(r"\[.*\]", content, re.DOTALL) if json_match: queries = json.loads(json_match.group()) if isinstance(queries, list) and all(isinstance(q, str) for q in queries): return queries[:5] logger.warning("Failed to parse LLM query response, using heuristic") return _heuristic_generate(topic, domain_name, libraries, needs) except Exception: logger.warning("LLM query generation failed", exc_info=True) return _heuristic_generate(topic, domain_name, libraries, needs) def _extract_key_phrases(text: str, max_words: int = 5) -> str: """Extract key phrases from a research topic.""" # Remove common filler words stop_words = { "a", "an", "the", "of", "for", "in", "on", "with", "and", "or", "to", "by", "is", "are", "using", "based", "via", "through", "novel", "new", "improved", "efficient", "towards", } words = text.lower().split() key_words = [w for w in words if w not in stop_words and len(w) > 2] return " ".join(key_words[:max_words]) ================================================ FILE: researchclaw/agents/figure_agent/__init__.py ================================================ """FigureAgent — multi-agent intelligent chart generation system. Architecture ------------ 1. **Planner** — analyzes experiment results and determines which charts to generate, their types, layouts, and captions. 2. **CodeGen** — generates Python matplotlib plotting scripts using academic styling (SciencePlots, 300 DPI, colorblind-safe palettes). 3. **Renderer** — executes plotting scripts and verifies output files. 4. **Critic** — tri-modal review: numerical accuracy, text correctness, and visual quality assessment. 5. **Integrator** — determines figure placement in the paper and generates markdown references with captions. The ``FigureOrchestrator`` coordinates all agents and produces a ``FigurePlan`` consumed by downstream pipeline stages (paper draft, paper export). """ from researchclaw.agents.figure_agent.orchestrator import ( FigureOrchestrator, FigurePlan, ) __all__ = ["FigureOrchestrator", "FigurePlan"] ================================================ FILE: researchclaw/agents/figure_agent/codegen.py ================================================ """CodeGen Agent — generates visualization code for each figure. Takes the Planner's figure specifications and experiment data, then generates either: - Standalone Python scripts (Matplotlib/Seaborn) — run by Renderer - LaTeX code (TikZ/PGFPlots) — embedded directly in the paper Architecture follows Visual ChatGPT (Wu et al., 2023): the LLM acts as a *controller* calling deterministic render tools instead of generating pixels directly. """ from __future__ import annotations import json import logging import re from pathlib import Path from typing import Any from researchclaw.agents.base import BaseAgent, AgentStepResult from researchclaw.agents.figure_agent.style_config import get_style_preamble from researchclaw.utils.sanitize import sanitize_figure_id from researchclaw.utils.thinking_tags import strip_thinking_tags logger = logging.getLogger(__name__) def _esc(s: str) -> str: """Escape curly braces in user-provided strings for str.format().""" return s.replace("{", "{{").replace("}", "}}") # --------------------------------------------------------------------------- # Degenerate data detection # --------------------------------------------------------------------------- def _is_degenerate_data(values: list[float]) -> bool: """Return True if data values are too degenerate to produce a useful chart. Rejects: empty lists, all-zero, all-identical, or single-value data. """ if not values or len(values) < 1: return True if all(v == 0 for v in values): return True if len(values) >= 2 and len(set(round(v, 6) for v in values)) <= 1: return True return False # --------------------------------------------------------------------------- # Metric name humanization # --------------------------------------------------------------------------- _METRIC_DISPLAY_NAMES: dict[str, str] = { "primary_metric": "Performance", "accuracy": "Accuracy (%)", "loss": "Loss", "f1_score": "F1 Score", "precision": "Precision", "recall": "Recall", "reward": "Reward", "return": "Return", "mse": "MSE", "mae": "MAE", "rmse": "RMSE", "bleu": "BLEU", "rouge": "ROUGE", "perplexity": "Perplexity", "auc": "AUC", } def _humanize_label(raw: str) -> str: """Convert raw metric names like 'primary_metric' to human-readable labels.""" if not raw: return "" low = raw.lower().strip() if low in _METRIC_DISPLAY_NAMES: return _METRIC_DISPLAY_NAMES[low] # Convert snake_case to Title Case return raw.replace("_", " ").title() # --------------------------------------------------------------------------- # Built-in chart templates # --------------------------------------------------------------------------- _TEMPLATE_BAR_COMPARISON = ''' {style_preamble} # Data conditions = {conditions} values = {values} ci_low = {ci_low} ci_high = {ci_high} # Plot fig, ax = plt.subplots(figsize=({width}, {height}), constrained_layout=True) x = np.arange(len(conditions)) bar_colors = [COLORS[i % len(COLORS)] for i in range(len(conditions))] yerr_lo = [max(0, v - lo) for v, lo in zip(values, ci_low)] yerr_hi = [max(0, hi - v) for v, hi in zip(values, ci_high)] bars = ax.bar(x, values, color=bar_colors, alpha=0.85, edgecolor="white", linewidth=0.5) ax.errorbar(x, values, yerr=[yerr_lo, yerr_hi], fmt="none", ecolor="#333", capsize=4, capthick=1.2, linewidth=1.2) # Value labels offset = max(yerr_hi) * 0.08 if yerr_hi and max(yerr_hi) > 0 else max(values) * 0.02 for i, v in enumerate(values): ax.text(i, v + offset, f"{{v:.4f}}", ha="center", va="bottom", fontweight="bold") ax.set_xlabel("{x_label}") ax.set_ylabel("{y_label}") ax.set_title("{title}") ax.set_xticks(x) ax.set_xticklabels([c.replace("_", " ") for c in conditions], rotation=25, ha="right") ax.grid(True, axis="y", alpha=0.3) ax.set_axisbelow(True) fig.savefig("{output_path}") plt.close(fig) print(f"Saved: {output_path}") ''' _TEMPLATE_GROUPED_BAR = ''' {style_preamble} # Data: conditions x metrics conditions = {conditions} metric_names = {metric_names} # data_matrix[i][j] = value for condition i, metric j data_matrix = {data_matrix} # Plot n_groups = len(conditions) n_bars = len(metric_names) fig, ax = plt.subplots(figsize=({width}, {height}), constrained_layout=True) x = np.arange(n_groups) bar_width = 0.8 / n_bars for j, metric in enumerate(metric_names): offset = (j - n_bars / 2 + 0.5) * bar_width vals = [data_matrix[i][j] for i in range(n_groups)] ax.bar(x + offset, vals, bar_width, label=metric.replace("_", " "), color=COLORS[j % len(COLORS)], alpha=0.85, edgecolor="white", linewidth=0.5) ax.set_xlabel("{x_label}") ax.set_ylabel("{y_label}") ax.set_title("{title}") ax.set_xticks(x) ax.set_xticklabels([c.replace("_", " ") for c in conditions], rotation=25, ha="right") ax.legend(loc="upper left", bbox_to_anchor=(0, 1), framealpha=0.9, edgecolor="gray") ax.grid(True, axis="y", alpha=0.3) ax.set_axisbelow(True) fig.savefig("{output_path}") plt.close(fig) print(f"Saved: {output_path}") ''' _TEMPLATE_TRAINING_CURVE = ''' {style_preamble} # Data: each series is (label, epochs, values, [optional std]) series_data = {series_data} fig, ax = plt.subplots(figsize=({width}, {height}), constrained_layout=True) for idx, series in enumerate(series_data): label = series["label"] epochs = series["epochs"] values = series["values"] color = COLORS[idx % len(COLORS)] ls = LINE_STYLES[idx % len(LINE_STYLES)] marker = MARKERS[idx % len(MARKERS)] ax.plot(epochs, values, linestyle=ls, color=color, linewidth=1.5, marker=marker, markersize=4, markevery=max(1, len(epochs)//10), label=label.replace("_", " ")) if "std" in series and series["std"]: std = series["std"] lower = [v - s for v, s in zip(values, std)] upper = [v + s for v, s in zip(values, std)] ax.fill_between(epochs, lower, upper, alpha=0.15, color=color) ax.set_xlabel("{x_label}") ax.set_ylabel("{y_label}") ax.set_title("{title}") ax.legend(loc="best", framealpha=0.9, edgecolor="gray") ax.grid(True, alpha=0.3) fig.savefig("{output_path}") plt.close(fig) print(f"Saved: {output_path}") ''' _TEMPLATE_HEATMAP = ''' {style_preamble} # Data row_labels = {row_labels} col_labels = {col_labels} data = np.array({data_matrix}) fig, ax = plt.subplots(figsize=({width}, {height}), constrained_layout=True) im = ax.imshow(data, cmap="cividis", aspect="auto") ax.set_xticks(np.arange(len(col_labels))) ax.set_yticks(np.arange(len(row_labels))) ax.set_xticklabels(col_labels, rotation=45, ha="right") ax.set_yticklabels(row_labels) # Annotate cells for i in range(len(row_labels)): for j in range(len(col_labels)): val = data[i, j] color = "white" if val > (data.max() + data.min()) / 2 else "black" ax.text(j, i, f"{{val:.3f}}", ha="center", va="center", color=color) ax.set_xlabel("{x_label}") ax.set_ylabel("{y_label}") ax.set_title("{title}") fig.colorbar(im, ax=ax, shrink=0.8) fig.savefig("{output_path}") plt.close(fig) print(f"Saved: {output_path}") ''' _TEMPLATE_LINE_MULTI = ''' {style_preamble} # Data: list of series dicts with label, x, y, [std] series_data = {series_data} fig, ax = plt.subplots(figsize=({width}, {height}), constrained_layout=True) for idx, series in enumerate(series_data): label = series["label"] x = series["x"] y = series["y"] color = COLORS[idx % len(COLORS)] ls = LINE_STYLES[idx % len(LINE_STYLES)] marker = MARKERS[idx % len(MARKERS)] ax.plot(x, y, linestyle=ls, color=color, linewidth=1.5, marker=marker, markersize=4, markevery=max(1, len(x)//8), label=label.replace("_", " ")) if "std" in series and series["std"]: std = series["std"] lower = [v - s for v, s in zip(y, std)] upper = [v + s for v, s in zip(y, std)] ax.fill_between(x, lower, upper, alpha=0.15, color=color) ax.set_xlabel("{x_label}") ax.set_ylabel("{y_label}") ax.set_title("{title}") ax.legend(loc="best", framealpha=0.9, edgecolor="gray") ax.grid(True, alpha=0.3) fig.savefig("{output_path}") plt.close(fig) print(f"Saved: {output_path}") ''' _TEMPLATE_SCATTER = ''' {style_preamble} # Data: list of groups with label, x, y groups = {groups} fig, ax = plt.subplots(figsize=({width}, {height}), constrained_layout=True) for idx, group in enumerate(groups): label = group["label"] x = group["x"] y = group["y"] color = COLORS[idx % len(COLORS)] marker = MARKERS[idx % len(MARKERS)] ax.scatter(x, y, c=color, marker=marker, s=40, alpha=0.7, label=label.replace("_", " ")) ax.set_xlabel("{x_label}") ax.set_ylabel("{y_label}") ax.set_title("{title}") ax.legend(loc="best", framealpha=0.9, edgecolor="gray") ax.grid(True, alpha=0.3) fig.savefig("{output_path}") plt.close(fig) print(f"Saved: {output_path}") ''' _TEMPLATES: dict[str, str] = { "bar_comparison": _TEMPLATE_BAR_COMPARISON, "ablation_grouped": _TEMPLATE_BAR_COMPARISON, # Same template, different data "grouped_bar": _TEMPLATE_GROUPED_BAR, "training_curve": _TEMPLATE_TRAINING_CURVE, "loss_curve": _TEMPLATE_TRAINING_CURVE, "heatmap": _TEMPLATE_HEATMAP, "confusion_matrix": _TEMPLATE_HEATMAP, "line_multi": _TEMPLATE_LINE_MULTI, "scatter_plot": _TEMPLATE_SCATTER, } # --------------------------------------------------------------------------- # LaTeX / PGFPlots templates — for direct LaTeX embedding # --------------------------------------------------------------------------- _LATEX_TEMPLATE_BAR_COMPARISON = r''' \begin{{figure}}[htbp] \centering \begin{{tikzpicture}} \begin{{axis}}[ ybar, bar width=15pt, width={width}cm, height={height}cm, xlabel={{{x_label}}}, ylabel={{{y_label}}}, title={{{title}}}, symbolic x coords={{{x_coords}}}, xtick=data, x tick label style={{rotate=25, anchor=east, font=\small}}, ymin=0, nodes near coords, nodes near coords align={{vertical}}, every node near coord/.append style={{font=\tiny}}, grid=major, grid style={{dashed, gray!30}}, ] \addplot[fill=blue!60, draw=blue!80] coordinates {{{coords}}}; \end{{axis}} \end{{tikzpicture}} \caption{{{caption}}} \label{{fig:{figure_id}}} \end{{figure}} ''' _LATEX_TEMPLATE_LINE = r''' \begin{{figure}}[htbp] \centering \begin{{tikzpicture}} \begin{{axis}}[ width={width}cm, height={height}cm, xlabel={{{x_label}}}, ylabel={{{y_label}}}, title={{{title}}}, legend pos=north west, grid=major, grid style={{dashed, gray!30}}, cycle list name=color list, ] {plot_commands} \end{{axis}} \end{{tikzpicture}} \caption{{{caption}}} \label{{fig:{figure_id}}} \end{{figure}} ''' _LATEX_TEMPLATE_HEATMAP = r''' \begin{{figure}}[htbp] \centering \begin{{tikzpicture}} \begin{{axis}}[ colormap/viridis, colorbar, width={width}cm, height={height}cm, xlabel={{{x_label}}}, ylabel={{{y_label}}}, title={{{title}}}, point meta min={meta_min}, point meta max={meta_max}, xtick={{{xtick}}}, ytick={{{ytick}}}, xticklabels={{{xticklabels}}}, yticklabels={{{yticklabels}}}, x tick label style={{rotate=45, anchor=east, font=\small}}, ] \addplot[matrix plot*, mesh/cols={cols}, mesh/rows={rows}, point meta=explicit] coordinates {{ {matrix_coords} }}; \end{{axis}} \end{{tikzpicture}} \caption{{{caption}}} \label{{fig:{figure_id}}} \end{{figure}} ''' _LATEX_TEMPLATES: dict[str, str] = { "bar_comparison": _LATEX_TEMPLATE_BAR_COMPARISON, "ablation_grouped": _LATEX_TEMPLATE_BAR_COMPARISON, "training_curve": _LATEX_TEMPLATE_LINE, "loss_curve": _LATEX_TEMPLATE_LINE, "line_multi": _LATEX_TEMPLATE_LINE, "heatmap": _LATEX_TEMPLATE_HEATMAP, "confusion_matrix": _LATEX_TEMPLATE_HEATMAP, } class CodeGenAgent(BaseAgent): """Generates visualization code (Python or LaTeX) for each planned figure. Supports two output formats: - ``"python"`` (default): Matplotlib/Seaborn scripts executed by Renderer - ``"latex"``: TikZ/PGFPlots code embedded directly in the paper """ name = "figure_codegen" def __init__(self, llm: Any, *, output_format: str = "python", use_docker: bool = False) -> None: super().__init__(llm) self._output_format = output_format # "python" or "latex" self._use_docker = use_docker # BUG-60: generate Docker paths when True # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def execute(self, context: dict[str, Any]) -> AgentStepResult: """Generate plotting scripts for all planned figures. Context keys: figures (list[dict]): Figure plan from Planner experiment_results (dict): Raw experiment data condition_summaries (dict): Per-condition aggregated stats metrics_summary (dict): Per-metric aggregated stats metric_key (str): Primary metric name output_dir (str): Directory for output scripts critic_feedback (list[dict], optional): Previous Critic feedback """ try: figures = context.get("figures", []) experiment_results = context.get("experiment_results", {}) condition_summaries = context.get("condition_summaries", {}) metrics_summary = context.get("metrics_summary", {}) metric_key = context.get("metric_key", "primary_metric") output_dir = context.get("output_dir", "charts") critic_feedback = context.get("critic_feedback", []) scripts: list[dict[str, Any]] = [] for fig_spec in figures: # BUG-36: skip non-dict entries (LLM may return strings) if not isinstance(fig_spec, dict): self.logger.warning("Skipping non-dict fig_spec: %s", type(fig_spec)) continue figure_id = fig_spec.get("figure_id", "unknown") chart_type = fig_spec.get("chart_type", "bar_comparison") # Check for critic feedback on this specific figure fig_feedback = None for fb in critic_feedback: # BUG-FIX: guard against non-dict entries in feedback if isinstance(fb, dict) and fb.get("figure_id") == figure_id: fig_feedback = fb break script = self._generate_script( fig_spec=fig_spec, chart_type=chart_type, condition_summaries=condition_summaries, metrics_summary=metrics_summary, experiment_results=experiment_results, metric_key=metric_key, output_dir=output_dir, critic_feedback=fig_feedback, ) scripts.append({ "figure_id": figure_id, "chart_type": chart_type, "script": script, "output_filename": f"{figure_id}.png", "title": fig_spec.get("title", ""), "caption": fig_spec.get("caption", ""), "section": fig_spec.get("section", "results"), "width": fig_spec.get("width", "single_column"), }) return self._make_result(True, data={"scripts": scripts}) except Exception as exc: self.logger.error("CodeGen failed: %s", exc) return self._make_result(False, error=str(exc)) # ------------------------------------------------------------------ # Script generation # ------------------------------------------------------------------ def _generate_script( self, *, fig_spec: dict[str, Any], chart_type: str, condition_summaries: dict[str, Any], metrics_summary: dict[str, Any], experiment_results: dict[str, Any], metric_key: str, output_dir: str, critic_feedback: dict[str, Any] | None, ) -> str: """Generate a plotting script for a single figure.""" figure_id = sanitize_figure_id(fig_spec.get("figure_id", "figure")) # BUG-20: Use absolute path to avoid CWD-relative savefig errors # BUG-60: When running in Docker, use container path directly so # renderer doesn't need fragile regex rewriting of host paths. if self._use_docker: output_path = f"/workspace/output/{figure_id}.png" else: output_path = str((Path(output_dir) / f"{figure_id}.png").resolve()) title = fig_spec.get("title", "") x_label = fig_spec.get("x_label", "") y_label = fig_spec.get("y_label", "") width_key = fig_spec.get("width", "single_column") # BUG-FIX: LLM may return data_source as a plain string (e.g. # "condition_comparison") instead of a dict. Normalize to dict. _raw_ds = fig_spec.get("data_source", {}) if isinstance(_raw_ds, str): data_source = {"type": _raw_ds} elif isinstance(_raw_ds, dict): data_source = _raw_ds else: data_source = {} from researchclaw.agents.figure_agent.style_config import FIGURE_WIDTH, DEFAULT_FIGURE_HEIGHT width = FIGURE_WIDTH.get(width_key, FIGURE_WIDTH["single_column"]) height = DEFAULT_FIGURE_HEIGHT # Try template-based generation first template = _TEMPLATES.get(chart_type) if template and not critic_feedback: script = self._fill_template( template=template, chart_type=chart_type, data_source=data_source, condition_summaries=condition_summaries, metrics_summary=metrics_summary, experiment_results=experiment_results, metric_key=metric_key, output_path=output_path, title=title, x_label=x_label, y_label=y_label, width=width, height=height, width_key=width_key, ) if script: return script # Fall back to LLM-generated script return self._llm_generate_script( fig_spec=fig_spec, chart_type=chart_type, condition_summaries=condition_summaries, metrics_summary=metrics_summary, experiment_results=experiment_results, metric_key=metric_key, output_path=output_path, width=width, height=height, critic_feedback=critic_feedback, width_key=width_key, ) def _fill_template( self, *, template: str, chart_type: str, data_source: dict[str, Any], condition_summaries: dict[str, Any], metrics_summary: dict[str, Any], experiment_results: dict[str, Any], metric_key: str, output_path: str, title: str, x_label: str, y_label: str, width: float, height: float, width_key: str = "single_column", ) -> str: """Fill a template with actual data values.""" style_preamble = get_style_preamble(width_key=width_key) source_type = data_source.get("type", "condition_comparison") if chart_type in ("bar_comparison", "ablation_grouped"): return self._fill_bar_template( template=template, condition_summaries=condition_summaries, metric_key=data_source.get("metric", metric_key), output_path=output_path, title=title, x_label=x_label, y_label=y_label, width=width, height=height, style_preamble=style_preamble, ) if chart_type == "grouped_bar" and source_type == "multi_metric": # BUG-37: LLM may return nested lists in metrics — flatten to list[str] _raw_metrics = data_source.get("metrics", []) _flat_metrics: list[str] = [] for _mi in (_raw_metrics if isinstance(_raw_metrics, list) else []): if isinstance(_mi, str): _flat_metrics.append(_mi) elif isinstance(_mi, list): _flat_metrics.extend(str(x) for x in _mi) else: _flat_metrics.append(str(_mi)) return self._fill_grouped_bar_template( template=template, condition_summaries=condition_summaries, metrics=_flat_metrics, output_path=output_path, title=title, x_label=x_label, y_label=y_label, width=width, height=height, style_preamble=style_preamble, ) if chart_type in ("heatmap", "confusion_matrix"): return self._fill_heatmap_template( template=template, condition_summaries=condition_summaries, metrics_summary=metrics_summary, output_path=output_path, title=title, x_label=x_label, y_label=y_label, width=width, height=height, style_preamble=style_preamble, ) # For other types, fall through to LLM generation return "" def _fill_bar_template( self, *, template: str, condition_summaries: dict[str, Any], metric_key: str, output_path: str, title: str, x_label: str, y_label: str, width: float, height: float, style_preamble: str, ) -> str: """Fill bar comparison template with condition data.""" conditions: list[str] = [] values: list[float] = [] ci_low: list[float] = [] ci_high: list[float] = [] for cond, cdata in condition_summaries.items(): if not isinstance(cdata, dict): continue metrics = cdata.get("metrics", {}) val = metrics.get(f"{metric_key}_mean") or metrics.get(metric_key) if val is None: continue try: fval = float(val) except (ValueError, TypeError): continue conditions.append(cond) values.append(fval) ci_low.append(float(cdata.get("ci95_low", fval))) ci_high.append(float(cdata.get("ci95_high", fval))) if not conditions: return "" # Skip degenerate data (all zeros, all identical) if _is_degenerate_data(values): logger.warning("Skipping degenerate bar chart: all values are identical or zero") return "" # Humanize empty/raw labels if not y_label or y_label.lower().replace("_", "") in ("primarymetric", "metric"): y_label = _humanize_label(metric_key) if not x_label: x_label = "Method" return template.format( style_preamble=style_preamble, conditions=repr(conditions), values=repr(values), ci_low=repr(ci_low), ci_high=repr(ci_high), output_path=output_path, title=_esc(title), x_label=_esc(x_label), y_label=_esc(y_label), width=width, height=height, ) def _fill_grouped_bar_template( self, *, template: str, condition_summaries: dict[str, Any], metrics: list[str], output_path: str, title: str, x_label: str, y_label: str, width: float, height: float, style_preamble: str, ) -> str: """Fill grouped bar template with multi-metric data.""" conditions: list[str] = list(condition_summaries.keys()) if not conditions or not metrics: return "" data_matrix: list[list[float]] = [] for cond in conditions: cdata = condition_summaries.get(cond, {}) cmetrics = cdata.get("metrics", {}) if isinstance(cdata, dict) else {} row = [] for m in metrics: val = cmetrics.get(f"{m}_mean") or cmetrics.get(m, 0) try: row.append(float(val)) except (ValueError, TypeError): row.append(0.0) data_matrix.append(row) return template.format( style_preamble=style_preamble, conditions=repr(conditions), metric_names=repr(metrics), data_matrix=repr(data_matrix), output_path=output_path, title=_esc(title), x_label=_esc(x_label), y_label=_esc(y_label), width=width, height=height, ) def _fill_heatmap_template( self, *, template: str, condition_summaries: dict[str, Any], metrics_summary: dict[str, Any], output_path: str, title: str, x_label: str, y_label: str, width: float, height: float, style_preamble: str, ) -> str: """Fill heatmap template — rows=conditions, cols=metrics.""" conditions = list(condition_summaries.keys()) # Select non-timing metrics metric_names = [ m for m in metrics_summary if not any(t in m.lower() for t in ["time", "elapsed", "seed", "runtime"]) ][:8] if not conditions or not metric_names: return "" data_matrix: list[list[float]] = [] for cond in conditions: cdata = condition_summaries.get(cond, {}) cmetrics = cdata.get("metrics", {}) if isinstance(cdata, dict) else {} row = [] for m in metric_names: val = cmetrics.get(f"{m}_mean") or cmetrics.get(m, 0) try: row.append(round(float(val), 4)) except (ValueError, TypeError): row.append(0.0) data_matrix.append(row) # Skip degenerate heatmaps (all values identical) all_vals = [v for row in data_matrix for v in row] if _is_degenerate_data(all_vals): logger.warning("Skipping degenerate heatmap: all values are identical or zero") return "" # Also skip single-row heatmaps (meaningless) if len(conditions) < 2: logger.warning("Skipping heatmap with only %d row(s)", len(conditions)) return "" return template.format( style_preamble=style_preamble, row_labels=repr(conditions), col_labels=repr(metric_names), data_matrix=repr(data_matrix), output_path=output_path, title=_esc(title), x_label=_esc(x_label or "Metric"), y_label=_esc(y_label or "Method"), width=max(width, len(metric_names) * 0.8), height=max(height, len(conditions) * 0.6), ) # ------------------------------------------------------------------ # LLM-based script generation # ------------------------------------------------------------------ def _llm_generate_script( self, *, fig_spec: dict[str, Any], chart_type: str, condition_summaries: dict[str, Any], metrics_summary: dict[str, Any], experiment_results: dict[str, Any], metric_key: str, output_path: str, width: float, height: float, critic_feedback: dict[str, Any] | None, width_key: str = "single_column", ) -> str: """Generate a plotting script using LLM.""" if self._output_format == "latex": return self._llm_generate_latex( fig_spec=fig_spec, chart_type=chart_type, condition_summaries=condition_summaries, metrics_summary=metrics_summary, metric_key=metric_key, width=width, height=height, critic_feedback=critic_feedback, ) style_preamble = get_style_preamble(width_key=width_key) system_prompt = ( "You are an expert scientific visualization programmer. " "Generate a standalone Python script that creates a publication-quality " "matplotlib chart.\n\n" "RULES:\n" "- The script must be completely self-contained (no external imports " "beyond matplotlib, numpy, seaborn)\n" "- All data values must be hardcoded in the script (no file I/O)\n" "- Use the provided style preamble at the top of the script\n" "- Output format: PNG at 300 DPI\n" "- Use colorblind-safe colors from the COLORS list\n" "- Include descriptive axis labels and title\n" "- Use constrained_layout=True in plt.subplots() — do NOT call fig.tight_layout()\n" "- Call fig.savefig() and plt.close(fig) at the end\n" "- Print 'Saved:
```
"""
files: dict[str, str] = {}
# Try named blocks first
for match in _CODE_BLOCK_RE.finditer(text):
fname = match.group(1).strip()
code = match.group(2).strip()
if fname and code:
# Normalize filename — strip path prefixes
fname = Path(fname).name
files[fname] = code
# If no named blocks, try unnamed and assume main.py
if not files:
for match in _UNNAMED_BLOCK_RE.finditer(text):
code = match.group(1).strip()
if code and len(code) > 50: # Skip tiny snippets
files["main.py"] = code
break
return files
# ---------------------------------------------------------------------------
# Helper: run experiment in sandbox
# ---------------------------------------------------------------------------
def _run_experiment_in_sandbox(
exp_dir: Path,
config: Any,
work_dir: Path,
timeout_sec: int = 600,
) -> dict | None:
"""Run experiment code in Docker/sandbox and return results dict.
Returns a dict with keys: stdout, stderr, returncode, metrics, elapsed_sec, timed_out.
Returns None if sandbox creation fails.
"""
try:
from researchclaw.experiment.factory import create_sandbox
sandbox_dir = work_dir / "sandbox"
sandbox_dir.mkdir(parents=True, exist_ok=True)
sandbox = create_sandbox(config.experiment, sandbox_dir)
result = sandbox.run_project(
exp_dir,
timeout_sec=timeout_sec,
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode,
"metrics": dict(result.metrics) if result.metrics else {},
"elapsed_sec": result.elapsed_sec,
"timed_out": result.timed_out,
}
except Exception as exc:
logger.warning("Sandbox execution failed: %s", exc)
return None
def _build_experiment_summary_from_run(
run_result: dict,
code: dict[str, str],
) -> dict:
"""Build an experiment_summary.json from a single sandbox run.
Parses condition-level metrics from stdout and builds the standard
summary format expected by ``assess_experiment_quality()``.
"""
metrics = run_result.get("metrics", {})
stdout = run_result.get("stdout", "")
# Also parse metrics from stdout if sandbox didn't capture them
if not metrics and stdout:
try:
from researchclaw.experiment.sandbox import parse_metrics
metrics = parse_metrics(stdout)
except ImportError:
pass
# Group metrics by condition
condition_summaries: dict[str, dict] = {}
for key, value in metrics.items():
if not isinstance(value, (int, float)):
continue
parts = key.split("/")
if len(parts) >= 3:
# Format: condition_name/seed/metric_name
cond_name = parts[0]
metric_name = parts[-1]
if cond_name not in condition_summaries:
condition_summaries[cond_name] = {"metrics": {}, "seeds": {}}
condition_summaries[cond_name]["metrics"][metric_name] = value
seed_key = "/".join(parts[1:-1])
condition_summaries[cond_name]["seeds"].setdefault(seed_key, {})[metric_name] = value
elif len(parts) == 2:
# BUG-199: Stage 13 refinement produces 2-part keys
# (condition_name/metric_name) without a seed component.
# Treat as a single-seed result.
cond_name, metric_name = parts
if cond_name not in condition_summaries:
condition_summaries[cond_name] = {"metrics": {}, "seeds": {}}
condition_summaries[cond_name]["metrics"][metric_name] = value
condition_summaries[cond_name]["seeds"].setdefault("0", {})[metric_name] = value
elif len(parts) == 1:
# Top-level metric like "primary_metric"
pass
# Compute per-condition mean metrics
for cond_name, cdata in condition_summaries.items():
seeds = cdata.get("seeds", {})
if seeds:
cdata["n_seeds"] = len(seeds)
# Average each metric across seeds
all_metrics: dict[str, list[float]] = {}
for seed_data in seeds.values():
for mk, mv in seed_data.items():
if isinstance(mv, (int, float)):
all_metrics.setdefault(mk, []).append(float(mv))
for mk, values in all_metrics.items():
if values:
cdata["metrics"][mk] = sum(values) / len(values)
# Remove seeds from final output (not standard format)
cdata.pop("seeds", None)
return {
"condition_summaries": condition_summaries,
"best_run": {
"metrics": metrics,
"status": "completed" if run_result.get("returncode") == 0 else "failed",
"stdout": stdout[:5000],
"stderr": run_result.get("stderr", "")[:2000],
},
"metrics_summary": {},
"total_conditions": len(condition_summaries),
"total_metric_keys": len(metrics),
}
================================================
FILE: researchclaw/pipeline/opencode_bridge.py
================================================
"""OpenCode 'Beast Mode' bridge — routes complex code generation to OpenCode CLI.
OpenCode (https://github.com/anomalyco/opencode) is an external AI coding agent
invoked via ``opencode run --format json "prompt"``. This module provides:
1. **ComplexityScore / score_complexity()** — analyses an experiment plan to
decide whether beast mode is warranted.
2. **OpenCodeBridge** — manages workspace creation, OpenCode invocation, file
collection, and cleanup.
"""
from __future__ import annotations
import ast
import json
import logging
import os
import re
import shutil
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Complexity scoring
# ---------------------------------------------------------------------------
# Keywords that indicate multi-component architectures
_COMPONENT_KEYWORDS: tuple[str, ...] = (
"encoder",
"decoder",
"discriminator",
"generator",
"critic",
"actor",
"teacher",
"student",
"backbone",
"head",
"neck",
"classifier",
"embedder",
"attention",
"transformer",
"tokenizer",
"vae",
"autoencoder",
)
# Indicators that multi-file generation is needed
_FILE_HINT_KEYWORDS: tuple[str, ...] = (
"model.py",
"trainer.py",
"dataset.py",
"utils.py",
"config.py",
"multiple files",
"modular",
"separate module",
"multi-file",
)
# Domain-complexity keywords
_DOMAIN_COMPLEX_KEYWORDS: tuple[str, ...] = (
"multi-modal",
"multimodal",
"distributed",
"gan",
"diffusion",
"nerf",
"mixture of experts",
"moe",
"meta-learning",
"meta learning",
"maml",
"neural ode",
"neural sde",
"physics-informed",
"pinn",
"graph neural",
"gnn",
"reinforcement learning",
"multi-agent",
"world model",
"vision-language",
"text-to-image",
"image-to-text",
)
# Patterns suggesting deep dependency chains
_DEPENDENCY_KEYWORDS: tuple[str, ...] = (
"custom layer",
"custom loss",
"wrapper",
"registry",
"hook",
"callback",
"scheduler",
"custom optimizer",
"custom dataset",
"custom sampler",
"custom transform",
)
@dataclass
class ComplexityScore:
"""Result of complexity analysis on an experiment plan."""
score: float # 0.0-1.0
signals: dict[str, float] = field(default_factory=dict)
recommendation: str = "" # "beast_mode" | "code_agent" | "legacy"
reason: str = ""
def _count_keyword_hits(text: str, keywords: tuple[str, ...]) -> int:
text_lower = text.lower()
return sum(1 for kw in keywords if kw in text_lower)
def score_complexity(
exp_plan: str,
topic: str = "",
*,
historical_failures: int = 0,
threshold: float = 0.6,
) -> ComplexityScore:
"""Score the complexity of an experiment to determine if beast mode is warranted.
Returns a ComplexityScore with score in [0.0, 1.0].
"""
if not exp_plan and not topic:
return ComplexityScore(
score=0.0,
signals={},
recommendation="legacy",
reason="Empty plan",
)
combined = f"{topic}\n{exp_plan}"
# Signal 1: Component count (weight 0.25)
comp_hits = _count_keyword_hits(combined, _COMPONENT_KEYWORDS)
component_score = min(comp_hits / 5.0, 1.0)
# Signal 2: File count hint (weight 0.20)
file_hits = _count_keyword_hits(combined, _FILE_HINT_KEYWORDS)
file_score = min(file_hits / 3.0, 1.0)
# Signal 3: Domain complexity (weight 0.20)
domain_hits = _count_keyword_hits(combined, _DOMAIN_COMPLEX_KEYWORDS)
domain_score = min(domain_hits / 3.0, 1.0)
# Signal 4: Condition count (weight 0.15)
# Look for numbered conditions, ablation mentions, variant mentions
condition_pattern = re.compile(
r"(?:condition|ablation|variant|experiment)\s*[\-_:]?\s*\d+",
re.IGNORECASE,
)
condition_matches = len(condition_pattern.findall(combined))
# Also count bullet points in conditions/ablations sections
condition_matches += combined.lower().count("baseline")
condition_score = min(condition_matches / 8.0, 1.0)
# Signal 5: Historical failures (weight 0.10)
failure_score = min(historical_failures / 3.0, 1.0)
# Signal 6: Dependency depth (weight 0.10)
dep_hits = _count_keyword_hits(combined, _DEPENDENCY_KEYWORDS)
dep_score = min(dep_hits / 3.0, 1.0)
# Weighted sum
weighted = (
0.25 * component_score
+ 0.20 * file_score
+ 0.20 * domain_score
+ 0.15 * condition_score
+ 0.10 * failure_score
+ 0.10 * dep_score
)
final_score = min(max(weighted, 0.0), 1.0)
signals = {
"component_count": round(component_score, 3),
"file_count_hint": round(file_score, 3),
"domain_complexity": round(domain_score, 3),
"condition_count": round(condition_score, 3),
"historical_failure": round(failure_score, 3),
"dependency_depth": round(dep_score, 3),
}
if final_score >= threshold:
recommendation = "beast_mode"
reason = (
f"Complexity {final_score:.2f} >= threshold {threshold:.2f}: "
f"top signals: "
+ ", ".join(
f"{k}={v:.2f}"
for k, v in sorted(signals.items(), key=lambda x: -x[1])[:3]
)
)
else:
recommendation = "code_agent"
reason = f"Complexity {final_score:.2f} < threshold {threshold:.2f}"
return ComplexityScore(
score=round(final_score, 4),
signals=signals,
recommendation=recommendation,
reason=reason,
)
# ---------------------------------------------------------------------------
# OpenCode bridge
# ---------------------------------------------------------------------------
@dataclass
class OpenCodeResult:
"""Result from an OpenCode invocation."""
success: bool
files: dict[str, str] = field(default_factory=dict)
opencode_log: str = ""
elapsed_sec: float = 0.0
error: str = ""
_MEGA_PROMPT_TEMPLATE = """\
You are implementing a complete, runnable ML/science experiment.
Read the files in the current workspace:
- EXPERIMENT_PLAN.yaml — the full experiment design
- GUIDANCE.md — topic, metric, environment constraints, domain-specific guidance
Your task:
1. Design the file structure (main.py is the required entry point).
2. Implement ALL files with complete, runnable code. No placeholders or TODOs.
3. main.py must be the entry point and print the primary metric as:
{metric}:
4. Include numerical stability guards (gradient clipping, NaN detection, etc.).
5. Use multi-seed evaluation (seeds 0, 1, 2) and report mean ± std.
6. Each ablation/condition MUST be genuinely different — not copy-paste with a renamed variable.
7. Implement a time guard: stop gracefully at 80% of the time budget ({time_budget_sec} seconds).
8. Write requirements.txt listing any extra pip packages needed.
9. If the experiment needs dataset downloads, write a setup.py that handles them.
IMPORTANT CONSTRAINTS:
- The code will run in an isolated Docker container with PyTorch, torchvision, and common ML packages pre-installed.
- Do NOT use argparse or CLI arguments — hardcode all configuration.
- All output must go to stdout (print statements).
- Keep the experiment feasible within {time_budget_sec} seconds total.
"""
class OpenCodeBridge:
"""Manages OpenCode CLI invocations for beast mode code generation."""
def __init__(
self,
*,
model: str = "",
llm_base_url: str = "",
api_key_env: str = "",
llm_provider: str = "openai-compatible",
timeout_sec: int = 600,
max_retries: int = 1,
workspace_cleanup: bool = True,
) -> None:
self._model = model
self._llm_base_url = llm_base_url
self._api_key_env = api_key_env
self._llm_provider = llm_provider
self._timeout_sec = timeout_sec
self._max_retries = max_retries
self._workspace_cleanup = workspace_cleanup
# -- availability check ---------------------------------------------------
@staticmethod
def check_available() -> bool:
"""Return True if the ``opencode`` CLI is installed and callable."""
opencode_cmd = shutil.which("opencode")
if not opencode_cmd:
return False
try:
result = subprocess.run(
[opencode_cmd, "--version"],
capture_output=True,
text=True,
timeout=15,
)
return result.returncode == 0
except FileNotFoundError:
return False
except subprocess.TimeoutExpired:
return False
except Exception: # noqa: BLE001
return False
# -- workspace preparation ------------------------------------------------
def _prepare_workspace(
self,
stage_dir: Path,
topic: str,
exp_plan: str,
metric: str,
pkg_hint: str,
extra_guidance: str,
time_budget_sec: int,
) -> Path:
"""Create a temporary workspace directory with context files."""
ws = stage_dir / f"opencode_beast_{int(time.time())}_{time.monotonic_ns() % 100000}"
ws.mkdir(parents=True, exist_ok=True)
# Write experiment plan
(ws / "EXPERIMENT_PLAN.yaml").write_text(
exp_plan or "# No experiment plan provided\n",
encoding="utf-8",
)
# Write guidance document
guidance_parts = [
f"# Experiment Guidance\n",
f"## Topic\n{topic}\n",
f"## Primary Metric\n{metric}\n",
f"## Time Budget\n{time_budget_sec} seconds\n",
]
if pkg_hint:
guidance_parts.append(f"## Environment\n{pkg_hint}\n")
if extra_guidance:
guidance_parts.append(f"## Additional Guidance\n{extra_guidance}\n")
(ws / "GUIDANCE.md").write_text(
"\n".join(guidance_parts), encoding="utf-8",
)
# Write opencode.json config
opencode_cfg = self._build_opencode_config()
(ws / "opencode.json").write_text(
json.dumps(opencode_cfg, indent=2), encoding="utf-8",
)
# OpenCode requires a git repository — initialise one with
# a single commit so that ``opencode run`` doesn't hang.
# BUG-OB-01/OB-02: Check return codes and catch TimeoutExpired.
try:
r = subprocess.run(
["git", "init"],
cwd=str(ws), capture_output=True, timeout=10,
)
if r.returncode != 0:
raise OSError(f"git init failed: {r.stderr}")
subprocess.run(
["git", "add", "-A"],
cwd=str(ws), capture_output=True, timeout=10,
)
subprocess.run(
["git", "-c", "user.email=beast@researchclaw",
"-c", "user.name=BeastMode",
"commit", "-m", "init workspace"],
cwd=str(ws), capture_output=True, timeout=10,
)
except subprocess.TimeoutExpired as exc:
raise OSError(f"git workspace init timed out: {exc}") from exc
return ws
def _is_azure(self) -> bool:
"""Detect Azure OpenAI from base URL or provider string."""
return (
"azure" in (self._llm_base_url or "").lower()
or "azure" in (self._llm_provider or "").lower()
)
def _build_opencode_config(self) -> dict[str, Any]:
"""Build the opencode.json configuration.
Always uses the "openai" provider — this works for both standard
OpenAI endpoints and Azure OpenAI (which accepts Bearer token auth
on the ``/openai/v1`` path and now supports the Responses API).
"""
cfg: dict[str, Any] = {
"$schema": "https://opencode.ai/config.json",
}
if self._llm_base_url:
if self._model:
cfg["model"] = (
self._model if "/" in self._model
else f"openai/{self._model}"
)
cfg["provider"] = {
"openai": {
"options": {
"baseURL": self._llm_base_url,
"apiKey": f"{{env:{self._api_key_env}}}"
if self._api_key_env
else "",
},
"models": {},
}
}
# Register the model so OpenCode knows it exists
if self._model:
model_name = self._model.split("/")[-1]
cfg["provider"]["openai"]["models"] = {
model_name: {
"name": model_name,
"modalities": {
"input": ["text"],
"output": ["text"],
},
}
}
elif self._model:
cfg["model"] = (
self._model if "/" in self._model
else f"openai/{self._model}"
)
return cfg
# -- model resolution -------------------------------------------------------
def _resolve_opencode_model(self) -> str:
"""Resolve the model identifier for OpenCode CLI's ``-m`` flag.
Resolution order:
1. If model already contains "/" (e.g. "anthropic/claude-sonnet-4-6") → use as-is
2. Otherwise → "openai/{model}" (works for both Azure and standard OpenAI)
Note: Azure AI Services now supports the Responses API with Bearer
token auth via the OpenAI-compatible endpoint, so we use the "openai"
provider universally — no Anthropic fallback needed.
"""
if not self._model:
return "anthropic/claude-sonnet-4-6"
if "/" in self._model:
return self._model
return f"openai/{self._model}"
# -- invocation ------------------------------------------------------------
def _invoke_opencode(
self,
workspace: Path,
prompt: str,
) -> tuple[bool, str, float]:
"""Run ``opencode run`` in the workspace. Returns (success, log, elapsed)."""
env = os.environ.copy()
# Pass API key via environment if configured
if self._api_key_env:
api_key = os.environ.get(self._api_key_env, "")
if api_key:
# We always use the "openai" provider for OpenCode now,
# which reads OPENAI_API_KEY (works for Azure too via
# Bearer token auth on the OpenAI-compatible endpoint).
env["OPENAI_API_KEY"] = api_key
# Use -m flag to specify model (more reliable than opencode.json)
resolved_model = self._resolve_opencode_model()
opencode_cmd = shutil.which("opencode") or "opencode"
cmd = [opencode_cmd, "run", "-m", resolved_model, "--format", "json", prompt]
t0 = time.monotonic()
try:
result = subprocess.run(
cmd,
cwd=str(workspace),
capture_output=True,
text=True,
timeout=self._timeout_sec,
env=env,
)
elapsed = time.monotonic() - t0
log = result.stdout + "\n" + result.stderr
return result.returncode == 0, log, elapsed
except subprocess.TimeoutExpired as exc:
elapsed = time.monotonic() - t0
log = f"TIMEOUT after {elapsed:.1f}s"
if exc.stdout:
log += f"\nstdout: {exc.stdout[:2000] if isinstance(exc.stdout, str) else exc.stdout.decode(errors='replace')[:2000]}"
return False, log, elapsed
except FileNotFoundError:
return False, "opencode CLI not found", 0.0
except Exception as exc: # noqa: BLE001
elapsed = time.monotonic() - t0
return False, f"Unexpected error: {exc}", elapsed
# -- file collection -------------------------------------------------------
@staticmethod
def _collect_files(workspace: Path) -> dict[str, str]:
"""Collect generated Python files, requirements.txt, and setup.py.
File names are flattened to basenames (e.g. ``src/main.py`` → ``main.py``)
because the downstream executor expects a flat file dict. If two files
share the same basename, the one closer to the workspace root wins.
"""
files: dict[str, str] = {}
# Sort by depth (fewer parts first) so root-level files take priority
py_files = sorted(
workspace.rglob("*.py"),
key=lambda p: len(p.relative_to(workspace).parts),
)
for py_file in py_files:
rel = py_file.relative_to(workspace)
parts = rel.parts
if any(p.startswith("__pycache__") or p.startswith(".") for p in parts):
continue
# Flatten to basename — executor expects flat structure
basename = rel.name
if basename not in files:
try:
files[basename] = py_file.read_text(encoding="utf-8", errors="replace")
except OSError as exc:
logger.warning("Beast mode: failed to read %s: %s", py_file, exc)
# Also collect requirements.txt and setup.py at root
for extra in ("requirements.txt", "setup.py"):
p = workspace / extra
if p.exists() and extra not in files:
files[extra] = p.read_text(encoding="utf-8", errors="replace")
return files
# -- entry-point validation ------------------------------------------------
@staticmethod
def _has_main_guard(source: str) -> bool:
"""Return True if *source* contains ``if __name__ == "__main__":``."""
try:
tree = ast.parse(source)
except SyntaxError:
return False
for node in ast.walk(tree):
if isinstance(node, ast.If):
test = node.test
if isinstance(test, ast.Compare) and isinstance(test.left, ast.Name):
if test.left.id == "__name__" and len(test.comparators) == 1:
comp = test.comparators[0]
if isinstance(comp, ast.Constant) and comp.value == "__main__":
return True
return False
@staticmethod
def _ensure_main_entry_point(files: dict[str, str]) -> dict[str, str]:
"""Ensure ``main.py`` has an ``if __name__ == "__main__"`` guard.
Beast Mode often generates multi-file projects where ``main.py`` is a
library module and the real entry point lives in another file (e.g.
``run_experiment.py``). Since the Docker sandbox always executes
``python3 main.py``, a library-only ``main.py`` exits immediately with
no output.
Strategy:
1. If ``main.py`` already has the guard → return unchanged.
2. Find the first other ``.py`` file that **does** have the guard.
3. Swap: rename that file to ``main.py`` and the old ``main.py`` to a
helper module (its original basename, or ``_lib.py``).
4. If no file has a guard, append a minimal stub to ``main.py`` that
calls the most likely entry function (``main()``, ``run()``, etc.).
"""
main_code = files.get("main.py", "")
if not main_code:
return files
if OpenCodeBridge._has_main_guard(main_code):
return files
# -- Strategy 2/3: find another file with the guard and swap -----------
for fname, code in files.items():
if fname == "main.py" or not fname.endswith(".py"):
continue
if OpenCodeBridge._has_main_guard(code):
logger.info(
"Beast mode: main.py lacks __main__ guard; swapping "
"entry point with %s",
fname,
)
new_files = dict(files)
# Rename original main.py → helper module
helper_name = fname # reuse the other file's name for old main
new_files[helper_name] = main_code
new_files["main.py"] = code
return new_files
# -- Strategy 4: inject a minimal entry point into main.py -------------
# Look for common entry functions defined in main.py
entry_func: str | None = None
try:
tree = ast.parse(main_code)
candidates = [
n.name
for n in ast.walk(tree)
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
and n.name in ("main", "run", "run_experiment", "train",
"run_experiments", "experiment", "run_all")
]
if candidates:
entry_func = candidates[0]
except SyntaxError:
pass
if entry_func:
logger.info(
"Beast mode: main.py lacks __main__ guard; injecting call "
"to %s()",
entry_func,
)
new_files = dict(files)
new_files["main.py"] = (
main_code.rstrip()
+ "\n\n\nif __name__ == \"__main__\":\n"
+ f" {entry_func}()\n"
)
return new_files
logger.warning(
"Beast mode: main.py lacks __main__ guard and no known entry "
"function found — experiment may exit without producing output",
)
return files
# -- main entry point ------------------------------------------------------
def generate(
self,
stage_dir: Path,
topic: str,
exp_plan: str,
metric: str,
pkg_hint: str = "",
extra_guidance: str = "",
time_budget_sec: int = 300,
) -> OpenCodeResult:
"""Run OpenCode to generate experiment code.
Returns an OpenCodeResult with success status and generated files.
"""
# Check availability first
if not self.check_available():
return OpenCodeResult(
success=False,
error="OpenCode CLI not installed or not callable",
)
workspace: Path | None = None
last_error = ""
for attempt in range(1 + self._max_retries):
# Prepare workspace
try:
workspace = self._prepare_workspace(
stage_dir=stage_dir,
topic=topic,
exp_plan=exp_plan,
metric=metric,
pkg_hint=pkg_hint,
extra_guidance=extra_guidance,
time_budget_sec=time_budget_sec,
)
except OSError as exc:
last_error = f"Failed to prepare workspace: {exc}"
logger.warning("Beast mode: %s", last_error)
continue
# Build the mega-prompt (use replace instead of .format() to
# avoid KeyError when metric contains curly braces like "F{1}")
prompt = _MEGA_PROMPT_TEMPLATE.replace(
"{metric}", metric
).replace(
"{time_budget_sec}", str(time_budget_sec)
)
logger.info(
"Beast mode: invoking OpenCode (attempt %d/%d, timeout=%ds)",
attempt + 1,
1 + self._max_retries,
self._timeout_sec,
)
success, log, elapsed = self._invoke_opencode(workspace, prompt)
if success:
files = self._collect_files(workspace)
if "main.py" not in files:
logger.warning(
"Beast mode: OpenCode succeeded but no main.py found "
"(files: %s)", list(files.keys()),
)
last_error = "No main.py in OpenCode output"
# Cleanup failed workspace
if self._workspace_cleanup and workspace.exists():
shutil.rmtree(workspace, ignore_errors=True)
continue
# BUG-R52-01: Ensure main.py has an entry point
files = self._ensure_main_entry_point(files)
# Write log
try:
(stage_dir / "opencode_log.txt").write_text(
log or "", encoding="utf-8",
)
except OSError as _wexc:
logger.warning("Beast mode: failed to write log: %s", _wexc)
# Cleanup workspace if configured
if self._workspace_cleanup and workspace.exists():
shutil.rmtree(workspace, ignore_errors=True)
return OpenCodeResult(
success=True,
files=files,
opencode_log=log,
elapsed_sec=elapsed,
)
last_error = log
logger.warning(
"Beast mode: OpenCode attempt %d failed (%.1fs): %s",
attempt + 1,
elapsed,
log[:500],
)
# Cleanup failed workspace
if self._workspace_cleanup and workspace and workspace.exists():
shutil.rmtree(workspace, ignore_errors=True)
# All attempts failed
return OpenCodeResult(
success=False,
opencode_log=last_error,
error=f"OpenCode failed after {1 + self._max_retries} attempt(s)",
)
# ---------------------------------------------------------------------------
# Helper: count historical failures
# ---------------------------------------------------------------------------
def count_historical_failures(run_dir: Path, stage_name: str = "stage-10") -> int:
"""Count past Stage 10 failures from stage directories and logs.
Each stage directory is counted at most once, even if multiple failure
indicators are present.
"""
failures = 0
for d in run_dir.glob(f"{stage_name}*"):
failed = False
# Check for beast_mode_log.json
bm_log = d / "beast_mode_log.json"
if bm_log.exists():
try:
data = json.loads(bm_log.read_text(encoding="utf-8"))
if not data.get("success", True):
failed = True
except Exception: # noqa: BLE001
pass
# Check for stage health failures
if not failed:
health = d / "stage_health.json"
if health.exists():
try:
data = json.loads(health.read_text(encoding="utf-8"))
if data.get("status") == "FAILED":
failed = True
except Exception: # noqa: BLE001
pass
# Check for validation report with FAILED status
if not failed:
vr = d / "validation_report.md"
if vr.exists():
try:
content = vr.read_text(encoding="utf-8")
if "BLOCKED" in content or "FAILED" in content:
failed = True
except Exception: # noqa: BLE001
pass
if failed:
failures += 1
return failures
================================================
FILE: researchclaw/pipeline/paper_verifier.py
================================================
"""Post-generation paper verification gate.
Extracts all numeric values from a generated LaTeX paper, compares them
against the ``VerifiedRegistry``, and rejects the paper if unverified
numbers appear in strict sections (Results, Experiments, Tables).
This is the **hard, deterministic** defense against fabrication.
"""
from __future__ import annotations
import logging
import math
import re
from dataclasses import dataclass, field
from pathlib import Path
from researchclaw.pipeline.verified_registry import VerifiedRegistry
logger = logging.getLogger(__name__)
# Numbers that are always allowed (years, common constants, etc.)
_ALWAYS_ALLOWED: set[float] = {
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0, 200.0,
0.5, 0.01, 0.001, 0.0001, 0.1, 0.05, 0.95, 0.99,
2024.0, 2025.0, 2026.0, 2027.0,
8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0,
224.0, 299.0, 384.0, # Common image sizes
# BUG-192: Common hyperparameter values
0.0003, 3e-4, 0.0005, 5e-4, 0.002, 2e-3, # learning rates
0.2, 0.3, 0.25, 0.7, 0.6, 0.8, # clip epsilon, dropout, gradient clip, GCE q, common HP
0.9, 0.999, 0.9999, # Adam betas, momentum
0.02, 0.03, # weight init std
1e-5, 1e-6, 1e-8, # epsilon, weight decay
300.0, 400.0, 500.0, # epochs
4096.0, 8192.0, # larger batch sizes / hidden dims
}
# Regex for extracting decimal numbers (including negative, scientific notation)
# NOTE: lookbehind/lookahead must NOT exclude { } — numbers inside \textbf{91.5}
# must still be extracted. We only exclude letters, underscore, and backslash.
_NUMBER_RE = re.compile(
r"(? float:
"""Fraction of numbers that are unverified."""
if self.total_numbers_checked == 0:
return 0.0
return len(self.unverified_numbers) / self.total_numbers_checked
def verify_paper(
tex_text: str,
registry: VerifiedRegistry,
*,
tolerance: float = 0.01,
strict_sections: set[str] | None = None,
lenient_sections: set[str] | None = None,
) -> VerificationResult:
"""Verify that all numbers in the paper are grounded in experiment data.
Parameters
----------
tex_text:
The full LaTeX source of the paper.
registry:
The verified value registry built from experiment data.
tolerance:
Relative tolerance for number matching (default 1%).
strict_sections:
Section names where unverified numbers cause REJECT.
lenient_sections:
Section names where unverified numbers cause WARNING only.
Returns
-------
VerificationResult
Contains pass/fail status, list of unverified numbers, and summary.
"""
if strict_sections is None:
strict_sections = _STRICT_SECTIONS
if lenient_sections is None:
lenient_sections = _LENIENT_SECTIONS
result = VerificationResult(passed=True, severity="PASS")
# 1. Parse sections
sections = _parse_sections(tex_text)
# 2. Find all tables (for in_table flag)
table_ranges = _find_table_ranges(tex_text)
# 3. Create skip mask (positions to ignore)
skip_mask = _build_skip_mask(tex_text)
# 4. Extract and verify numbers
lines = tex_text.split("\n")
for line_idx, line in enumerate(lines):
line_num = line_idx + 1
section = _section_at_line(sections, line_idx)
section_lower = section.lower() if section else ""
in_table = any(
start <= line_idx <= end and is_results
for start, end, is_results in table_ranges
)
for m in _NUMBER_RE.finditer(line):
num_str = m.group(1)
char_pos = _line_offset(lines, line_idx) + m.start()
# Skip if inside a skip zone
if skip_mask[char_pos]:
continue
try:
value = float(num_str)
except ValueError:
continue
if not math.isfinite(value):
continue
result.total_numbers_checked += 1
# Always-allowed numbers
if value in _ALWAYS_ALLOWED:
result.total_numbers_verified += 1
continue
# Integer-like small numbers (likely indices, counts, etc.)
# BUG-23 P1: In strict sections or tables, only auto-pass very small
# integers (≤5) — larger counts (e.g. "20 datasets") could be fabricated.
is_strict_ctx = _is_strict_section(section_lower, strict_sections) or in_table
_int_limit = 5 if is_strict_ctx else 20
if value == int(value) and abs(value) <= _int_limit:
result.total_numbers_verified += 1
continue
# Check against registry
if registry.is_verified(value, tolerance=tolerance):
result.total_numbers_verified += 1
continue
# UNVERIFIED — classify severity by section
ctx = line.strip()[:120]
unv = UnverifiedNumber(
value=value,
line_number=line_num,
context=ctx,
section=section or "(preamble)",
in_table=in_table,
)
result.unverified_numbers.append(unv)
is_strict = _is_strict_section(section_lower, strict_sections)
if is_strict or in_table:
result.strict_violations += 1
else:
result.lenient_violations += 1
# 5. Check for fabricated conditions
result.fabricated_conditions = _check_condition_names(tex_text, registry, lines)
# 5b. BUG-23 P2: Check training config claims (epochs, dataset, etc.)
result.config_warnings = _check_training_config(tex_text, registry)
# 6. Determine severity
if result.strict_violations > 0 or len(result.fabricated_conditions) > 0:
result.passed = False
result.severity = "REJECT"
elif result.lenient_violations > 0:
result.passed = True
result.severity = "WARN"
else:
result.passed = True
result.severity = "PASS"
# 7. Build summary
result.summary = _build_summary(result)
logger.info("Paper verification: %s", result.summary)
return result
def verify_paper_file(
tex_path: Path,
registry: VerifiedRegistry,
**kwargs,
) -> VerificationResult:
"""Convenience: verify from a file path."""
tex_text = tex_path.read_text(encoding="utf-8")
return verify_paper(tex_text, registry, **kwargs)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _parse_sections(tex_text: str) -> list[tuple[int, str]]:
"""Parse section headings and their line positions.
Returns list of (line_index, section_name) sorted by line_index.
"""
sections: list[tuple[int, str]] = []
lines = tex_text.split("\n")
for i, line in enumerate(lines):
m = _SECTION_RE.search(line)
if m:
sections.append((i, m.group(1).strip()))
return sections
def _section_at_line(sections: list[tuple[int, str]], line_idx: int) -> str | None:
"""Return the section name that contains the given line."""
current = None
for sec_line, sec_name in sections:
if sec_line <= line_idx:
current = sec_name
else:
break
return current
_STRICT_EXEMPT_KEYWORDS: set[str] = {
"dataset", "setup", "protocol", "hyperparameter", "implementation",
"hardware", "infrastructure", "notation", "preliminaries",
}
def _is_strict_section(section_lower: str, strict_set: set[str]) -> bool:
"""Check if a section name matches any strict section pattern.
BUG-R49-02: Sections like "Datasets and Evaluation Protocol" contain
the keyword "evaluation" but describe protocol parameters, not results.
Such sections are exempted when they also contain a setup/protocol keyword.
"""
if not section_lower:
return False
for strict_name in strict_set:
if strict_name in section_lower:
# Check for exemption: if the section also contains a
# setup/protocol keyword, it's not a results section.
if any(kw in section_lower for kw in _STRICT_EXEMPT_KEYWORDS):
return False
return True
return False
def _find_table_ranges(tex_text: str) -> list[tuple[int, int, bool]]:
"""Find line ranges of table environments.
Returns ``(start_line, end_line, is_results_table)`` tuples.
Hyperparameter / configuration tables (detected by ``\\caption`` keywords)
are marked ``is_results_table=False`` so the verifier skips strict checks
on their numeric content (BUG-192).
"""
_HP_CAPTION_KW = {
"hyperparameter", "hyper-parameter", "configuration", "config",
"setting", "training detail", "implementation detail",
}
ranges: list[tuple[int, int, bool]] = []
lines = tex_text.split("\n")
in_table = False
start = 0
for i, line in enumerate(lines):
if r"\begin{table" in line:
in_table = True
start = i
elif r"\end{table" in line and in_table:
# Scan table block for \caption to determine type
table_block = "\n".join(lines[start : i + 1]).lower()
is_hp = any(kw in table_block for kw in _HP_CAPTION_KW)
ranges.append((start, i, not is_hp))
in_table = False
return ranges
def _build_skip_mask(tex_text: str) -> list[bool]:
"""Build a per-character boolean mask of positions to skip."""
mask = [False] * len(tex_text)
for pattern in _SKIP_PATTERNS:
for m in pattern.finditer(tex_text):
for pos in range(m.start(), m.end()):
if pos < len(mask):
mask[pos] = True
return mask
def _line_offset(lines: list[str], line_idx: int) -> int:
"""Return the character offset of the start of a line."""
offset = 0
for i in range(line_idx):
offset += len(lines[i]) + 1 # +1 for newline
return offset
def _check_condition_names(
tex_text: str,
registry: VerifiedRegistry,
lines: list[str],
) -> list[FabricatedCondition]:
"""Check if the paper mentions condition names that never ran."""
fabricated: list[FabricatedCondition] = []
# Only check if we have known conditions
if not registry.condition_names:
return fabricated
# Build pattern of known condition names (exact match in text)
# Look for condition-like names that appear in tables or bold text
# This is heuristic — we look for unknown names that look like conditions
known_lower = {name.lower() for name in registry.condition_names}
# Common generic terms that should NOT be flagged as fabricated conditions
_GENERIC_TERMS = {
"method", "metric", "condition", "---", "",
"model", "approach", "variant", "architecture",
"ours", "average", "mean", "std", "total",
"baseline", "proposed", "ablation", "default",
"results", "table", "figure", "section",
}
def _is_candidate(name: str) -> bool:
"""Check if a cleaned name looks like a real condition name."""
return bool(
name
and name.lower() not in known_lower
and name.lower() not in _GENERIC_TERMS
and not name.startswith("\\")
and len(name) > 1
and not name.isdigit()
# BUG-DA8-15: Reject numeric-looking strings (e.g. "91.5" from \textbf{91.5})
and not re.match(r'^[\d.eE+\-]+$', name)
)
def _clean_latex(s: str) -> str:
s = re.sub(r"\\textbf\{([^}]*)\}", r"\1", s)
s = re.sub(r"\\textit\{([^}]*)\}", r"\1", s)
return s.replace("\\_", "_").strip()
_seen_names: set[str] = set()
# 1. Extract potential condition names from TABLE ROWS
for i, line in enumerate(lines):
if "&" in line and "\\\\" in line:
cells = line.split("&")
if cells:
cand_clean = _clean_latex(cells[0].strip().rstrip("\\").strip())
if _is_candidate(cand_clean) and cand_clean.lower() not in _seen_names:
_seen_names.add(cand_clean.lower())
fabricated.append(
FabricatedCondition(
name=cand_clean,
line_number=i + 1,
context=line.strip()[:120],
)
)
# 2. BUG-23 P2: Also check PROSE — bold/italic condition mentions in
# Results/Experiments sections that don't match known conditions.
_strict_sections_lower = {
"results", "experiments", "experimental results",
"evaluation", "ablation", "comparison",
}
sections = _parse_sections(tex_text)
for i, line in enumerate(lines):
section = _section_at_line(sections, i)
if not section or section.lower() not in _strict_sections_lower:
continue
# Find \textbf{CondName} or \textit{CondName} in prose
for m in re.finditer(r"\\text(?:bf|it)\{([^}]+)\}", line):
cand_clean = _clean_latex(m.group(1)).strip()
# Only flag multi-word or snake_case names that look like conditions
if (
_is_candidate(cand_clean)
and ("_" in cand_clean or cand_clean[0].isupper())
and cand_clean.lower() not in _seen_names
):
_seen_names.add(cand_clean.lower())
fabricated.append(
FabricatedCondition(
name=cand_clean,
line_number=i + 1,
context=line.strip()[:120],
)
)
return fabricated
def _check_training_config(
tex_text: str,
registry: VerifiedRegistry,
) -> list[str]:
"""BUG-23 P2: Check if paper claims about training config match reality.
Extracts epoch counts from paper text and compares against known
training_config from the registry. Returns list of warning strings.
"""
warnings: list[str] = []
# Extract "trained for N epochs" or "N epochs" claims
epoch_claims = re.findall(
r"(?:trained?\s+(?:for\s+)?|over\s+|(?:for|with)\s+)(\d+)\s+epoch",
tex_text,
re.IGNORECASE,
)
if epoch_claims and registry.training_config:
actual_steps = registry.training_config.get("TRAINING_STEPS")
actual_epochs = registry.training_config.get("epochs")
if actual_epochs is not None:
for claim in epoch_claims:
claimed = int(claim)
if abs(claimed - actual_epochs) > max(5, actual_epochs * 0.3):
warnings.append(
f"Paper claims {claimed} epochs but experiment ran {int(actual_epochs)} epochs"
)
elif actual_steps is not None:
# Can't compare epochs to steps directly, but flag very large claims
for claim in epoch_claims:
claimed = int(claim)
if claimed > 500:
warnings.append(
f"Paper claims {claimed} epochs — verify against actual training steps ({int(actual_steps)})"
)
# Check condition count claims ("N conditions" / "N methods" / "N baselines")
count_claims = re.findall(
r"(\d+)\s+(?:condition|method|baseline|approach|variant)s?\b",
tex_text,
re.IGNORECASE,
)
if count_claims and registry.condition_names:
actual_count = len(registry.condition_names)
for claim in count_claims:
claimed = int(claim)
if claimed > actual_count + 1:
warnings.append(
f"Paper claims {claimed} conditions/methods but only {actual_count} ran"
)
if warnings:
logger.warning("Training config validation: %s", warnings)
return warnings
def _build_summary(result: VerificationResult) -> str:
"""Build human-readable summary."""
parts = [f"severity={result.severity}"]
parts.append(
f"checked={result.total_numbers_checked}, "
f"verified={result.total_numbers_verified}, "
f"unverified={len(result.unverified_numbers)}"
)
if result.strict_violations:
parts.append(f"strict_violations={result.strict_violations}")
if result.fabricated_conditions:
names = [fc.name for fc in result.fabricated_conditions[:3]]
parts.append(f"fabricated_conditions={names}")
if result.config_warnings:
parts.append(f"config_warnings={len(result.config_warnings)}")
return "; ".join(parts)
================================================
FILE: researchclaw/pipeline/runner.py
================================================
from __future__ import annotations
import json
import importlib
import logging
import os
import shutil
import tempfile
import time as _time
from pathlib import Path
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.evolution import EvolutionStore, extract_lessons
from researchclaw.knowledge.base import write_stage_to_kb
from researchclaw.pipeline.executor import StageResult, execute_stage
from researchclaw.pipeline.stages import (
DECISION_ROLLBACK,
MAX_DECISION_PIVOTS,
NONCRITICAL_STAGES,
STAGE_SEQUENCE,
Stage,
StageStatus,
)
def _utcnow_iso() -> str:
from datetime import datetime, timezone
return datetime.now(timezone.utc).isoformat(timespec="seconds")
def _should_start(stage: Stage, from_stage: Stage, started: bool) -> bool:
if started:
return True
return stage == from_stage
def _build_pipeline_summary(
*,
run_id: str,
results: list[StageResult],
from_stage: Stage,
run_dir: Path | None = None,
) -> dict[str, object]:
summary: dict[str, object] = {
"run_id": run_id,
"stages_executed": len(results),
"stages_done": sum(1 for item in results if item.status == StageStatus.DONE),
"stages_blocked": sum(
1 for item in results if item.status == StageStatus.BLOCKED_APPROVAL
),
"stages_failed": sum(
1 for item in results if item.status == StageStatus.FAILED
),
"degraded": any(r.decision == "degraded" for r in results),
"from_stage": int(from_stage),
"final_stage": int(results[-1].stage) if results else int(from_stage),
"final_status": results[-1].status.value if results else "no_stages",
"generated": _utcnow_iso(),
"content_metrics": _collect_content_metrics(run_dir),
}
return summary
def _write_pipeline_summary(run_dir: Path, summary: dict[str, object]) -> None:
(run_dir / "pipeline_summary.json").write_text(
json.dumps(summary, indent=2),
encoding="utf-8",
)
def _write_checkpoint(run_dir: Path, stage: Stage, run_id: str) -> None:
"""Write checkpoint atomically via temp file + rename to prevent corruption."""
checkpoint = {
"last_completed_stage": int(stage),
"last_completed_name": stage.name,
"run_id": run_id,
"timestamp": _utcnow_iso(),
}
target = run_dir / "checkpoint.json"
fd, tmp_path = tempfile.mkstemp(dir=run_dir, suffix=".tmp", prefix="checkpoint_")
os.close(fd)
try:
with open(tmp_path, "w", encoding="utf-8") as fh:
fh.write(json.dumps(checkpoint, indent=2))
Path(tmp_path).replace(target)
except BaseException:
Path(tmp_path).unlink(missing_ok=True)
raise
def _write_heartbeat(run_dir: Path, stage: Stage, run_id: str) -> None:
"""Write heartbeat file for sentinel watchdog monitoring."""
import os
heartbeat = {
"pid": os.getpid(),
"last_stage": int(stage),
"last_stage_name": stage.name,
"run_id": run_id,
"timestamp": _utcnow_iso(),
}
(run_dir / "heartbeat.json").write_text(
json.dumps(heartbeat, indent=2), encoding="utf-8"
)
def read_checkpoint(run_dir: Path) -> Stage | None:
"""Read checkpoint and return the NEXT stage to execute, or None if no checkpoint."""
cp_path = run_dir / "checkpoint.json"
if not cp_path.exists():
return None
try:
data = json.loads(cp_path.read_text(encoding="utf-8"))
last_num = data.get("last_completed_stage")
if last_num is None:
return None
for i, stage in enumerate(STAGE_SEQUENCE):
if int(stage) == last_num:
if i + 1 < len(STAGE_SEQUENCE):
return STAGE_SEQUENCE[i + 1]
return None
return None
except (json.JSONDecodeError, TypeError, ValueError):
return None
def resume_from_checkpoint(
run_dir: Path, default_stage: Stage = Stage.TOPIC_INIT
) -> Stage:
"""Resolve the stage to resume from using checkpoint metadata."""
next_stage = read_checkpoint(run_dir)
return next_stage if next_stage is not None else default_stage
def _collect_content_metrics(run_dir: Path | None) -> dict[str, object]:
"""Collect content authenticity metrics from stage outputs."""
metrics: dict[str, object] = {
"template_ratio": None,
"citation_verify_score": None,
"total_citations": None,
"verified_citations": None,
"degraded_sources": [],
}
if run_dir is None:
return metrics
draft_path = run_dir / "stage-17" / "paper_draft.md"
if draft_path.exists():
try:
quality_module = importlib.import_module("researchclaw.quality")
compute_template_ratio = quality_module.compute_template_ratio
text = draft_path.read_text(encoding="utf-8")
metrics["template_ratio"] = round(compute_template_ratio(text), 4)
except (
AttributeError,
ModuleNotFoundError,
UnicodeDecodeError,
OSError,
ValueError,
TypeError,
):
pass
verify_path = run_dir / "stage-23" / "verification_report.json"
if verify_path.exists():
try:
vdata = json.loads(verify_path.read_text(encoding="utf-8"))
if isinstance(vdata, dict):
summary = vdata.get("summary", vdata)
total = summary.get("total", 0) if isinstance(summary, dict) else None
verified = summary.get("verified", 0) if isinstance(summary, dict) else None
if isinstance(total, int | float) and isinstance(verified, int | float):
total_num = int(total)
verified_num = int(verified)
metrics["total_citations"] = total_num
metrics["verified_citations"] = verified_num
if total_num > 0:
metrics["citation_verify_score"] = round(
verified_num / total_num, 4
)
except (json.JSONDecodeError, OSError, TypeError, ValueError):
pass
return metrics
logger = logging.getLogger(__name__)
def _run_experiment_diagnosis(run_dir: Path, config: RCConfig, run_id: str) -> None:
"""Run experiment diagnosis after Stage 14 and save reports.
Produces:
- ``run_dir/experiment_diagnosis.json`` — structured diagnosis + quality assessment
- ``run_dir/repair_prompt.txt`` — repair instructions (if quality is insufficient)
"""
try:
from researchclaw.pipeline.experiment_diagnosis import (
diagnose_experiment,
assess_experiment_quality,
)
# Find the most recent stage-14 experiment_summary.json
summary_path = None
for candidate in sorted(run_dir.glob("stage-14*/experiment_summary.json")):
summary_path = candidate
if not summary_path or not summary_path.exists():
return
summary = json.loads(summary_path.read_text(encoding="utf-8"))
# Collect stdout/stderr from experiment runs
stdout, stderr = "", ""
runs_dir = summary_path.parent / "runs"
if runs_dir.is_dir():
for run_file in sorted(runs_dir.glob("*.json"))[:5]:
try:
run_data = json.loads(run_file.read_text(encoding="utf-8"))
if isinstance(run_data, dict):
stdout += run_data.get("stdout", "") + "\n"
stderr += run_data.get("stderr", "") + "\n"
except (json.JSONDecodeError, OSError):
continue
# Load experiment plan from stage-09
plan = None
for candidate in sorted(run_dir.glob("stage-09*/exp_plan.yaml")):
try:
import yaml as _yaml_diag
plan = _yaml_diag.safe_load(candidate.read_text(encoding="utf-8"))
except Exception:
pass
if plan is None:
for candidate in sorted(run_dir.glob("stage-09*/experiment_design.json")):
try:
plan = json.loads(candidate.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
pass
# Load refinement log if available
ref_log = None
for candidate in sorted(run_dir.glob("stage-13*/refinement_log.json")):
try:
ref_log = json.loads(candidate.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
pass
# Run diagnosis
diag = diagnose_experiment(
experiment_summary=summary,
experiment_plan=plan,
refinement_log=ref_log,
stdout=stdout.strip(),
stderr=stderr.strip(),
)
# Run quality assessment
qa = assess_experiment_quality(summary, ref_log)
# Save diagnosis report
diag_report = {
"diagnosis": diag.to_dict(),
"quality_assessment": {
"mode": qa.mode.value,
"sufficient": qa.sufficient,
"repair_possible": qa.repair_possible,
"deficiency_types": [d.type.value for d in qa.deficiencies],
},
"repair_needed": not qa.sufficient,
"generated": _utcnow_iso(),
}
(run_dir / "experiment_diagnosis.json").write_text(
json.dumps(diag_report, indent=2), encoding="utf-8"
)
if not qa.sufficient:
# Generate repair prompt for the REFINE loop
from researchclaw.pipeline.experiment_repair import build_repair_prompt
code: dict[str, str] = {}
# Try refined code first, then stage-10 experiment dir, then raw stage-10
for _glob_pat in (
"stage-13*/experiment_final/*.py",
"stage-10*/experiment/*.py",
"stage-10*/*.py",
):
for candidate in sorted(run_dir.glob(_glob_pat)):
try:
code[candidate.name] = candidate.read_text(encoding="utf-8")
except (OSError, UnicodeDecodeError):
pass
if code:
break
repair_prompt = build_repair_prompt(
diag, code, time_budget_sec=config.experiment.time_budget_sec
)
(run_dir / "repair_prompt.txt").write_text(
repair_prompt, encoding="utf-8"
)
logger.info(
"[%s] Experiment diagnosis: mode=%s, deficiencies=%d — repair prompt saved",
run_id, qa.mode.value, len(diag.deficiencies),
)
print(
f"[{run_id}] Experiment diagnosis: {qa.mode.value} "
f"({len(diag.deficiencies)} issues found, repair needed)"
)
else:
logger.info(
"[%s] Experiment diagnosis: mode=%s, sufficient=True — quality OK",
run_id, qa.mode.value,
)
print(f"[{run_id}] Experiment diagnosis: {qa.mode.value} — quality OK")
except Exception as exc:
logger.warning("Experiment diagnosis failed: %s", exc)
def _run_experiment_repair(run_dir: Path, config: RCConfig, run_id: str) -> None:
"""Execute the experiment repair loop when diagnosis finds quality issues.
Calls the repair loop from ``experiment_repair.py`` which:
1. Loads experiment code and diagnosis
2. Gets fixes from LLM or OpenCode
3. Re-runs experiment in sandbox
4. Re-assesses quality
5. Repeats up to max_cycles
"""
try:
from researchclaw.pipeline.experiment_repair import run_repair_loop
repair_result = run_repair_loop(
run_dir=run_dir,
config=config,
run_id=run_id,
)
# Save repair result
(run_dir / "experiment_repair_result.json").write_text(
json.dumps(repair_result.to_dict(), indent=2), encoding="utf-8"
)
# BUG-186: Promote best experiment summary to stage-14/ so
# downstream stages (sanitizer, paper_verifier) see it.
# BUG-198: Only promote if the repair summary is RICHER than
# the existing stage-14 summary. The repair loop can produce
# empty summaries (metrics: {}, 0 conditions) which would
# overwrite enriched data from the analysis stage.
if repair_result.best_experiment_summary:
from researchclaw.pipeline.experiment_repair import (
_summary_quality_score,
)
best_path = run_dir / "stage-14" / "experiment_summary.json"
existing_score = 0.0
if best_path.exists():
try:
existing = json.loads(
best_path.read_text(encoding="utf-8")
)
existing_score = _summary_quality_score(existing)
except (json.JSONDecodeError, OSError):
pass
repair_score = _summary_quality_score(
repair_result.best_experiment_summary
)
if repair_score > existing_score:
best_path.write_text(
json.dumps(
repair_result.best_experiment_summary, indent=2
),
encoding="utf-8",
)
logger.info(
"[%s] Promoted repair results to stage-14 "
"(score %.1f > %.1f, success=%s)",
run_id, repair_score, existing_score,
repair_result.success,
)
else:
logger.info(
"[%s] Kept existing stage-14 summary (score %.1f >= "
"repair score %.1f)",
run_id, existing_score, repair_score,
)
if repair_result.success:
# Re-run diagnosis with updated results
_run_experiment_diagnosis(run_dir, config, run_id)
else:
logger.info(
"[%s] Repair loop completed without reaching full_paper quality "
"(best mode: %s, %d cycles)",
run_id, repair_result.final_mode.value, repair_result.total_cycles,
)
except Exception as exc:
logger.warning("[%s] Experiment repair failed: %s", run_id, exc)
print(f"[{run_id}] Experiment repair failed: {exc}")
def execute_pipeline(
*,
run_dir: Path,
run_id: str,
config: RCConfig,
adapters: AdapterBundle,
from_stage: Stage = Stage.TOPIC_INIT,
auto_approve_gates: bool = False,
stop_on_gate: bool = False,
skip_noncritical: bool = False,
kb_root: Path | None = None,
) -> list[StageResult]:
"""Execute pipeline stages sequentially from `from_stage` and write summary."""
results: list[StageResult] = []
started = False
total_stages = len(STAGE_SEQUENCE)
for stage in STAGE_SEQUENCE:
started = _should_start(stage, from_stage, started)
if not started:
continue
stage_num = int(stage)
prefix = f"[{run_id}] Stage {stage_num:02d}/{total_stages}"
print(f"{prefix} {stage.name} — running...")
# BUG-218: Ensure the best stage-14 experiment data is promoted
# BEFORE paper writing begins. Without this, the recursive REFINE
# path writes the paper using the latest (potentially worse)
# iteration's data, because the post-recursion promotion at line
# ~547 runs only after the recursive call—i.e. after the paper
# has already been written.
if stage == Stage.PAPER_OUTLINE:
_promote_best_stage14(run_dir, config)
t0 = _time.monotonic()
result = execute_stage(
stage,
run_dir=run_dir,
run_id=run_id,
config=config,
adapters=adapters,
auto_approve_gates=auto_approve_gates,
)
elapsed = _time.monotonic() - t0
if result.status == StageStatus.DONE:
arts = ", ".join(result.artifacts) if result.artifacts else "none"
if result.decision == "degraded":
print(
f"{prefix} {stage.name} — DEGRADED ({elapsed:.1f}s) "
f"— continuing with sanitization → {arts}"
)
else:
print(f"{prefix} {stage.name} — done ({elapsed:.1f}s) → {arts}")
elif result.status == StageStatus.FAILED:
err = result.error or "unknown error"
print(f"{prefix} {stage.name} — FAILED ({elapsed:.1f}s) — {err}")
elif result.status == StageStatus.BLOCKED_APPROVAL:
print(f"{prefix} {stage.name} — blocked (awaiting approval)")
results.append(result)
if kb_root is not None and result.status == StageStatus.DONE:
try:
stage_dir = run_dir / f"stage-{int(stage):02d}"
write_stage_to_kb(
kb_root,
stage_id=int(stage),
stage_name=stage.name.lower(),
run_id=run_id,
artifacts=list(result.artifacts),
stage_dir=stage_dir,
backend=config.knowledge_base.backend,
topic=config.research.topic,
)
except Exception: # noqa: BLE001
pass
if result.status == StageStatus.DONE:
_write_checkpoint(run_dir, stage, run_id)
# --- Experiment diagnosis + repair after Stage 14 (result_analysis) ---
if (
stage == Stage.RESULT_ANALYSIS
and result.status == StageStatus.DONE
and config.experiment.repair.enabled
):
_run_experiment_diagnosis(run_dir, config, run_id)
# Check if repair loop should run
_diag_path = run_dir / "experiment_diagnosis.json"
if _diag_path.exists():
try:
_diag_data = json.loads(_diag_path.read_text(encoding="utf-8"))
if _diag_data.get("repair_needed"):
_run_experiment_repair(run_dir, config, run_id)
except (json.JSONDecodeError, OSError):
pass
# --- Heartbeat for sentinel watchdog ---
_write_heartbeat(run_dir, stage, run_id)
# --- PIVOT/REFINE decision handling ---
if (
stage == Stage.RESEARCH_DECISION
and result.status == StageStatus.DONE
and result.decision in DECISION_ROLLBACK
):
pivot_count = _read_pivot_count(run_dir)
# R6-4: Skip REFINE if experiment metrics are empty for consecutive cycles
if pivot_count > 0 and _consecutive_empty_metrics(run_dir, pivot_count):
logger.warning(
"Consecutive REFINE cycles produced empty metrics — forcing PROCEED"
)
print(
f"[{run_id}] Consecutive empty metrics across REFINE cycles — forcing PROCEED"
)
# BUG-211: Promote best stage-14 before proceeding with
# empty data — an earlier iteration may have real metrics.
_promote_best_stage14(run_dir, config)
elif pivot_count < MAX_DECISION_PIVOTS:
rollback_target = DECISION_ROLLBACK[result.decision]
_record_decision_history(
run_dir, result.decision, rollback_target, pivot_count + 1
)
logger.info(
"Decision %s: rolling back to %s (attempt %d/%d)",
result.decision.upper(),
rollback_target.name,
pivot_count + 1,
MAX_DECISION_PIVOTS,
)
print(
f"[{run_id}] Decision: {result.decision.upper()} → "
f"rollback to {rollback_target.name} "
f"(attempt {pivot_count + 1}/{MAX_DECISION_PIVOTS})"
)
# Version existing stage directories before overwriting
_version_rollback_stages(
run_dir, rollback_target, pivot_count + 1
)
# Recurse from rollback target
pivot_results = execute_pipeline(
run_dir=run_dir,
run_id=run_id,
config=config,
adapters=adapters,
from_stage=rollback_target,
auto_approve_gates=auto_approve_gates,
stop_on_gate=stop_on_gate,
skip_noncritical=skip_noncritical,
kb_root=kb_root,
)
results.extend(pivot_results)
# BUG-211: Promote best stage-14 after REFINE completes so
# downstream stages use the best data, not just the latest.
_promote_best_stage14(run_dir, config)
break # Exit current loop; recursive call handles the rest
else:
# Quality gate: check if experiment results are actually usable
_quality_ok, _quality_msg = _check_experiment_quality(
run_dir, pivot_count
)
if not _quality_ok:
logger.warning(
"Max pivot attempts (%d) reached — forcing PROCEED "
"with quality warning: %s",
MAX_DECISION_PIVOTS,
_quality_msg,
)
print(
f"[{run_id}] QUALITY WARNING: {_quality_msg}"
)
# Write quality warning to run directory
_qw_path = run_dir / "quality_warning.txt"
_qw_path.write_text(
f"Max pivots ({MAX_DECISION_PIVOTS}) reached.\n"
f"Quality gate failed: {_quality_msg}\n"
f"Paper will be written but may have significant issues.\n",
encoding="utf-8",
)
else:
logger.warning(
"Max pivot attempts (%d) reached — forcing PROCEED",
MAX_DECISION_PIVOTS,
)
print(
f"[{run_id}] Max pivot attempts reached — forcing PROCEED"
)
# BUG-205: After forced PROCEED, promote the BEST stage-14
# experiment summary across all REFINE iterations.
_promote_best_stage14(run_dir, config)
if result.status == StageStatus.FAILED:
if skip_noncritical and stage in NONCRITICAL_STAGES:
logger.warning("Noncritical stage %s failed - skipping", stage.name)
else:
break
if result.status == StageStatus.BLOCKED_APPROVAL and stop_on_gate:
break
summary = _build_pipeline_summary(
run_id=run_id,
results=results,
from_stage=from_stage,
run_dir=run_dir,
)
_write_pipeline_summary(run_dir, summary)
# --- Evolution: extract and store lessons ---
lessons: list[object] = []
try:
lessons = extract_lessons(results, run_id=run_id, run_dir=run_dir)
if lessons:
store = EvolutionStore(run_dir / "evolution")
store.append_many(lessons)
logger.info("Extracted %d lessons from pipeline run", len(lessons))
except Exception: # noqa: BLE001
logger.warning("Evolution lesson extraction failed (non-blocking)")
# --- MetaClaw bridge: convert high-severity lessons to skills ---
try:
_metaclaw_post_pipeline(config, results, lessons, run_id, run_dir)
except Exception: # noqa: BLE001
logger.warning("MetaClaw post-pipeline hook failed (non-blocking)")
# --- Package deliverables into a single folder ---
try:
deliverables_dir = _package_deliverables(run_dir, run_id, config)
if deliverables_dir is not None:
print(f"[{run_id}] Deliverables packaged → {deliverables_dir}")
except Exception: # noqa: BLE001
logger.warning("Deliverables packaging failed (non-blocking)")
return results
def _package_deliverables(
run_dir: Path,
run_id: str,
config: RCConfig,
) -> Path | None:
"""Collect all final user-facing deliverables into a single ``deliverables/`` folder.
Returns the deliverables directory path, or None if nothing was packaged.
Packaged artifacts (best-available version selected automatically):
- paper_final.md — Final paper (Markdown)
- paper.tex — Conference-ready LaTeX
- references.bib — BibTeX bibliography
- code/ — Experiment code package
- verification_report.json — Citation verification report (if available)
"""
dest = run_dir / "deliverables"
dest.mkdir(parents=True, exist_ok=True)
packaged: list[str] = []
# --- 1. Final paper (Markdown) ---
# Prefer verified version (stage 23) over base version (stage 22)
paper_md = None
for candidate in [
run_dir / "stage-23" / "paper_final_verified.md",
run_dir / "stage-22" / "paper_final.md",
]:
if candidate.exists() and candidate.stat().st_size > 0:
paper_md = candidate
break
if paper_md is not None:
shutil.copy2(paper_md, dest / "paper_final.md")
packaged.append("paper_final.md")
# --- 2. LaTeX paper ---
# BUG-183: Stage 22's paper.tex has been sanitized (fabricated numbers
# replaced with ---). Regenerating from Markdown would undo this because
# the Markdown was never sanitized. Prefer Stage-22 paper.tex when a
# sanitization report exists. Only regenerate from verified Markdown if
# no sanitization was performed (i.e., the run was clean).
tex_regenerated = False
_sanitization_report = run_dir / "stage-22" / "sanitization_report.json"
_was_sanitized = _sanitization_report.exists()
verified_md = run_dir / "stage-23" / "paper_final_verified.md"
if (
not _was_sanitized
and paper_md is not None
and paper_md == verified_md
and verified_md.exists()
and verified_md.stat().st_size > 0
):
try:
from researchclaw.templates import get_template, markdown_to_latex
from researchclaw.pipeline.executor import _extract_paper_title
tpl = get_template(config.export.target_conference)
v_text = verified_md.read_text(encoding="utf-8")
tex_content = markdown_to_latex(
v_text,
tpl,
title=_extract_paper_title(v_text),
authors=config.export.authors,
bib_file=config.export.bib_file,
)
# IMP-17: Quality check — ensure regenerated LaTeX has
# proper structure (abstract, multiple sections)
_has_abstract = (
"\\begin{abstract}" in tex_content
and tex_content.split("\\begin{abstract}")[1]
.split("\\end{abstract}")[0]
.strip()
)
_section_count = tex_content.count("\\section{")
if _has_abstract and _section_count >= 3:
(dest / "paper.tex").write_text(tex_content, encoding="utf-8")
packaged.append("paper.tex")
tex_regenerated = True
logger.info(
"Deliverables: regenerated paper.tex from verified markdown"
)
else:
logger.warning(
"Regenerated paper.tex has poor structure "
"(abstract=%s, sections=%d) — using Stage 22 version",
bool(_has_abstract),
_section_count,
)
except Exception: # noqa: BLE001
logger.debug("paper.tex regeneration from verified md failed")
elif _was_sanitized:
logger.info(
"Deliverables: using Stage 22 paper.tex (sanitized) — "
"skipping markdown regeneration to preserve sanitization"
)
if not tex_regenerated:
tex_src = run_dir / "stage-22" / "paper.tex"
if tex_src.exists() and tex_src.stat().st_size > 0:
shutil.copy2(tex_src, dest / "paper.tex")
packaged.append("paper.tex")
# --- 3. References (BibTeX) ---
# Prefer verified bib (stage 23) over base bib (stage 22)
bib_src = None
for candidate in [
run_dir / "stage-23" / "references_verified.bib",
run_dir / "stage-22" / "references.bib",
]:
if candidate.exists() and candidate.stat().st_size > 0:
bib_src = candidate
break
if bib_src is not None:
shutil.copy2(bib_src, dest / "references.bib")
packaged.append("references.bib")
# --- 4. Experiment code package ---
code_src = run_dir / "stage-22" / "code"
if code_src.is_dir():
code_dest = dest / "code"
if code_dest.exists():
shutil.rmtree(code_dest)
shutil.copytree(code_src, code_dest)
packaged.append("code/")
# --- 5. Verification report (optional) ---
verify_src = run_dir / "stage-23" / "verification_report.json"
if verify_src.exists() and verify_src.stat().st_size > 0:
shutil.copy2(verify_src, dest / "verification_report.json")
packaged.append("verification_report.json")
# --- 5b. Sanitization report (degraded mode) ---
san_src = run_dir / "stage-22" / "sanitization_report.json"
if san_src.exists() and san_src.stat().st_size > 0:
shutil.copy2(san_src, dest / "sanitization_report.json")
packaged.append("sanitization_report.json")
# --- 6. Charts (optional) ---
charts_src = run_dir / "stage-22" / "charts"
if charts_src.is_dir() and any(charts_src.iterdir()):
charts_dest = dest / "charts"
if charts_dest.exists():
shutil.rmtree(charts_dest)
shutil.copytree(charts_src, charts_dest)
packaged.append("charts/")
# --- 7. Conference style files (.sty, .bst) ---
try:
from researchclaw.templates import get_template
tpl = get_template(config.export.target_conference)
style_files = tpl.get_style_files()
for sf in style_files:
shutil.copy2(sf, dest / sf.name)
packaged.append(sf.name)
if style_files:
logger.info(
"Deliverables: bundled %d style files for %s",
len(style_files),
tpl.display_name,
)
except Exception: # noqa: BLE001
logger.debug("Style file bundling skipped (template lookup failed)")
# --- 8. Verify & repair cite key coverage (IMP-12 + IMP-14) ---
tex_path = dest / "paper.tex"
bib_path = dest / "references.bib"
if tex_path.exists() and bib_path.exists():
try:
tex_text = tex_path.read_text(encoding="utf-8")
bib_text = bib_path.read_text(encoding="utf-8")
import re as _re
# IMP-15: Deduplicate .bib entries
_seen_bib_keys: set[str] = set()
_deduped_entries: list[str] = []
for _bm in _re.finditer(
r"(@\w+\{([^,]+),.*?\n\})", bib_text, _re.DOTALL
):
_bkey = _bm.group(2).strip()
if _bkey not in _seen_bib_keys:
_seen_bib_keys.add(_bkey)
_deduped_entries.append(_bm.group(1))
if len(_deduped_entries) < len(
list(_re.finditer(r"@\w+\{", bib_text))
):
bib_text = "\n\n".join(_deduped_entries) + "\n"
bib_path.write_text(bib_text, encoding="utf-8")
logger.info(
"Deliverables: deduplicated .bib → %d entries",
len(_deduped_entries),
)
# Collect all cite keys from \cite{key1, key2}
all_cite_keys: set[str] = set()
for cm in _re.finditer(r"\\cite\{([^}]+)\}", tex_text):
all_cite_keys.update(k.strip() for k in cm.group(1).split(","))
bib_keys = set(_re.findall(r"@\w+\{([^,]+),", bib_text))
missing = all_cite_keys - bib_keys
# IMP-14: Strip orphaned \cite{key} from paper.tex
if missing:
logger.warning(
"Deliverables: stripping %d orphaned cite keys from "
"paper.tex: %s",
len(missing),
sorted(missing)[:10],
)
def _filter_cite(m: _re.Match[str]) -> str:
keys = [k.strip() for k in m.group(1).split(",")]
kept = [k for k in keys if k not in missing]
if not kept:
return ""
return "\\cite{" + ", ".join(kept) + "}"
tex_text = _re.sub(r"\\cite\{([^}]+)\}", _filter_cite, tex_text)
# Clean up whitespace artifacts: double spaces, space before period
tex_text = _re.sub(r" +", " ", tex_text)
tex_text = _re.sub(r" ([.,;:)])", r"\1", tex_text)
tex_path.write_text(tex_text, encoding="utf-8")
logger.info(
"Deliverables: paper.tex repaired — all remaining cite "
"keys verified"
)
else:
logger.info(
"Deliverables: all %d cite keys verified in references.bib",
len(all_cite_keys),
)
except Exception: # noqa: BLE001
logger.debug("Cite key verification/repair skipped")
# --- 9. IMP-18: Compile LaTeX to verify paper.tex ---
if tex_path.exists() and bib_path.exists():
try:
from researchclaw.templates.compiler import compile_latex
compile_result = compile_latex(tex_path, max_attempts=3, timeout=120)
if compile_result.success:
logger.info("IMP-18: paper.tex compiles successfully")
# Keep the generated PDF
pdf_path = dest / tex_path.stem
pdf_file = dest / (tex_path.stem + ".pdf")
if pdf_file.exists():
packaged.append(f"{tex_path.stem}.pdf")
else:
logger.warning(
"IMP-18: paper.tex compilation failed after %d attempts: %s",
compile_result.attempts,
compile_result.errors[:3],
)
if compile_result.fixes_applied:
logger.info(
"IMP-18: Applied %d auto-fixes: %s",
len(compile_result.fixes_applied),
compile_result.fixes_applied,
)
except Exception: # noqa: BLE001
logger.debug("IMP-18: LaTeX compilation skipped (non-blocking)")
if not packaged:
# Nothing to package — remove empty dir
dest.rmdir()
return None
# --- Write manifest ---
manifest = {
"run_id": run_id,
"target_conference": config.export.target_conference,
"files": packaged,
"generated": _utcnow_iso(),
"notes": {
"paper_final.md": "Final paper in Markdown format",
"paper.tex": f"Conference-ready LaTeX ({config.export.target_conference})",
"references.bib": "BibTeX bibliography (verified citations only)",
"code/": "Experiment source code with requirements.txt",
"verification_report.json": "Citation integrity & relevance verification",
"charts/": "Result visualizations",
},
}
(dest / "manifest.json").write_text(
json.dumps(manifest, indent=2), encoding="utf-8"
)
logger.info(
"Deliverables packaged: %s (%d items)",
dest,
len(packaged),
)
return dest
def _version_rollback_stages(
run_dir: Path, rollback_target: Stage, attempt: int
) -> None:
"""Rename stage directories that will be overwritten by a PIVOT/REFINE.
For example, if rolling back to Stage 8 (attempt 2), renames:
stage-08/ → stage-08_v1/
stage-09/ → stage-09_v1/
... up to stage-15/
"""
import shutil
rollback_num = int(rollback_target)
# Stages from rollback target up to RESEARCH_DECISION (15) will be rerun
decision_num = int(Stage.RESEARCH_DECISION)
for stage_num in range(rollback_num, decision_num + 1):
stage_dir = run_dir / f"stage-{stage_num:02d}"
if stage_dir.exists():
version_dir = run_dir / f"stage-{stage_num:02d}_v{attempt}"
if version_dir.exists():
shutil.rmtree(version_dir)
stage_dir.rename(version_dir)
logger.debug(
"Versioned %s → %s", stage_dir.name, version_dir.name
)
def _consecutive_empty_metrics(run_dir: Path, pivot_count: int) -> bool:
"""R6-4: Check if the current and previous REFINE cycles both produced empty metrics."""
# Check the most recent experiment_summary.json (stage-14) and its versioned predecessor.
# BUG-215: When stage-14/ doesn't exist (renamed to stage-14_v{N} without
# promotion), fall back to the latest versioned directory as "current".
current = run_dir / "stage-14" / "experiment_summary.json"
if not current.exists():
# Try the latest versioned directory
for _v in range(pivot_count + 1, 0, -1):
alt = run_dir / f"stage-14_v{_v}" / "experiment_summary.json"
if alt.exists():
current = alt
break
prev = run_dir / f"stage-14_v{pivot_count}" / "experiment_summary.json"
for path in (current, prev):
if not path.exists():
return False
try:
data = json.loads(path.read_text(encoding="utf-8"))
# Check all possible metric locations
has_metrics = False
ms = data.get("metrics_summary", {})
if isinstance(ms, dict) and ms:
has_metrics = True
br = data.get("best_run", {})
if isinstance(br, dict) and br.get("metrics"):
has_metrics = True
if has_metrics:
return False # At least one cycle had real metrics
except (json.JSONDecodeError, OSError, AttributeError):
return False
return True # Both cycles had empty metrics
def _promote_best_stage14(run_dir: Path, config: RCConfig) -> None:
"""BUG-205: After forced PROCEED, promote the best stage-14 experiment.
Scans all ``stage-14*`` directories, scores them by primary metric,
and copies the best experiment_summary.json into ``stage-14/`` if the
current ``stage-14/`` is not already the best.
"""
import shutil
metric_key = config.experiment.metric_key or "primary_metric"
metric_dir = config.experiment.metric_direction or "maximize"
candidates: list[tuple[float, Path]] = []
for d in sorted(run_dir.glob("stage-14*")):
summary_path = d / "experiment_summary.json"
if not summary_path.exists():
continue
try:
data = json.loads(summary_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
continue
ms = data.get("metrics_summary", {})
pm_val: float | None = None
# BUG-DA8-03: Exact match first, then substring fallback
# (avoids "accuracy" matching "balanced_accuracy")
if metric_key in ms:
_v = ms[metric_key]
try:
pm_val = float(_v["mean"] if isinstance(_v, dict) else _v)
except (TypeError, ValueError, KeyError):
pass
if pm_val is None:
for k, v in ms.items():
if metric_key in k:
try:
pm_val = float(v["mean"] if isinstance(v, dict) else v)
except (TypeError, ValueError, KeyError):
pass
break
if pm_val is not None:
candidates.append((pm_val, d))
if not candidates:
return # nothing to promote
current_dir = run_dir / "stage-14"
# Sort: best first
candidates.sort(key=lambda x: x[0], reverse=(metric_dir == "maximize"))
# BUG-226: Detect degenerate near-zero metrics (broken normalization or
# collapsed training). When minimising, a value >1000x smaller than the
# second-best almost certainly comes from a degenerate iteration.
if metric_dir == "minimize" and len(candidates) > 1:
_bv, _bd = candidates[0]
_sv = candidates[1][0]
if 0 < _bv < _sv * 1e-3:
logger.warning(
"BUG-226: Degenerate best value %.6g is >1000× smaller than "
"second-best %.6g — skipping degenerate iteration %s",
_bv, _sv, _bd.name,
)
candidates.pop(0)
best_val, best_dir = candidates[0]
# BUG-223: Always write canonical best summary at run root BEFORE any
# early return, so downstream consumers (Stage 17, Stage 20, Stage 22,
# VerifiedRegistry) always find experiment_summary_best.json.
_best_src = best_dir / "experiment_summary.json"
if _best_src.exists():
shutil.copy2(_best_src, run_dir / "experiment_summary_best.json")
logger.info(
"BUG-223: Wrote experiment_summary_best.json from %s (%.4f)",
best_dir.name, best_val,
)
# BUG-225: Also copy analysis.md from the best iteration so Stage 17
# doesn't read stale analysis from a degenerate non-versioned stage-14.
_best_analysis = best_dir / "analysis.md"
if _best_analysis.exists():
shutil.copy2(_best_analysis, run_dir / "analysis_best.md")
if best_dir == current_dir:
logger.info("BUG-205: stage-14/ already has the best result (%.4f)", best_val)
return
# Promote: copy best summary into stage-14/
current_summary = current_dir / "experiment_summary.json"
best_summary = best_dir / "experiment_summary.json"
# BUG-213: Also promote when stage-14/ is missing or empty
if best_summary.exists():
current_dir.mkdir(parents=True, exist_ok=True)
logger.warning(
"BUG-205: Promoting %s (%.4f) over stage-14/",
best_dir.name, best_val,
)
shutil.copy2(best_summary, current_summary)
# Also copy charts, analysis, and figure plans if they exist
for fname in [
"analysis.md",
"results_table.tex",
"figure_plan.json", # BUG-213: must travel with metrics
"figure_plan_final.json", # BUG-213: ditto
]:
src = best_dir / fname
if src.exists():
shutil.copy2(src, current_dir / fname)
# Copy charts directory
best_charts = best_dir / "charts"
current_charts = current_dir / "charts"
if best_charts.is_dir():
if current_charts.is_dir():
shutil.rmtree(current_charts)
shutil.copytree(best_charts, current_charts)
def _check_experiment_quality(
run_dir: Path, pivot_count: int
) -> tuple[bool, str]:
"""Quality gate before forced PROCEED.
Returns (ok, message). ok=False means experiment results have critical
quality issues and the forced-PROCEED paper will likely be poor.
"""
# BUG-DA8-18: Check experiment_summary_best.json first (repair-promoted)
summary_path = run_dir / "experiment_summary_best.json"
if not summary_path.exists():
summary_path = run_dir / "stage-14" / "experiment_summary.json"
if not summary_path.exists():
for v in range(pivot_count, 0, -1):
alt = run_dir / f"stage-14_v{v}" / "experiment_summary.json"
if alt.exists():
summary_path = alt
break
if not summary_path.exists():
return False, "No experiment_summary.json found — no metrics produced"
try:
data = json.loads(summary_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
return False, "experiment_summary.json is malformed"
# Check 1: Are all metrics zero?
ms = data.get("metrics_summary", {})
if isinstance(ms, dict):
values: list[float] = []
for k, v in ms.items():
if isinstance(v, (int, float)):
values.append(float(v))
# BUG-212: metrics_summary values are often dicts {min,max,mean,count}
elif isinstance(v, dict) and "mean" in v:
_mv = v["mean"]
if isinstance(_mv, (int, float)):
values.append(float(_mv))
if values and all(v == 0.0 for v in values):
return False, "All experiment metrics are zero — experiments likely failed"
# Check 2: Zero variance across conditions (R13-1)
# Look for ablation_warnings or condition comparison data
ablation_warnings = data.get("ablation_warnings", [])
# BUG-212: Key is "condition_summaries", not "conditions"
conditions = data.get(
"condition_summaries", data.get("condition_metrics", {})
)
if isinstance(conditions, dict) and len(conditions) >= 2:
primary_values: list[float] = []
for cond_name, cond_data in conditions.items():
if isinstance(cond_data, dict):
# BUG-212: Primary metric lives inside cond_data["metrics"]
_metrics = cond_data.get("metrics", cond_data)
pm = _metrics.get(
"primary_metric",
_metrics.get("primary_metric_mean"),
)
if isinstance(pm, (int, float)):
primary_values.append(float(pm))
if len(primary_values) >= 2 and len(set(primary_values)) == 1:
return False, (
f"All {len(primary_values)} conditions have identical primary_metric "
f"({primary_values[0]}) — condition implementations are likely broken"
)
# Check 3: Too many ablation warnings
if isinstance(ablation_warnings, list) and len(ablation_warnings) >= 3:
return False, (
f"{len(ablation_warnings)} ablation warnings — most conditions "
f"produce identical results"
)
# Check 4: Analysis quality score (if available)
quality = data.get("analysis_quality", data.get("quality_score"))
if isinstance(quality, (int, float)) and quality < 3.0:
return False, f"Analysis quality score {quality}/10 — below minimum threshold"
return True, "Quality checks passed"
def _read_pivot_count(run_dir: Path) -> int:
"""Read how many PIVOT/REFINE decisions have been made so far."""
history_path = run_dir / "decision_history.json"
if not history_path.exists():
return 0
try:
data = json.loads(history_path.read_text(encoding="utf-8"))
if isinstance(data, list):
return len(data)
except (json.JSONDecodeError, OSError):
pass
return 0
def _record_decision_history(
run_dir: Path, decision: str, rollback_target: Stage, attempt: int
) -> None:
"""Append a decision event to the history log."""
history_path = run_dir / "decision_history.json"
history: list[dict[str, object]] = []
if history_path.exists():
try:
data = json.loads(history_path.read_text(encoding="utf-8"))
if isinstance(data, list):
history = data
except (json.JSONDecodeError, OSError):
pass
history.append({
"decision": decision,
"rollback_target": rollback_target.name,
"rollback_stage_num": int(rollback_target),
"attempt": attempt,
"timestamp": _utcnow_iso(),
})
history_path.write_text(
json.dumps(history, indent=2), encoding="utf-8"
)
logger = logging.getLogger(__name__)
def _read_quality_score(run_dir: Path) -> float | None:
"""Extract quality score from the most recent quality_report.json."""
report_path = run_dir / "stage-20" / "quality_report.json"
if not report_path.exists():
return None
try:
data = json.loads(report_path.read_text(encoding="utf-8"))
if isinstance(data, dict):
# Try common keys: score_1_to_10, score, quality_score
for key in ("score_1_to_10", "score", "quality_score", "overall_score"):
if key in data:
return float(data[key])
except (json.JSONDecodeError, ValueError, TypeError):
pass
return None
def _write_iteration_context(
run_dir: Path, iteration: int, reviews: str, quality_score: float | None
) -> None:
"""Write iteration feedback file so next round can read it."""
ctx = {
"iteration": iteration,
"quality_score": quality_score,
"reviews_excerpt": reviews[:3000] if reviews else "",
"generated": _utcnow_iso(),
}
(run_dir / "iteration_context.json").write_text(
json.dumps(ctx, indent=2), encoding="utf-8"
)
def execute_iterative_pipeline(
*,
run_dir: Path,
run_id: str,
config: RCConfig,
adapters: AdapterBundle,
auto_approve_gates: bool = False,
kb_root: Path | None = None,
max_iterations: int = 3,
quality_threshold: float = 7.0,
convergence_rounds: int = 2,
) -> dict[str, object]:
"""Run the full pipeline with iterative quality improvement.
After the first full pass (stages 1-22), if the quality gate score is below
*quality_threshold*, re-run stages 16-22 (paper writing + finalization) with
review feedback injected. Stop when:
- Score >= quality_threshold, OR
- Score hasn't improved for *convergence_rounds* consecutive iterations, OR
- *max_iterations* reached.
Returns a summary dict with iteration history.
"""
iteration_scores: list[float | None] = []
all_results: list[list[StageResult]] = []
# --- First full pass ---
logger.info("Iteration 1/%d: running full pipeline (stages 1-22)", max_iterations)
results = execute_pipeline(
run_dir=run_dir,
run_id=f"{run_id}-iter1",
config=config,
adapters=adapters,
auto_approve_gates=auto_approve_gates,
kb_root=kb_root,
)
all_results.append(results)
score = _read_quality_score(run_dir)
iteration_scores.append(score)
logger.info("Iteration 1 score: %s", score)
# --- Iterative improvement ---
for iteration in range(2, max_iterations + 1):
# Check if we've met quality threshold
if score is not None and score >= quality_threshold:
logger.info(
"Quality threshold %.1f met (score=%.1f). Stopping.",
quality_threshold,
score,
)
break
# Check convergence (score hasn't improved)
if len(iteration_scores) >= convergence_rounds:
recent = iteration_scores[-convergence_rounds:]
if all(s is not None for s in recent):
recent_scores = [float(s) for s in recent if s is not None]
if max(recent_scores) - min(recent_scores) < 0.5:
logger.info(
"Convergence detected: scores %s unchanged for %d rounds. Stopping.",
recent,
convergence_rounds,
)
break
# Write iteration context with feedback from reviews
reviews_text = ""
reviews_path = run_dir / "stage-18" / "reviews.md"
if reviews_path.exists():
reviews_text = reviews_path.read_text(encoding="utf-8")
_write_iteration_context(run_dir, iteration, reviews_text, score)
# Re-run from PAPER_OUTLINE (stage 16) through EXPORT_PUBLISH (stage 22)
logger.info(
"Iteration %d/%d: re-running stages 16-22 with feedback",
iteration,
max_iterations,
)
results = execute_pipeline(
run_dir=run_dir,
run_id=f"{run_id}-iter{iteration}",
config=config,
adapters=adapters,
from_stage=Stage.PAPER_OUTLINE,
auto_approve_gates=auto_approve_gates,
kb_root=kb_root,
)
all_results.append(results)
score = _read_quality_score(run_dir)
iteration_scores.append(score)
logger.info("Iteration %d score: %s", iteration, score)
# --- Build iterative summary ---
converged = False
if len(iteration_scores) >= convergence_rounds:
recent_window = iteration_scores[-convergence_rounds:]
if all(s is not None for s in recent_window):
recent_scores = [float(s) for s in recent_window if s is not None]
converged = max(recent_scores) - min(recent_scores) < 0.5
summary: dict[str, object] = {
"run_id": run_id,
"total_iterations": len(iteration_scores),
"iteration_scores": iteration_scores,
"quality_threshold": quality_threshold,
"converged": converged,
"final_score": iteration_scores[-1] if iteration_scores else None,
"met_threshold": score is not None and score >= quality_threshold,
"stages_per_iteration": [len(r) for r in all_results],
"generated": _utcnow_iso(),
}
(run_dir / "iteration_summary.json").write_text(
json.dumps(summary, indent=2, default=str), encoding="utf-8"
)
# --- Package deliverables into a single folder ---
try:
deliverables_dir = _package_deliverables(run_dir, run_id, config)
if deliverables_dir is not None:
print(f"[{run_id}] Deliverables packaged → {deliverables_dir}")
except Exception: # noqa: BLE001
logger.warning("Deliverables packaging failed (non-blocking)")
return summary
def _metaclaw_post_pipeline(
config: RCConfig,
results: list[StageResult],
lessons: list[object],
run_id: str,
run_dir: Path,
) -> None:
"""MetaClaw bridge: post-pipeline hook.
1. Convert high-severity lessons into MetaClaw skills.
2. Record skill effectiveness feedback.
3. Signal session end to MetaClaw proxy.
"""
bridge = getattr(config, "metaclaw_bridge", None)
if not bridge or not getattr(bridge, "enabled", False):
return
from researchclaw.llm.client import LLMClient
# 1. Lesson-to-skill conversion
l2s = getattr(bridge, "lesson_to_skill", None)
if l2s and getattr(l2s, "enabled", False) and lessons:
try:
from researchclaw.metaclaw_bridge.lesson_to_skill import (
convert_lessons_to_skills,
)
min_sev = getattr(l2s, "min_severity", "warning")
llm = LLMClient.from_rc_config(config)
new_skills = convert_lessons_to_skills(
lessons,
llm,
getattr(bridge, "skills_dir", "~/.metaclaw/skills"),
min_severity=min_sev,
max_skills=getattr(l2s, "max_skills_per_run", 3),
)
if new_skills:
logger.info(
"MetaClaw: generated %d new skills from lessons: %s",
len(new_skills),
new_skills,
)
except Exception: # noqa: BLE001
logger.warning("MetaClaw lesson-to-skill conversion failed", exc_info=True)
# 2. Skill effectiveness feedback
try:
from researchclaw.metaclaw_bridge.skill_feedback import (
SkillFeedbackStore,
record_stage_skills,
)
from researchclaw.metaclaw_bridge.stage_skill_map import get_stage_config
feedback_store = SkillFeedbackStore(run_dir / "evolution" / "skill_effectiveness.jsonl")
for result in results:
stage_num = int(getattr(result, "stage", 0))
stage_name = {
1: "topic_init", 2: "problem_decompose", 3: "search_strategy",
4: "literature_collect", 5: "literature_screen", 6: "knowledge_extract",
7: "synthesis", 8: "hypothesis_gen", 9: "experiment_design",
10: "code_generation", 11: "resource_planning", 12: "experiment_run",
13: "iterative_refine", 14: "result_analysis", 15: "research_decision",
16: "paper_outline", 17: "paper_draft", 18: "peer_review",
19: "paper_revision", 20: "quality_gate", 21: "knowledge_archive",
22: "export_publish", 23: "citation_verify",
}.get(stage_num, "")
if not stage_name:
continue
stage_config = get_stage_config(stage_name)
active_skills = stage_config.get("skills", [])
status = str(getattr(result, "status", ""))
success = "done" in status.lower()
if active_skills:
record_stage_skills(
feedback_store,
stage_name,
run_id,
success,
active_skills,
)
except Exception: # noqa: BLE001
logger.warning("MetaClaw skill feedback recording failed")
# 3. Signal session end (fire-and-forget)
try:
from researchclaw.metaclaw_bridge.session import MetaClawSession
import json as _json
import urllib.request as _urllib_req
session = MetaClawSession(run_id)
end_headers = session.end()
# Send a minimal request to signal session end
proxy_url = getattr(bridge, "proxy_url", "http://localhost:30000")
url = f"{proxy_url.rstrip('/')}/v1/chat/completions"
body = _json.dumps({
"model": "session-end",
"messages": [{"role": "user", "content": "session complete"}],
"max_tokens": 1,
}).encode("utf-8")
headers = {"Content-Type": "application/json"}
headers.update(end_headers)
req = _urllib_req.Request(url, data=body, headers=headers)
try:
_urllib_req.urlopen(req, timeout=5)
except Exception: # noqa: BLE001
pass # Best-effort signal
except Exception: # noqa: BLE001
pass
================================================
FILE: researchclaw/pipeline/stage_impls/__init__.py
================================================
"""Stage implementation modules for the research pipeline executor."""
================================================
FILE: researchclaw/pipeline/stage_impls/_analysis.py
================================================
"""Stages 14-15: Result analysis and research decision."""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain, _is_ml_domain
from researchclaw.pipeline._helpers import (
StageResult,
_build_context_preamble,
_chat_with_prompt,
_collect_experiment_results,
_collect_json_context,
_get_evolution_overlay,
_multi_perspective_generate,
_read_prior_artifact,
_safe_json_loads,
_synthesize_perspectives,
_utcnow_iso,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
def _execute_result_analysis(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
# --- Collect experiment data ---
exp_data = _collect_experiment_results(
run_dir,
metric_key=config.experiment.metric_key,
metric_direction=config.experiment.metric_direction,
)
runs_dir = _read_prior_artifact(run_dir, "runs/") or ""
context = ""
if runs_dir:
context = _collect_json_context(Path(runs_dir), max_files=30)
# --- R13-1: Merge Stage 13 (ITERATIVE_REFINE) results if available ---
# Stage 13 stores richer per-condition metrics in refinement_log.json
# that _collect_experiment_results() misses (it only scans runs/ dirs).
_refine_log_text = _read_prior_artifact(run_dir, "refinement_log.json")
if _refine_log_text:
try:
_refine_data = json.loads(_refine_log_text)
_best_iter = None
_best_ver = _refine_data.get("best_version", "")
def _get_best_sandbox(it: dict) -> dict:
"""BUG-181: Metrics may be in sandbox or sandbox_after_fix."""
sbx = it.get("sandbox", {})
if sbx.get("metrics"):
return sbx
sbx_fix = it.get("sandbox_after_fix", {})
if sbx_fix.get("metrics"):
return sbx_fix
return sbx
for _it in _refine_data.get("iterations", []):
_sbx = _get_best_sandbox(_it)
_it_metrics = _sbx.get("metrics", {})
if _it.get("version_dir", "") == _best_ver and _it_metrics:
_best_iter = _it
break
# If no version match, take the first iteration with metrics
if _best_iter is None:
for _it in _refine_data.get("iterations", []):
_sbx = _get_best_sandbox(_it)
if _sbx.get("metrics"):
_best_iter = _it
break
if _best_iter is not None:
_sbx = _get_best_sandbox(_best_iter)
_refine_metrics = _sbx.get("metrics", {})
# BUG-165 fix: Prefer Stage 13 refinement data when it is
# actually better. The old `or True` unconditionally
# replaced existing data, causing catastrophic regressions
# (BUG-205: v1=78.93% destroyed by v3=8.65%).
_refine_is_better = not exp_data["metrics_summary"]
if not _refine_is_better and _refine_metrics:
# Compare primary_metric values to decide
_mkey = config.experiment.metric_key or "primary_metric"
_mdir = config.experiment.metric_direction or "maximize"
_existing_pm: float | None = None
_refine_pm: float | None = None
# BUG-214: Use exact match first, then substring fallback
# to avoid "accuracy" matching "balanced_accuracy".
_ms_items = list((exp_data.get("metrics_summary") or {}).items())
for _k, _v in _ms_items:
if _k == _mkey:
try:
_existing_pm = float(_v["mean"] if isinstance(_v, dict) else _v)
except (TypeError, ValueError, KeyError):
pass
break
else:
for _k, _v in _ms_items:
if _mkey in _k:
try:
_existing_pm = float(_v["mean"] if isinstance(_v, dict) else _v)
except (TypeError, ValueError, KeyError):
pass
break
_refine_items = list(_refine_metrics.items())
for _k, _v in _refine_items:
if _k == _mkey:
try:
_refine_pm = float(_v)
except (TypeError, ValueError):
pass
break
else:
for _k, _v in _refine_items:
if _mkey in _k:
try:
_refine_pm = float(_v)
except (TypeError, ValueError):
pass
break
if _existing_pm is None:
_refine_is_better = True # no existing data
elif _refine_pm is not None:
if _mdir == "maximize":
_refine_is_better = _refine_pm > _existing_pm
else:
_refine_is_better = _refine_pm < _existing_pm
logger.info(
"Stage 14: Refine metric comparison: existing=%s, refine=%s, "
"direction=%s → refine_is_better=%s",
_existing_pm, _refine_pm, _mdir, _refine_is_better,
)
if _refine_metrics and _refine_is_better:
# Refinement has richer data — rebuild metrics_summary from it
_new_summary: dict[str, dict[str, float | None]] = {}
for _mk, _mv in _refine_metrics.items():
try:
_fv = float(_mv)
_new_summary[_mk] = {
"min": round(_fv, 6),
"max": round(_fv, 6),
"mean": round(_fv, 6),
"count": 1,
}
except (ValueError, TypeError):
pass
if _new_summary:
exp_data["metrics_summary"] = _new_summary
# Also update best_run with refinement data
exp_data["best_run"] = {
"run_id": "iterative-refine-best",
"task_id": "sandbox-main",
"status": "completed",
"metrics": {
k: v for k, v in _refine_metrics.items()
},
"elapsed_sec": _sbx.get("elapsed_sec", 0),
"stdout": "", # omit for brevity
"stderr": _sbx.get("stderr", ""),
"timed_out": _sbx.get("timed_out", False),
}
# Rebuild latex table
_ltx = [
r"\begin{table}[h]", r"\centering",
r"\caption{Experiment Results (Best Refinement Iteration)}",
r"\begin{tabular}{lrrrr}", r"\hline",
r"Metric & Min & Max & Mean & N \\", r"\hline",
]
for _col in sorted(_new_summary.keys()):
_s = _new_summary[_col]
_ltx.append(
f"{_col} & {_s['min']:.4f} & {_s['max']:.4f} "
f"& {_s['mean']:.4f} & {_s['count']} \\\\"
)
_ltx.extend([r"\hline", r"\end{tabular}", r"\end{table}"])
exp_data["latex_table"] = "\n".join(_ltx)
# Count unique conditions (keys without 'seed' and not ending in _mean/_std)
_conditions = {
k for k in _refine_metrics
if "seed" not in k and not k.endswith("_std")
}
exp_data["runs"] = [exp_data["best_run"]]
# Store condition count for accurate reporting
exp_data["best_run"]["condition_count"] = len(_conditions)
if not context:
context = json.dumps(
{"refinement_best_metrics": _refine_metrics},
indent=2, default=str,
)
_bm_val = _refine_data.get("best_metric")
logger.info(
"R13-1: Merged %d metrics from refinement_log (best_metric=%.4f)",
len(_refine_metrics),
float(_bm_val) if isinstance(_bm_val, (int, float)) else 0.0,
)
except (json.JSONDecodeError, OSError, KeyError):
logger.warning("R13-1: Failed to parse refinement_log.json, using Stage 12 data")
# --- R19-2: Extract PAIRED comparisons from refinement stdout ---
from researchclaw.experiment.sandbox import extract_paired_comparisons as _extract_paired
_all_paired: list[dict[str, object]] = []
# First: from _collect_experiment_results (Stage 12 runs/)
if exp_data.get("paired_comparisons"):
_all_paired.extend(exp_data["paired_comparisons"])
# Second: from refinement_log iterations (Stage 13)
if _refine_log_text:
try:
_rl = json.loads(_refine_log_text)
for _it in _rl.get("iterations", []):
for _sbx_key in ("sandbox", "sandbox_after_fix"):
_sbx_stdout = (_it.get(_sbx_key) or {}).get("stdout", "")
if _sbx_stdout:
_all_paired.extend(_extract_paired(_sbx_stdout))
except (json.JSONDecodeError, OSError):
pass
# --- R19-3: Build structured condition_summaries from metrics ---
_condition_summaries: dict[str, dict[str, Any]] = {}
_ms = exp_data.get("metrics_summary", {})
_best_metrics = {}
if exp_data.get("best_run") and isinstance(exp_data["best_run"], dict):
_best_metrics = exp_data["best_run"].get("metrics", {})
# Group metrics by condition prefix (e.g., "ppo/primary_metric" → condition "ppo")
for _mk, _mv in _best_metrics.items():
parts = _mk.split("/")
if len(parts) >= 2:
cond = parts[0]
metric_name = parts[-1]
if cond not in _condition_summaries:
_condition_summaries[cond] = {"metrics": {}}
try:
_condition_summaries[cond]["metrics"][metric_name] = float(_mv)
except (ValueError, TypeError):
pass
# BUG-09 fix: If no condition summaries were built (metrics don't use
# condition/metric format), try to extract from metrics_summary or
# structured_results so FigureAgent has data to work with.
if not _condition_summaries and _ms:
# Try to parse condition data from metrics_summary keys
for _mk, _mv in _ms.items():
parts = _mk.split("/")
if len(parts) >= 2:
cond = parts[0]
metric_name = parts[-1]
if cond not in _condition_summaries:
_condition_summaries[cond] = {"metrics": {}}
try:
# BUG-182: metrics_summary values are dicts {min,max,mean,count},
# not plain floats. Extract the mean value.
if isinstance(_mv, dict):
_val = float(_mv["mean"]) if "mean" in _mv else None
else:
_val = float(_mv)
if _val is not None:
_condition_summaries[cond]["metrics"][metric_name] = _val
except (ValueError, TypeError, KeyError):
pass
if not _condition_summaries:
# Last resort: build from structured_results condition keys
_sr = exp_data.get("structured_results", {})
if isinstance(_sr, dict):
for _sk, _sv in _sr.items():
if isinstance(_sv, dict) and _sk not in ("metadata", "config"):
_condition_summaries[_sk] = {"metrics": {}}
for _smk, _smv in _sv.items():
try:
_condition_summaries[_sk]["metrics"][_smk] = float(_smv)
except (ValueError, TypeError):
pass
# R33: Build per-seed data structure (needed for CIs and paired tests below)
_seed_data: dict[str, dict[int, float]] = {} # {condition: {seed: value}}
for _mk, _mv in _best_metrics.items():
parts = _mk.split("/")
# Pattern: condition/regime/seed_id/primary_metric
if len(parts) >= 4 and parts[-1] == config.experiment.metric_key:
cond = parts[0]
try:
seed_id = int(parts[2])
val = float(_mv)
_seed_data.setdefault(cond, {})[seed_id] = val
except (ValueError, TypeError):
pass
# Enrich condition summaries with seed counts, success rates, and CIs
for _ck, _cv in _condition_summaries.items():
# Look for success_rate in metrics
sr_key = f"{_ck}/success_rate"
if sr_key in _best_metrics:
try:
_cv["success_rate"] = float(_best_metrics[sr_key])
except (ValueError, TypeError):
pass
# Count seed-level entries to estimate n_seeds
_seed_count = 0
for _mk in _best_metrics:
if _mk.startswith(f"{_ck}/") and "seed" in _mk.lower():
_seed_count += 1
if _seed_count > 0:
_cv["n_seed_metrics"] = _seed_count
# R33: Compute mean ± std and bootstrap 95% CI from per-seed data
if _ck in _seed_data and len(_seed_data[_ck]) >= 3:
_vals = list(_seed_data[_ck].values())
import statistics as _stats_mod
_mean = _stats_mod.mean(_vals)
_std = _stats_mod.stdev(_vals)
_cv["metrics"][f"{config.experiment.metric_key}_mean"] = round(_mean, 6)
_cv["metrics"][f"{config.experiment.metric_key}_std"] = round(_std, 6)
_cv["n_seeds"] = len(_vals)
# Bootstrap 95% CI (use local RNG to avoid corrupting global state)
import random as _rng_mod
_rng_local = _rng_mod.Random(42)
_boot_means = []
for _ in range(1000):
_sample = [_rng_local.choice(_vals) for _ in range(len(_vals))]
_boot_means.append(_stats_mod.mean(_sample))
_boot_means.sort()
_ci_low = round(_boot_means[int(0.025 * len(_boot_means))], 6)
_ci_high = round(_boot_means[int(0.975 * len(_boot_means))], 6)
# IMP-16: Sanity check — CI must contain the mean
if _ci_low > _mean or _ci_high < _mean:
logger.warning(
"Bootstrap CI [%.4f, %.4f] does not contain mean %.4f "
"for condition %s — replacing CI with mean ± 1.96*SE",
_ci_low, _ci_high, _mean, _ck,
)
_se = _std / (len(_vals) ** 0.5)
_ci_low = round(_mean - 1.96 * _se, 6)
_ci_high = round(_mean + 1.96 * _se, 6)
_cv["ci95_low"] = _ci_low
_cv["ci95_high"] = _ci_high
# Count totals
_total_conditions = len(_condition_summaries) if _condition_summaries else None
_total_metrics = len(_best_metrics) if _best_metrics else None
# --- R33: Pipeline-level paired computation as fallback ---
# If the experiment code's PAIRED lines are sparse or suspicious (e.g.,
# all identical t-stats), compute fresh paired tests from per-seed data.
# (_seed_data was built above before condition summary enrichment)
if len(_seed_data) >= 2:
# Find common seeds across conditions
_all_seeds_sets = [set(v.keys()) for v in _seed_data.values()]
_common_seeds = set.intersection(*_all_seeds_sets) if _all_seeds_sets else set()
if len(_common_seeds) >= 3:
_cond_names_sorted = sorted(_seed_data.keys())
_pipeline_paired: list[dict[str, object]] = []
# Compare each condition against the first baseline (alphabetically)
_baseline_cond = _cond_names_sorted[0]
for _other_cond in _cond_names_sorted[1:]:
_diffs = []
for _sid in sorted(_common_seeds):
_diffs.append(
_seed_data[_other_cond][_sid] - _seed_data[_baseline_cond][_sid]
)
if _diffs:
import statistics
_n = len(_diffs)
_mean_d = statistics.mean(_diffs)
_std_d = statistics.stdev(_diffs) if _n > 1 else 0.0
_t = (_mean_d / (_std_d / (_n ** 0.5))) if _std_d > 0 else 0.0
_df = _n - 1
# Two-tailed p-value using t-distribution
import math
try:
from scipy.stats import t as _t_dist
_p = float(2 * _t_dist.sf(abs(_t), _df))
except ImportError:
_p = 2 * (1 - 0.5 * (1 + math.erf(abs(_t) / (2 ** 0.5))))
if _df < 30:
_p = min(1.0, _p * (1 + 2.5 / max(_df, 1)))
_pipeline_paired.append({
"method": _other_cond,
"baseline": _baseline_cond,
"mean_diff": round(_mean_d, 6),
"std_diff": round(_std_d, 6),
"t_stat": round(_t, 4),
"p_value": round(_p, 6),
"n_seeds": _n,
"source": "pipeline_computed",
})
# Use pipeline-computed if experiment code's are suspicious
_exp_t_stats = {round(p.get("t_stat", 0), 4) for p in _all_paired}
_all_identical = len(_exp_t_stats) <= 1 and len(_all_paired) > 1
if _pipeline_paired and (_all_identical or len(_all_paired) < len(_pipeline_paired)):
logger.info(
"R33: Using %d pipeline-computed paired tests (experiment code had %d, identical=%s)",
len(_pipeline_paired), len(_all_paired), _all_identical,
)
_all_paired = _pipeline_paired
# --- P8: Detect identical conditions (broken ablations) ---
_ablation_warnings: list[str] = []
if _condition_summaries and len(_condition_summaries) >= 2:
_cond_names = sorted(_condition_summaries.keys())
for _i in range(len(_cond_names)):
for _j in range(_i + 1, len(_cond_names)):
_c1, _c2 = _cond_names[_i], _cond_names[_j]
_s1_raw = _condition_summaries[_c1]
_s2_raw = _condition_summaries[_c2]
# BUG-133 fix: compare inner metrics dicts, not top-level keys
_s1_m = _s1_raw.get("metrics", {}) if isinstance(_s1_raw, dict) else {}
_s2_m = _s2_raw.get("metrics", {}) if isinstance(_s2_raw, dict) else {}
if not isinstance(_s1_m, dict):
_s1_m = {}
if not isinstance(_s2_m, dict):
_s2_m = {}
_shared_keys = set(_s1_m.keys()) & set(_s2_m.keys())
if not _shared_keys:
continue
_all_equal = True
for _sk in _shared_keys:
_v1 = _s1_m[_sk]
_v2 = _s2_m[_sk]
if _v1 != _v2:
_all_equal = False
break
if _all_equal and _shared_keys:
_warn = (
f"ABLATION FAILURE: Conditions '{_c1}' and '{_c2}' produce "
f"identical outputs across all {len(_shared_keys)} metrics. "
f"The ablation is invalid — the differentiating parameter "
f"is likely not used in the code."
)
_ablation_warnings.append(_warn)
logger.warning("P8: %s", _warn)
elif _shared_keys:
# R5-BUG-03: Also flag near-identical conditions (< 1% relative diff)
_near_identical = True
for _sk in _shared_keys:
_v1 = _s1_m[_sk]
_v2 = _s2_m[_sk]
try:
_v1f, _v2f = float(_v1), float(_v2)
_denom = max(abs(_v1f), abs(_v2f), 1e-12)
if abs(_v1f - _v2f) / _denom > 0.01:
_near_identical = False
break
except (TypeError, ValueError):
_near_identical = False
break
if _near_identical:
_warn = (
f"ABLATION WARNING: Conditions '{_c1}' and '{_c2}' produce "
f"near-identical outputs (<1% relative difference) across "
f"all {len(_shared_keys)} metrics. The ablation may be trivial."
)
_ablation_warnings.append(_warn)
logger.warning("P8: %s", _warn)
# --- Improvement B: Validate seed counts ---
_seed_insufficiency_warnings: list[str] = []
for _sc_name, _sc_seeds in _seed_data.items():
_n_seeds = len(_sc_seeds)
if 0 < _n_seeds < 3:
_warn = (
f"SEED_INSUFFICIENCY: Condition '{_sc_name}' has only "
f"{_n_seeds} seed(s) (minimum 3 required for statistical validity)"
)
_seed_insufficiency_warnings.append(_warn)
logger.warning("B: %s", _warn)
# --- Write structured experiment summary ---
summary_payload = {
"metrics_summary": exp_data["metrics_summary"],
"total_runs": len(exp_data["runs"]),
"best_run": exp_data["best_run"],
"latex_table": exp_data["latex_table"],
"generated": _utcnow_iso(),
}
if _seed_insufficiency_warnings:
summary_payload["seed_insufficiency_warnings"] = _seed_insufficiency_warnings
# R13-1: Detect zero-variance across conditions (all conditions identical primary metric)
if _condition_summaries and len(_condition_summaries) >= 2:
_primary_vals = []
for _cs in _condition_summaries.values():
if isinstance(_cs, dict):
# Try 'metrics' dict first (actual structure), then 'primary_metric' fallback
_metrics = _cs.get("metrics", {})
if isinstance(_metrics, dict) and _metrics:
_pv_candidate = next(iter(_metrics.values()), None)
if isinstance(_pv_candidate, dict):
_pv_candidate = _pv_candidate.get("mean")
if isinstance(_pv_candidate, (int, float)):
_primary_vals.append(_pv_candidate)
continue
_pm = _cs.get("primary_metric", {})
_pv = _pm.get("mean") if isinstance(_pm, dict) else _pm
if isinstance(_pv, (int, float)):
_primary_vals.append(_pv)
if len(_primary_vals) >= 2 and len(set(_primary_vals)) == 1:
_zv_warn = (
f"ZERO VARIANCE: All {len(_primary_vals)} conditions have "
f"identical primary_metric ({_primary_vals[0]}). "
f"Experiment condition wiring is likely broken."
)
_ablation_warnings.append(_zv_warn)
logger.warning("R13-1: %s", _zv_warn)
if _ablation_warnings:
summary_payload["ablation_warnings"] = _ablation_warnings
if _all_paired:
summary_payload["paired_comparisons"] = _all_paired
if _condition_summaries:
summary_payload["condition_summaries"] = _condition_summaries
summary_payload["condition_metrics"] = _condition_summaries # alias for quality gate
summary_payload["total_conditions"] = _total_conditions
if _total_metrics:
summary_payload["total_metric_keys"] = _total_metrics
(stage_dir / "experiment_summary.json").write_text(
json.dumps(summary_payload, indent=2, default=str), encoding="utf-8"
)
if exp_data["latex_table"]:
(stage_dir / "results_table.tex").write_text(
exp_data["latex_table"], encoding="utf-8"
)
# --- Build data-augmented prompt ---
preamble = _build_context_preamble(
config, run_dir, include_goal=True, include_hypotheses=True
)
data_context = ""
if exp_data["metrics_summary"]:
lines = ["\n## Quantitative Results"]
for mk, mv in exp_data["metrics_summary"].items():
if isinstance(mv, dict):
lines.append(
f"- {mk}: mean={mv.get('mean', '?')}, min={mv.get('min', '?')}, "
f"max={mv.get('max', '?')}, n={mv.get('count', '?')}"
)
data_context = "\n".join(lines)
# Append structured results if available
if exp_data.get("structured_results"):
structured_text = json.dumps(
exp_data["structured_results"], indent=2, default=str
)
# Truncate to avoid blowing up context
if len(structured_text) > 6000:
structured_text = structured_text[:6000] + "\n... (truncated)"
data_context += (
f"\n\n## Structured Experiment Results (from results.json)\n"
f"```json\n{structured_text}\n```"
)
# P8: Inject ablation warnings into data context
if _ablation_warnings:
data_context += "\n\nCRITICAL ABLATION WARNINGS:\n"
for _aw in _ablation_warnings:
data_context += f"- {_aw}\n"
data_context += (
"\nYou MUST address these in your analysis. Identical conditions "
"mean the ablation design is broken and the comparison is meaningless.\n"
)
if llm is not None:
_pm = prompts or PromptManager()
from researchclaw.prompts import DEBATE_ROLES_ANALYSIS # noqa: PLC0415
# --- Multi-perspective debate ---
perspectives_dir = stage_dir / "perspectives"
variables = {
"preamble": preamble,
"data_context": data_context,
"context": context,
}
perspectives = _multi_perspective_generate(
llm, DEBATE_ROLES_ANALYSIS, variables, perspectives_dir
)
# --- Synthesize into unified analysis ---
analysis = _synthesize_perspectives(
llm, perspectives, "analysis_synthesize", _pm
)
else:
# Template with real data if available
ms = exp_data["metrics_summary"]
metrics_block = ""
if ms:
for mk, mv in ms.items():
if isinstance(mv, dict):
metrics_block += (
f"- **{mk}**: mean={mv.get('mean')}, "
f"min={mv.get('min')}, max={mv.get('max')}, n={mv.get('count')}\n"
)
else:
metrics_block = f"- Primary metric key: `{config.experiment.metric_key}`\n- No quantitative data yet.\n"
analysis = f"""# Result Analysis
## Metrics Summary
{metrics_block}
## Comparative Findings
- Proposed approach results from {len(exp_data["runs"])} run(s) collected.
## Statistical Checks
- Recommend confidence interval and seed-wise variance reporting.
## Limitations
- Limited runs and synthetic constraints.
## Conclusion
- Proceed to decision stage with moderate confidence.
Generated: {_utcnow_iso()}
"""
(stage_dir / "analysis.md").write_text(analysis, encoding="utf-8")
artifacts = ["analysis.md", "experiment_summary.json"]
if (stage_dir / "results_table.tex").exists():
artifacts.append("results_table.tex")
# IMP-6 + FA: Generate charts early (Stage 14) so paper draft can reference them
# Try FigureAgent first (multi-agent intelligent charts), fall back to visualize.py
_figure_plan_saved = False
if config.experiment.figure_agent.enabled and llm is not None:
try:
from researchclaw.agents.figure_agent import FigureOrchestrator
from researchclaw.agents.figure_agent.orchestrator import FigureAgentConfig as _FACfg
_fa_cfg = _FACfg(
enabled=True,
min_figures=config.experiment.figure_agent.min_figures,
max_figures=config.experiment.figure_agent.max_figures,
max_iterations=config.experiment.figure_agent.max_iterations,
render_timeout_sec=config.experiment.figure_agent.render_timeout_sec,
use_docker=config.experiment.figure_agent.use_docker,
docker_image=config.experiment.figure_agent.docker_image,
output_format=config.experiment.figure_agent.output_format,
gemini_api_key=config.experiment.figure_agent.gemini_api_key,
gemini_model=config.experiment.figure_agent.gemini_model,
nano_banana_enabled=config.experiment.figure_agent.nano_banana_enabled,
strict_mode=config.experiment.figure_agent.strict_mode,
dpi=config.experiment.figure_agent.dpi,
)
_fa = FigureOrchestrator(llm, _fa_cfg, stage_dir=stage_dir)
# Build conditions list from condition_summaries
_fa_conditions = list(_condition_summaries.keys()) if _condition_summaries else []
# BUG-09 fix: pass best_run metrics as fallback data if
# structured_results is empty, so Planner has some data to chart
_fa_exp_results = exp_data.get("structured_results", {})
if not _fa_exp_results and _best_metrics:
_fa_exp_results = {"best_run_metrics": _best_metrics}
# Read paper draft for Decision Agent analysis
_paper_draft = (
_read_prior_artifact(run_dir, "paper_draft.md")
or _read_prior_artifact(run_dir, "outline.md")
or ""
)
_fa_plan = _fa.orchestrate({
"experiment_results": _fa_exp_results,
"condition_summaries": _condition_summaries,
"metrics_summary": exp_data.get("metrics_summary", {}),
"metric_key": config.experiment.metric_key,
"conditions": _fa_conditions,
"topic": _read_prior_artifact(run_dir, "topic.md") or config.research.topic,
"hypothesis": _read_prior_artifact(run_dir, "hypotheses.md") or "",
"paper_draft": _paper_draft,
"output_dir": str(stage_dir / "charts"),
})
if _fa_plan.figure_count > 0:
# Save figure plan for Stage 17 to read
(stage_dir / "figure_plan.json").write_text(
json.dumps(_fa_plan.to_dict(), indent=2, default=str),
encoding="utf-8",
)
_figure_plan_saved = True
for _cf_name in _fa_plan.get_chart_files():
artifacts.append(f"charts/{_cf_name}")
logger.info(
"Stage 14: FigureAgent generated %d charts (%d passed review, %.1fs)",
_fa_plan.figure_count,
_fa_plan.passed_count,
_fa_plan.elapsed_sec,
)
else:
logger.warning("Stage 14: FigureAgent produced no charts, falling back")
except Exception as _fa_exc:
logger.warning("Stage 14: FigureAgent failed (%s), falling back to visualize.py", _fa_exc)
# Fallback: legacy visualize.py chart generation
if not _figure_plan_saved:
try:
from researchclaw.experiment.visualize import (
generate_all_charts as _gen_charts_early,
)
_charts_dir = stage_dir / "charts"
_early_charts = _gen_charts_early(
run_dir,
_charts_dir,
metric_key=config.experiment.metric_key,
)
if _early_charts:
for _cp in _early_charts:
artifacts.append(f"charts/{_cp.name}")
logger.info(
"Stage 14: Generated %d early charts (legacy) for paper embedding",
len(_early_charts),
)
except Exception as _chart_exc:
logger.warning("Stage 14: Early chart generation failed: %s", _chart_exc)
return StageResult(
stage=Stage.RESULT_ANALYSIS,
status=StageStatus.DONE,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-14/{a}" for a in artifacts),
)
def _parse_decision(text: str) -> str:
"""Extract PROCEED/PIVOT/REFINE from decision text.
Looks for the first standalone keyword on its own line after a
``## Decision`` heading. Falls back to a keyword scan of the first
few lines after the heading, but only matches the keyword itself
(not mentions inside explanatory prose like "PIVOT is not warranted").
Returns lowercase ``"proceed"`` / ``"pivot"`` / ``"refine"``.
Defaults to ``"proceed"`` if nothing matches.
"""
import re as _re
text_upper = text.upper()
# Look in the first occurrence after "## Decision" heading
decision_section = ""
for keyword in ("## DECISION", "## Decision", "## decision"):
if keyword.upper() in text_upper:
idx = text_upper.index(keyword.upper())
decision_section = text[idx : idx + 200]
break
search_text = decision_section or text[:500]
# First try: look for a line that is just the keyword (possibly with
# whitespace / markdown bold / trailing punctuation).
for line in search_text.splitlines():
stripped = line.strip().strip("*").strip("#").strip()
if stripped.upper() in ("PROCEED", "PIVOT", "REFINE"):
return stripped.lower()
# Fallback: regex for standalone word boundaries so that
# "PIVOT is not warranted" does NOT match as a decision.
for kw in ("PIVOT", "REFINE", "PROCEED"):
# Only match if the keyword appears as the FIRST keyword-class token
# on its own (not embedded in a sentence saying "not PIVOT").
pattern = _re.compile(
r"(?:^|##\s*Decision\s*\n\s*)" + kw, _re.IGNORECASE | _re.MULTILINE
)
if pattern.search(search_text):
return kw.lower()
# Last resort: position-based — prefer whichever keyword appears LAST
# (the final conclusion after deliberation is more reliable than early mentions)
# BUG-DA8-08: Old code always returned "refine" when both keywords present
search_upper = search_text.upper()
last_refine = search_upper.rfind("REFINE")
last_pivot = search_upper.rfind("PIVOT")
if last_refine >= 0 and (last_pivot < 0 or last_refine > last_pivot):
return "refine"
if last_pivot >= 0 and (last_refine < 0 or last_pivot > last_refine):
return "pivot"
return "proceed"
def _execute_research_decision(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
analysis = _read_prior_artifact(run_dir, "analysis.md") or ""
# P6: Detect degenerate REFINE cycles — inject warning if metrics stagnate
_degenerate_hint = ""
_refine_log = _read_prior_artifact(run_dir, "refinement_log.json")
if _refine_log:
try:
_rl = json.loads(_refine_log)
_iters = _rl.get("iterations", [])
_metrics = [it.get("metric") for it in _iters if isinstance(it, dict)]
_valid = [m for m in _metrics if m is not None]
_all_saturated = _valid and all(m <= 0.001 or m >= 0.999 for m in _valid)
_all_identical = len(set(_valid)) <= 1 and len(_valid) >= 2
if _all_saturated or _all_identical:
_degenerate_hint = (
"\n\nSYSTEM WARNING — DEGENERATE REFINE CYCLE DETECTED:\n"
f"Metrics across {len(_valid)} iterations: {_valid}\n"
"All iterations produce identical/saturated results. Further REFINE "
"cycles CANNOT fix this — the underlying benchmark design is too "
"easy/hard. You SHOULD choose PROCEED with a quality caveat rather "
"than REFINE again.\n"
)
logger.warning("P6: Degenerate refine cycle detected, injecting PROCEED hint")
except (json.JSONDecodeError, OSError):
pass
# Phase 2: Inject experiment diagnosis into decision prompt
_diagnosis_hint = ""
_diag_path = run_dir / "experiment_diagnosis.json"
if _diag_path.exists():
try:
_diag_data = json.loads(_diag_path.read_text(encoding="utf-8"))
_qa = _diag_data.get("quality_assessment", {})
_mode = _qa.get("mode", "unknown")
_sufficient = _qa.get("sufficient", False)
_deficiency_types = _qa.get("deficiency_types", [])
if not _sufficient:
_diagnosis_hint = (
"\n\n## EXPERIMENT DIAGNOSIS (from automated analysis)\n"
f"Quality mode: {_mode}\n"
f"Sufficient for full paper: NO\n"
f"Issues found: {', '.join(_deficiency_types)}\n\n"
"IMPORTANT: The experiment has significant issues. "
"If REFINE is chosen, a structured repair prompt is available "
"at repair_prompt.txt with specific fixes for identified issues.\n"
"If the same issues persist after 2+ REFINE cycles, choose PROCEED "
"with appropriate quality caveats.\n"
)
logger.info(
"Stage 15: Injected experiment diagnosis — mode=%s, issues=%s",
_mode, _deficiency_types,
)
except (json.JSONDecodeError, OSError):
pass
# Improvement C: Check ablation quality — if >50% trivial, push REFINE
_ablation_refine_hint = ""
# BUG-DA8-16: Prefer experiment_summary_best.json (promoted best) over
# alphabetically-last stage-14* (which could be a stale versioned dir)
_exp_sum_path = run_dir / "experiment_summary_best.json"
if not _exp_sum_path.is_file():
_exp_sum_path = None
for _s14 in sorted(run_dir.glob("stage-14*/experiment_summary.json"), reverse=True):
_exp_sum_path = _s14
break
if _exp_sum_path and _exp_sum_path.is_file():
try:
from researchclaw.pipeline.stage_impls._paper_writing import _check_ablation_effectiveness
_abl_exp = json.loads(_exp_sum_path.read_text(encoding="utf-8"))
_abl_warnings = _check_ablation_effectiveness(_abl_exp, threshold=0.02)
if _abl_warnings:
_trivial_count = sum(1 for w in _abl_warnings if "ineffective" in w.lower() or "trivial" in w.lower())
_total_abl = max(1, len(_abl_warnings))
if _trivial_count / _total_abl > 0.5:
_ablation_refine_hint = (
"\n\n## ABLATION QUALITY ASSESSMENT (CRITICAL)\n"
f"STRONG RECOMMENDATION: Choose REFINE.\n"
f"{_trivial_count}/{_total_abl} ablations show <2% difference from baseline "
f"(trivially similar). This means the ablation design is broken.\n"
"Warnings:\n" + "\n".join(f"- {w}" for w in _abl_warnings) + "\n"
)
logger.warning("C: %d/%d ablations trivial → recommending REFINE", _trivial_count, _total_abl)
except Exception: # noqa: BLE001
pass
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "research_decision")
sp = _pm.for_stage("research_decision", evolution_overlay=_overlay, analysis=analysis)
_user = sp.user + _degenerate_hint + _diagnosis_hint + _ablation_refine_hint
resp = _chat_with_prompt(llm, sp.system, _user)
decision_md = resp.content
else:
decision_md = f"""# Research Decision
## Decision
PROCEED
## Justification
Current evidence suggests measurable progress with actionable limitations.
## Next Actions
- Build detailed paper outline
- Expand ablation and uncertainty analysis in writing
Generated: {_utcnow_iso()}
"""
(stage_dir / "decision.md").write_text(decision_md, encoding="utf-8")
# --- Extract structured decision ---
decision = _parse_decision(decision_md)
# T3.1: Validate decision quality — check for minimum experiment rigor
_quality_warnings: list[str] = []
_dec_lower = decision_md.lower()
if "baseline" not in _dec_lower and "control" not in _dec_lower:
_quality_warnings.append("Decision text does not mention baselines")
if "seed" not in _dec_lower and "replicat" not in _dec_lower and "run" not in _dec_lower:
_quality_warnings.append("Decision text does not mention multi-seed/replicate runs")
if "metric" not in _dec_lower and "accuracy" not in _dec_lower and "loss" not in _dec_lower:
_quality_warnings.append("Decision text does not mention evaluation metrics")
if _quality_warnings:
logger.warning("T3.1: Decision quality warnings: %s", _quality_warnings)
decision_payload = {
"decision": decision,
"raw_text_excerpt": decision_md[:500],
"quality_warnings": _quality_warnings,
"generated": _utcnow_iso(),
}
(stage_dir / "decision_structured.json").write_text(
json.dumps(decision_payload, indent=2), encoding="utf-8"
)
logger.info("Research decision: %s", decision)
return StageResult(
stage=Stage.RESEARCH_DECISION,
status=StageStatus.DONE,
artifacts=("decision.md", "decision_structured.json"),
evidence_refs=("stage-15/decision.md",),
decision=decision,
)
================================================
FILE: researchclaw/pipeline/stage_impls/_code_generation.py
================================================
"""Stage 10: Code generation."""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.experiment.validator import (
CodeValidation,
format_issues_for_llm,
validate_code,
)
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain
from researchclaw.pipeline._helpers import (
StageResult,
_chat_with_prompt,
_ensure_sandbox_deps,
_extract_code_block,
_extract_multi_file_blocks,
_extract_yaml_block,
_get_evolution_overlay,
_load_hardware_profile,
_read_prior_artifact,
_safe_json_loads,
_utcnow_iso,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
# Improvement G: Continuous-action environments that are incompatible with DQN
_CONTINUOUS_ENVS = {
"pendulum", "halfcheetah", "hopper", "walker2d", "ant", "humanoid",
"swimmer", "reacher", "invertedpendulum", "inverteddoublependulum",
"mountaincarcontinuous", "lunarlander-continuous",
}
def _check_rl_compatibility(code: str) -> list[str]:
"""Detect DQN + continuous-action environment mismatches.
Returns a list of error strings if incompatible combinations are found.
"""
errors: list[str] = []
code_lower = code.lower()
has_dqn = "dqn" in code_lower
if not has_dqn:
return errors
for env_name in _CONTINUOUS_ENVS:
if env_name in code_lower:
errors.append(
f"RL COMPATIBILITY ERROR: DQN is used with continuous-action "
f"environment '{env_name}'. DQN only works with DISCRETE action "
f"spaces. Use SAC, TD3, or PPO instead."
)
return errors
def _execute_code_generation(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
exp_plan = _read_prior_artifact(run_dir, "exp_plan.yaml") or ""
metric = config.experiment.metric_key
max_repair = 5 # BUG-14: Increased from 3 to give more chances for critical bugs
files: dict[str, str] = {}
validation_log: list[str] = []
# --- Detect available packages for sandbox ---
_pm = prompts or PromptManager()
# --- Hardware-aware package hint ---
hw_profile = _load_hardware_profile(run_dir)
if config.experiment.mode in ("sandbox", "docker"):
if config.experiment.mode == "docker":
pkg_prefix = "docker mode"
_net_policy = config.experiment.docker.network_policy
_base_pkgs = (
", torchvision, torchaudio, matplotlib, seaborn, scipy, "
"tqdm, torchdiffeq, gymnasium, networkx, PyYAML, Pillow, "
"transformers, datasets, accelerate, peft, bitsandbytes, "
"timm, einops, torchmetrics, h5py"
)
if _net_policy == "none":
pkg_extras = _base_pkgs + " (ONLY pre-installed packages — NO pip install available)"
elif _net_policy in ("setup_only", "pip_only"):
pkg_extras = _base_pkgs + ", and additional pip-installable packages via requirements.txt"
else:
pkg_extras = _base_pkgs + ", and additional pip-installable packages (auto-detected from imports)"
else:
pkg_prefix = "sandbox mode"
pkg_extras = ""
if hw_profile and hw_profile.get("has_gpu"):
gpu_type = hw_profile.get("gpu_type", "cuda")
gpu_name = hw_profile.get("gpu_name", "GPU")
tier = hw_profile.get("tier", "limited")
if tier == "high":
device_hint = f"torch.device('{gpu_type}')"
pkg_hint = (
f"\nAVAILABLE PACKAGES ({pkg_prefix}): Python stdlib, numpy, torch, sklearn, scipy, pandas{pkg_extras}.\n"
f"GPU: {gpu_name} ({gpu_type}). You MAY use PyTorch with GPU acceleration.\n"
f"Use `device = {device_hint}` for tensor operations.\n"
)
else: # limited (low VRAM NVIDIA or MPS)
device_hint = f"torch.device('{gpu_type}')"
pkg_hint = (
f"\nAVAILABLE PACKAGES ({pkg_prefix}): Python stdlib, numpy, torch, sklearn, scipy, pandas{pkg_extras}.\n"
f"GPU: {gpu_name} ({gpu_type}) — LIMITED performance.\n"
f"Use `device = {device_hint}` but design LIGHTWEIGHT experiments:\n"
f"- Small models (<1M parameters)\n"
f"- Few epochs (<=20)\n"
f"- Small datasets (<=10K samples)\n"
f"- Avoid large batch sizes\n"
)
else:
pkg_hint = _pm.block("pkg_hint_sandbox")
else:
pkg_hint = ""
# --- Compute budget hint ---
time_budget_sec = config.experiment.time_budget_sec
try:
compute_budget = _pm.block("compute_budget").replace(
"{time_budget_sec}", str(time_budget_sec)
)
except Exception: # noqa: BLE001
compute_budget = (
f"\n## Compute Budget Constraint\n"
f"- Total execution time limit: {time_budget_sec} seconds\n"
f"- Design experiments that complete within this budget\n"
f"- Implement a time guard: stop gracefully at 80% of budget\n"
)
# --- Dataset guidance + setup script + HP reporting (docker/sandbox modes) ---
extra_guidance = ""
_net_policy = getattr(getattr(config, "docker", None), "network_policy", "setup_only")
if config.experiment.mode in ("sandbox", "docker"):
_net_policy = (
config.experiment.docker.network_policy
if config.experiment.mode == "docker"
else "none" # sandbox mode has no network
)
if _net_policy == "none":
# Network disabled: inject strict offline-only guidance
try:
extra_guidance += _pm.block("network_disabled_guidance")
except Exception: # noqa: BLE001
pass
elif _net_policy == "full":
try:
extra_guidance += _pm.block("dataset_guidance")
extra_guidance += _pm.block("network_full_guidance")
except Exception: # noqa: BLE001
pass
else:
# setup_only or pip_only — existing behavior
try:
extra_guidance += _pm.block("dataset_guidance")
except Exception: # noqa: BLE001
pass
if config.experiment.mode == "docker":
try:
extra_guidance += _pm.block("setup_script_guidance")
except Exception: # noqa: BLE001
pass
try:
extra_guidance += _pm.block("hp_reporting")
except Exception: # noqa: BLE001
pass
# I-06: Multi-seed enforcement for all experiments
try:
extra_guidance += _pm.block("multi_seed_enforcement")
except Exception: # noqa: BLE001
pass
# --- BA: Inject BenchmarkAgent plan from Stage 9 ---
_bp_path = None
for _s9_dir in sorted(run_dir.glob("stage-09*"), reverse=True):
_candidate = _s9_dir / "benchmark_plan.json"
if _candidate.exists():
_bp_path = _candidate
break
if _bp_path is not None:
try:
import json as _json_bp
_bp_data = _json_bp.loads(_bp_path.read_text(encoding="utf-8"))
# Reconstruct the prompt block
from researchclaw.agents.benchmark_agent.orchestrator import BenchmarkPlan
_bp = BenchmarkPlan(
selected_benchmarks=_bp_data.get("selected_benchmarks", []),
selected_baselines=_bp_data.get("selected_baselines", []),
data_loader_code=_bp_data.get("data_loader_code", ""),
baseline_code=_bp_data.get("baseline_code", ""),
experiment_notes=_bp_data.get("experiment_notes", ""),
)
_bp_block = _bp.to_prompt_block()
if _bp_block:
extra_guidance += (
"\n\n## BenchmarkAgent Selections (USE THESE)\n"
"The following datasets, baselines, and code snippets were "
"automatically selected and validated by the BenchmarkAgent. "
"You MUST use these selections in your experiment code.\n\n"
+ _bp_block
)
logger.info(
"BA: Injected benchmark plan (%d benchmarks, %d baselines)",
len(_bp.selected_benchmarks), len(_bp.selected_baselines),
)
except Exception as _bp_exc:
logger.debug("BA: Failed to load benchmark plan: %s", _bp_exc)
# --- P2.2+P2.3: LLM training topic detection and guidance ---
_llm_keywords = (
"language model", "llm", "fine-tun", "lora", "qlora", "peft",
"instruction tun", "rlhf", "dpo", "sft", "alignment",
"transformer train", "causal lm", "chat model", "qwen", "llama",
"mistral", "phi-", "gemma", "pretraining", "tokeniz",
)
topic_lower = config.research.topic.lower()
is_llm_topic = any(kw in topic_lower for kw in _llm_keywords)
# --- I-08: RL topic detection and step guidance ---
_rl_keywords = (
"reinforcement learning", "policy gradient", "ppo", "sac", "td3",
"ddpg", "dqn", "a2c", "a3c", "mujoco", "locomotion", "continuous control",
"reward shaping", "exploration", "multi-agent rl", "marl", "curriculum rl",
"imitation learning", "inverse rl", "offline rl", "model-based rl",
"actor-critic", "reinforce", "gym", "gymnasium",
)
is_rl_topic = any(kw in topic_lower for kw in _rl_keywords)
if is_rl_topic:
try:
extra_guidance += _pm.block("rl_step_guidance")
except Exception: # noqa: BLE001
pass
# --- F-01: Framework API doc injection (auto-detected) ---
try:
from researchclaw.data import detect_frameworks, load_framework_docs
_hypothesis_text = _read_prior_artifact(run_dir, "hypotheses.md") or ""
_fw_ids = detect_frameworks(
config.research.topic, _hypothesis_text, exp_plan or ""
)
if _fw_ids:
_fw_docs = load_framework_docs(_fw_ids, max_chars=8000)
if _fw_docs:
extra_guidance += _fw_docs
logger.info("F-01: Injected framework docs for: %s", _fw_ids)
except Exception: # noqa: BLE001
logger.debug("F-01: Framework doc injection skipped", exc_info=True)
if is_llm_topic and config.experiment.mode == "docker":
try:
extra_guidance += _pm.block("llm_training_guidance")
except Exception: # noqa: BLE001
pass
try:
extra_guidance += _pm.block("llm_eval_guidance")
except Exception: # noqa: BLE001
pass
# P2.3: Warn if time budget is too short for LLM training
if time_budget_sec < 3600:
extra_guidance += (
"\n## COMPUTE BUDGET WARNING\n"
f"Current time_budget_sec={time_budget_sec} is likely TOO SHORT "
f"for LLM fine-tuning. Typical LoRA training needs 1-4 hours. "
f"Design a LIGHTWEIGHT experiment:\n"
f"- Use a small dataset (<=5000 samples)\n"
f"- Train for 1-3 epochs only\n"
f"- Use small batch size (1-2) with gradient accumulation\n"
f"- Use 4-bit quantization (QLoRA) to minimize memory\n"
f"- Limit max_seq_length to 512-1024\n"
f"- If possible, use a smaller model (<=7B parameters)\n"
)
# --- Domain-specific guidance injection for non-ML domains ---
try:
from researchclaw.domains.detector import detect_domain as _dd_s10, is_ml_domain as _is_ml_s10
_dp = _dd_s10(topic=config.research.topic)
if not _is_ml_s10(_dp):
from researchclaw.domains.prompt_adapter import get_adapter as _ga
_adapter = _ga(_dp)
_blocks = _adapter.get_code_generation_blocks({})
if _blocks.compute_budget:
compute_budget = _blocks.compute_budget
if _blocks.dataset_guidance:
extra_guidance = _blocks.dataset_guidance + "\n" + extra_guidance
if _blocks.code_generation_hints:
extra_guidance += "\n" + _blocks.code_generation_hints
if _blocks.output_format_guidance:
extra_guidance += "\n" + _blocks.output_format_guidance
logger.info("Injected domain-specific guidance for %s", _dp.domain_id)
except Exception: # noqa: BLE001
logger.debug("Domain guidance injection skipped", exc_info=True)
# BUG-R6-01: Add explicit implementation constraints to prevent LLM
# from substituting unrelated DL models for lightweight algorithms.
extra_guidance += (
"\n\nIMPLEMENTATION CONSTRAINTS (MUST FOLLOW):\n"
"- Implement EXACTLY the algorithm/method described in the topic.\n"
"- Do NOT replace the stated method with a deep-learning proxy "
"(e.g. ResNet, BERT, GPT, Gymnasium+SB3) unless the topic "
"EXPLICITLY requires deep learning.\n"
"- Prefer lightweight CPU-friendly libraries (numpy, scipy, "
"sklearn, pandas) unless deep learning is inherent to the topic.\n"
"- The experiment MUST be self-contained and runnable without GPU.\n"
)
# --- Code generation: Beast Mode → CodeAgent → Legacy single-shot ---
_code_agent_active = False
_beast_mode_used = False
_code_max_tokens = 8192
# ── Beast Mode: OpenCode external agent (optional) ─────────────────
_oc_cfg = config.experiment.opencode
if _oc_cfg.enabled:
from researchclaw.pipeline.opencode_bridge import (
OpenCodeBridge,
OpenCodeResult,
count_historical_failures,
score_complexity,
)
_hist_failures = count_historical_failures(run_dir)
_cplx = score_complexity(
exp_plan=exp_plan,
topic=config.research.topic,
historical_failures=_hist_failures,
threshold=_oc_cfg.complexity_threshold,
)
# Persist complexity analysis
(stage_dir / "complexity_analysis.json").write_text(
json.dumps(
{
"score": _cplx.score,
"signals": _cplx.signals,
"recommendation": _cplx.recommendation,
"reason": _cplx.reason,
"threshold": _oc_cfg.complexity_threshold,
"historical_failures": _hist_failures,
},
indent=2,
),
encoding="utf-8",
)
if _cplx.recommendation == "beast_mode":
_proceed = _oc_cfg.auto
if not _proceed:
# Non-auto mode: check for HITL adapter
if adapters.hitl is not None:
try:
_proceed = adapters.hitl.confirm(
f"Beast Mode: complexity={_cplx.score:.2f} "
f"(threshold={_oc_cfg.complexity_threshold}). "
f"Route to OpenCode?"
)
except Exception: # noqa: BLE001
logger.info(
"Beast mode: HITL adapter unavailable, skipping "
"(set opencode.auto=true for non-interactive runs)"
)
else:
logger.info(
"Beast mode: no HITL adapter, skipping "
"(set opencode.auto=true for non-interactive runs)"
)
if _proceed:
_oc_model = _oc_cfg.model or config.llm.primary_model
_bridge = OpenCodeBridge(
model=_oc_model,
llm_base_url=config.llm.base_url,
api_key_env=config.llm.api_key_env,
llm_provider=config.llm.provider,
timeout_sec=_oc_cfg.timeout_sec,
max_retries=_oc_cfg.max_retries,
workspace_cleanup=_oc_cfg.workspace_cleanup,
)
logger.info(
"Beast mode: ENGAGED (complexity=%.2f, model=%s)",
_cplx.score,
_oc_model,
)
_oc_result: OpenCodeResult = _bridge.generate(
stage_dir=stage_dir,
topic=config.research.topic,
exp_plan=exp_plan,
metric=metric,
pkg_hint=pkg_hint + "\n" + compute_budget,
extra_guidance=extra_guidance,
time_budget_sec=config.experiment.time_budget_sec,
)
# Persist beast mode log
(stage_dir / "beast_mode_log.json").write_text(
json.dumps(
{
"success": _oc_result.success,
"elapsed_sec": _oc_result.elapsed_sec,
"files": list(_oc_result.files.keys()),
"error": _oc_result.error,
"complexity_score": _cplx.score,
"model": _oc_model,
},
indent=2,
),
encoding="utf-8",
)
if _oc_result.success and _oc_result.files:
files = _oc_result.files
_beast_mode_used = True
_code_agent_active = True # skip legacy path
logger.info(
"Beast mode: SUCCESS — %d files in %.1fs",
len(files),
_oc_result.elapsed_sec,
)
else:
logger.warning(
"Beast mode: FAILED (%s) — falling back to CodeAgent",
_oc_result.error or "unknown error",
)
else:
logger.info(
"Beast mode: complexity=%.2f (threshold=%.2f), not triggered",
_cplx.score,
_oc_cfg.complexity_threshold,
)
if not _beast_mode_used and config.experiment.code_agent.enabled and llm is not None:
# ── F-02: Advanced Code Agent path ────────────────────────────────
from researchclaw.pipeline.code_agent import CodeAgent as _CodeAgent
_ca_cfg = config.experiment.code_agent
# Ensure we have a proper config object
if not hasattr(_ca_cfg, "enabled"):
from researchclaw.pipeline.code_agent import (
CodeAgentConfig as _CAConfig,
)
_ca_cfg = _CAConfig()
# Sandbox factory (only for sandbox/docker modes)
_sandbox_factory = None
if config.experiment.mode in ("sandbox", "docker"):
from researchclaw.experiment.factory import (
create_sandbox as _csb,
)
_sandbox_factory = _csb
if any(
config.llm.primary_model.startswith(p)
for p in ("gpt-5", "o3", "o4")
):
_code_max_tokens = 16384
# ── Domain detection + Code Search for non-ML domains ──────────
_domain_profile = None
_code_search_result = None
try:
from researchclaw.domains.detector import detect_domain as _dd
from researchclaw.domains.detector import is_ml_domain as _is_ml
_domain_profile = _dd(topic=config.research.topic)
logger.info(
"CodeAgent: domain=%s (%s)",
_domain_profile.display_name,
_domain_profile.domain_id,
)
# Run code search for non-ML domains (ML has enough built-in knowledge)
if not _is_ml(_domain_profile):
try:
from researchclaw.agents.code_searcher import CodeSearchAgent
_cs_agent = CodeSearchAgent(llm=llm)
_code_search_result = _cs_agent.search(
topic=config.research.topic,
domain=_domain_profile,
)
if _code_search_result and _code_search_result.patterns.has_content:
logger.info(
"Code search: %d patterns, %d repos found",
len(_code_search_result.patterns.api_patterns),
len(_code_search_result.repos_found),
)
except Exception: # noqa: BLE001
logger.debug("Code search unavailable", exc_info=True)
except Exception: # noqa: BLE001
logger.debug("Domain detection unavailable", exc_info=True)
_agent = _CodeAgent(
llm=llm,
prompts=_pm,
config=_ca_cfg,
stage_dir=stage_dir,
sandbox_factory=_sandbox_factory,
experiment_config=config.experiment,
domain_profile=_domain_profile,
code_search_result=_code_search_result,
)
_agent_result = _agent.generate(
topic=config.research.topic,
exp_plan=exp_plan,
metric=metric,
pkg_hint=pkg_hint + "\n" + compute_budget + "\n" + extra_guidance,
max_tokens=_code_max_tokens,
)
files = _agent_result.files
_code_agent_active = True
# Write agent artifacts
(stage_dir / "code_agent_log.json").write_text(
json.dumps(
{
"log": _agent_result.validation_log,
"llm_calls": _agent_result.total_llm_calls,
"sandbox_runs": _agent_result.total_sandbox_runs,
"best_score": _agent_result.best_score,
"tree_nodes_explored": _agent_result.tree_nodes_explored,
"review_rounds": _agent_result.review_rounds,
},
indent=2,
),
encoding="utf-8",
)
if _agent_result.architecture_spec:
(stage_dir / "architecture_spec.yaml").write_text(
_agent_result.architecture_spec, encoding="utf-8",
)
logger.info(
"CodeAgent: %d LLM calls, %d sandbox runs, score=%.2f",
_agent_result.total_llm_calls,
_agent_result.total_sandbox_runs,
_agent_result.best_score,
)
elif not _beast_mode_used and llm is not None:
# ── Legacy single-shot generation ─────────────────────────────────
topic = config.research.topic
_md = config.experiment.metric_direction
_md_hint = (
f"`{_md}` — use direction={'lower' if _md == 'minimize' else 'higher'} "
f"in METRIC_DEF. You MUST NOT use the opposite direction."
)
_overlay = _get_evolution_overlay(run_dir, "code_generation")
sp = _pm.for_stage(
"code_generation",
evolution_overlay=_overlay,
topic=topic,
metric=metric,
pkg_hint=pkg_hint + "\n" + compute_budget + "\n" + extra_guidance,
exp_plan=exp_plan,
metric_direction_hint=_md_hint,
)
# R13-3: Use higher max_tokens for reasoning models (they consume tokens
# for internal chain-of-thought). Retry once with even higher limit on empty.
_code_max_tokens = sp.max_tokens or 8192
if any(config.llm.primary_model.startswith(p) for p in ("gpt-5", "o3", "o4")):
_code_max_tokens = max(_code_max_tokens, 16384)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=_code_max_tokens,
)
files = _extract_multi_file_blocks(resp.content)
if not files and not resp.content.strip():
# Empty response — retry with higher token limit
logger.warning(
"R13-3: Empty LLM response for code_generation (len=%d, "
"finish_reason=%s, tokens=%d). Retrying with 32768 tokens.",
len(resp.content),
resp.finish_reason,
resp.total_tokens,
)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=32768,
)
files = _extract_multi_file_blocks(resp.content)
if not files:
logger.warning(
"R13-2: _extract_multi_file_blocks returned empty. "
"LLM response length=%d, first 300 chars: %s",
len(resp.content),
resp.content[:300],
)
# --- Fallback: generic numerical experiment ---
if not files:
files = {
"main.py": (
"import numpy as np\n"
"\n"
"np.random.seed(42)\n"
"\n"
"# Fallback experiment: parameter sweep on a synthetic objective\n"
"# This runs when LLM code generation fails to produce valid code.\n"
"dim = 10\n"
"n_conditions = 3\n"
"results = {}\n"
"\n"
"for cond_idx in range(n_conditions):\n"
" cond_name = f'condition_{cond_idx}'\n"
" scores = []\n"
" for seed in range(3):\n"
" rng = np.random.RandomState(seed + cond_idx * 100)\n"
" x = rng.randn(dim)\n"
" score = float(1.0 / (1.0 + np.sum(x ** 2)))\n"
" scores.append(score)\n"
" mean_score = float(np.mean(scores))\n"
" results[cond_name] = mean_score\n"
f" print(f'condition={{cond_name}} {metric}: {{mean_score:.6f}}')\n"
"\n"
"best = max(results, key=results.get)\n"
f"print(f'{metric}: {{results[best]:.6f}}')\n"
)
}
# --- Validate each file + auto-repair loop ---
all_valid = True
attempt = 0
for fname, code in list(files.items()):
# Skip non-Python files (requirements.txt, setup.py, etc.)
if not fname.endswith(".py"):
continue
validation = validate_code(code)
repair_attempt = 0
while not validation.ok and llm is not None and repair_attempt < max_repair:
repair_attempt += 1
attempt += 1
# Only send errors to the LLM — warnings don't block validation
# and confuse the LLM into over-correcting (e.g. removing runtime imports)
errors_only = type(validation)(
issues=[i for i in validation.issues if i.severity == "error"]
)
issues_text = format_issues_for_llm(errors_only)
validation_log.append(
f"File {fname} attempt {repair_attempt}: {validation.summary()}"
)
logger.info(
"Code validation failed for %s (attempt %d/%d): %s",
fname,
repair_attempt,
max_repair,
validation.summary(),
)
all_files_ctx = "\n\n".join(
f"```filename:{f}\n{c}\n```" for f, c in files.items()
)
rp = _pm.sub_prompt(
"code_repair",
fname=fname,
issues_text=issues_text,
all_files_ctx=all_files_ctx,
)
resp = _chat_with_prompt(llm, rp.system, rp.user)
_repaired = _extract_code_block(resp.content)
if _repaired.strip():
files[fname] = _repaired
else:
logger.warning("Repair attempt returned empty code, keeping original")
validation = validate_code(files[fname])
if not validation.ok:
all_valid = False
# BUG-14: Log remaining issues prominently
logger.warning(
"Code validation FAILED for %s after %d repair attempts: %s",
fname, max_repair, validation.summary(),
)
# Improvement G: RL algorithm-environment compatibility check
for fname, code in list(files.items()):
if not fname.endswith(".py"):
continue
_rl_errors = _check_rl_compatibility(code)
if _rl_errors:
for _rl_err in _rl_errors:
logger.error("Stage 10: %s (in %s)", _rl_err, fname)
validation_log.append(f"RL_COMPAT: {fname}: {_rl_err}")
all_valid = False
# BUG-14: Block on critical validation failures (syntax/import errors)
if not all_valid:
_has_critical = False
for fname, code in files.items():
_v = validate_code(code)
if not _v.ok:
for issue in _v.issues:
if issue.severity == "error" and issue.category in (
"syntax", "import",
):
_has_critical = True
if _has_critical:
logger.error(
"Stage 10: CRITICAL validation issues remain after %d repair "
"attempts. Blocking stage.", max_repair,
)
(stage_dir / "validation_report.md").write_text(
"# Code Validation Report\n\n"
f"**Status**: BLOCKED — critical issues remain after {max_repair} repairs\n\n"
+ "\n".join(f"- {e}" for e in validation_log),
encoding="utf-8",
)
return StageResult(
stage=Stage.CODE_GENERATION,
status=StageStatus.FAILED,
artifacts=("validation_report.md",),
evidence_refs=(),
)
# --- BUG-184: Cross-import validation — warn if a .py file imports a
# local module that doesn't exist in the files dict. This catches the
# case where Beast Mode/CodeAgent produced an intermediate file that
# got lost during repair iterations.
_known_modules = {
f.replace(".py", "") for f in files if f.endswith(".py")
}
_stdlib_and_common = {
"os", "sys", "json", "math", "time", "copy", "re", "random",
"pathlib", "argparse", "logging", "collections", "functools",
"itertools", "abc", "typing", "dataclasses", "enum", "io",
"csv", "pickle", "glob", "shutil", "subprocess", "datetime",
"numpy", "np", "torch", "torchvision", "gymnasium", "gym",
"sklearn", "scipy", "pandas", "matplotlib", "PIL", "tqdm",
"einops", "timm", "transformers", "datasets", "peft",
"stable_baselines3",
}
for fname, code in list(files.items()):
if not fname.endswith(".py"):
continue
for _m in re.findall(
r"^(?:from|import)\s+([a-zA-Z_][a-zA-Z0-9_]*)",
code, re.MULTILINE,
):
if (_m not in _known_modules
and _m not in _stdlib_and_common
and not _m.startswith("_")):
logger.warning(
"BUG-184: %s imports '%s' which is not in generated "
"files — experiment may crash on import",
fname, _m,
)
# --- Write experiment directory ---
exp_dir = stage_dir / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
for fname, code in files.items():
(exp_dir / fname).write_text(code, encoding="utf-8")
# --- Write validation report ---
if validation_log or not all_valid:
report_lines = ["# Code Validation Report\n"]
if all_valid:
report_lines.append(f"**Status**: PASSED after {attempt} total repair(s)\n")
else:
report_lines.append(
f"**Status**: FAILED after {attempt} total repair attempt(s)\n"
)
for entry in validation_log:
report_lines.append(f"- {entry}")
(stage_dir / "validation_report.md").write_text(
"\n".join(report_lines), encoding="utf-8"
)
# --- R10-Fix6: Code complexity and quality check ---
from researchclaw.experiment.validator import (
auto_fix_unbound_locals,
check_code_complexity,
deep_validate_files,
)
# --- BUG-3 fix: Programmatic auto-fix for UnboundLocalError patterns ---
_total_ub_fixes = 0
for fname, code in list(files.items()):
if fname.endswith(".py"):
fixed_code, n_fixes = auto_fix_unbound_locals(code)
if n_fixes > 0:
files[fname] = fixed_code
(exp_dir / fname).write_text(fixed_code, encoding="utf-8")
_total_ub_fixes += n_fixes
logger.info(
"Stage 10: auto-fixed %d UnboundLocalError risk(s) in %s",
n_fixes, fname,
)
if _total_ub_fixes:
logger.info(
"Stage 10: auto-fixed %d total UnboundLocalError risks", _total_ub_fixes
)
complexity_warnings: list[str] = []
for fname, code in files.items():
if fname.endswith(".py"):
cw = check_code_complexity(code)
for w in cw:
complexity_warnings.append(f"[{fname}] {w}")
logger.warning("Stage 10 code quality: [%s] %s", fname, w)
# --- P1.1+P1.2: Deep quality analysis (class quality, scoping, API) ---
deep_warnings = deep_validate_files(files)
for w in deep_warnings:
logger.warning("Stage 10 deep quality: %s", w)
complexity_warnings.extend(deep_warnings)
# --- P1.2: If critical deep issues found, attempt one repair cycle ---
critical_deep = [w for w in deep_warnings if any(
kw in w for kw in ("UnboundLocalError", "unregistered", "does not exist",
"empty or trivial subclass", "does NOT override",
"Import-usage mismatch", "NameError",
"was removed", "ptp()",
"copy-paste", "identical method signatures",
"identical AST", "NOT a real ablation",
"shadows stdlib/pip")
)]
if critical_deep and llm is not None:
logger.info(
"Stage 10: %d critical code issues found — triggering repair cycle",
len(critical_deep),
)
repair_issues = "\n".join(f"- {w}" for w in critical_deep)
all_code_ctx = "\n\n".join(
f"```filename:{f}\n{c}\n```" for f, c in files.items()
)
repair_prompt = (
f"CRITICAL CODE QUALITY ISSUES FOUND:\n{repair_issues}\n\n"
f"Fix ALL these issues in the code below. Return the complete "
f"corrected files using ```filename:xxx.py format.\n\n"
f"RULES:\n"
f"- nn.Linear/nn.Conv must be created in __init__(), not forward()\n"
f"- Variables used after if/else must be defined before the branch\n"
f"- Use scipy.special.erf, not np.erf\n"
f"- Ablation/variant classes must have genuinely different logic\n"
f"- Every class must have a real implementation, not just `pass`\n"
f"- Ablation classes MUST override the parent method that implements "
f"the component being ablated (e.g., if ablating attention, override "
f"the attention method with a simpler alternative like mean pooling)\n"
f"- IMPORT CONSISTENCY: if you write `from X import Y`, call `Y()` "
f"directly — NOT `X.Y()`. Mixing styles causes NameError.\n"
f"- NumPy 2.0: ndarray.ptp() was removed — use arr.max()-arr.min()\n"
f"- NumPy 2.0: np.bool/np.int/np.float removed — use builtins\n"
f"- Pretrained models (EfficientNet, ResNet, ViT) expect 224×224 input "
f"— add `transforms.Resize(224)` when using CIFAR (32×32) or similar\n"
f"- Copy-paste ablation: if two classes have identical bodies, REWRITE "
f"the ablation to genuinely remove/reduce a component (e.g., zero out "
f"attention weights, halve hidden dimensions, remove a loss term)\n"
f"- KD: teacher must be frozen, add projection layers if teacher_dim != "
f"student_dim, use temperature T=4 for soft targets\n"
f"- FILENAME COLLISIONS: If a file like config.py shadows a pip/stdlib "
f"package, rename it (e.g., config.py → experiment_config.py) and update "
f"ALL imports referencing it\n\n"
f"Current code:\n{all_code_ctx}\n"
)
try:
repair_resp = _chat_with_prompt(
llm,
_pm.system("code_generation"),
repair_prompt,
max_tokens=_code_max_tokens,
)
repaired = _extract_multi_file_blocks(repair_resp.content)
if repaired and "main.py" in repaired:
files = repaired
for fname, code in files.items():
(exp_dir / fname).write_text(code, encoding="utf-8")
# Re-check after repair
deep_warnings_after = deep_validate_files(files)
fixed = len(critical_deep) - len([
w for w in deep_warnings_after
if any(kw in w for kw in (
"UnboundLocalError", "unregistered", "does not exist",
"empty or trivial subclass", "does NOT override",
"Import-usage mismatch", "NameError",
"was removed", "ptp()",
"copy-paste", "identical method signatures",
"identical AST", "NOT a real ablation",
"shadows stdlib/pip",
))
])
logger.info(
"Stage 10: Deep repair fixed %d/%d critical issues",
fixed, len(critical_deep),
)
complexity_warnings.append(
f"[REPAIR] Deep repair fixed {fixed}/{len(critical_deep)} "
f"critical issues"
)
except Exception as exc:
logger.debug("Deep repair failed: %s", exc)
if complexity_warnings:
health: dict[str, Any] = {}
health["code_complexity_warnings"] = complexity_warnings
(stage_dir / "code_complexity.json").write_text(
json.dumps(health, indent=2), encoding="utf-8"
)
# --- P1.4: LLM Code Review (Stage 10.5) ---
# Skip when CodeAgent is active — Phase 4 review already covers this.
if llm is not None and not _code_agent_active:
all_code_review = "\n\n".join(
f"# --- {fname} ---\n{code}" for fname, code in files.items()
)
if len(all_code_review) > 12000:
all_code_review = all_code_review[:12000] + "\n... [truncated]"
review_prompt = (
f"You are a senior researcher reviewing experiment code for a "
f"research submission.\n\n"
f"TOPIC: {config.research.topic}\n"
f"EXPERIMENT PLAN:\n{exp_plan[:3000]}\n\n"
f"CODE:\n```python\n{all_code_review}\n```\n\n"
f"Review the code and return JSON with this EXACT structure:\n"
f'{{"score": <1-10>, "issues": ['
f'{{"severity": "critical|major|minor", '
f'"description": "...", "fix": "..."}}], '
f'"verdict": "pass|needs_fix"}}\n\n'
f"Check specifically:\n"
f"1. Does each algorithm/method have a DISTINCT implementation? "
f"(Not just renamed copies)\n"
f"2. Are ablation conditions genuinely different from the main method?\n"
f"3. Are loss functions / training loops mathematically correct?\n"
f"4. Will the code actually run without errors? Check variable scoping, "
f"API usage, tensor shape compatibility.\n"
f"5. Is the code complex enough for a research paper? (Not trivial)\n"
f"6. Are experimental conditions fairly compared (same seeds, data)?\n"
f"7. If using pretrained models (EfficientNet, ResNet, ViT), are input "
f"images resized to the model's expected size (e.g., 224x224)? CIFAR "
f"images are 32x32 and MUST be resized for pretrained models.\n"
f"8. Are imports consistent? `from X import Y` must use `Y()`, not `X.Y()`.\n"
)
try:
review_resp = llm.chat(
[{"role": "user", "content": review_prompt}],
system="You are a meticulous ML code reviewer. Be strict.",
max_tokens=2048,
)
# Extract JSON from LLM response (may be wrapped in markdown fences)
_review_text = review_resp.content if hasattr(review_resp, "content") else str(review_resp)
# Strip markdown JSON fences if present
_review_text = _review_text.strip()
if _review_text.startswith("```"):
_lines = _review_text.splitlines()
_start = 1 if _lines[0].strip().startswith("```") else 0
_end = len(_lines) - 1 if _lines[-1].strip() == "```" else len(_lines)
_review_text = "\n".join(_lines[_start:_end])
review_data = _safe_json_loads(_review_text, {})
if isinstance(review_data, dict):
review_score = review_data.get("score", 0)
review_verdict = review_data.get("verdict", "unknown")
review_issues = review_data.get("issues", [])
# Write review report
review_report = {
"score": review_score,
"verdict": review_verdict,
"issues": review_issues,
"timestamp": _utcnow_iso(),
}
(stage_dir / "code_review.json").write_text(
json.dumps(review_report, indent=2), encoding="utf-8"
)
# If critical issues found and score low, attempt fix
critical_issues = [
i for i in review_issues
if isinstance(i, dict)
and i.get("severity") == "critical"
]
if critical_issues and review_score <= 4:
logger.warning(
"Stage 10 code review: score=%d, %d critical issues — "
"attempting fix",
review_score, len(critical_issues),
)
fix_descriptions = "\n".join(
f"- [{i.get('severity', '?')}] {i.get('description', '?')}: "
f"{i.get('fix', 'no fix suggested')}"
for i in critical_issues
)
fix_prompt = (
f"Code review found {len(critical_issues)} CRITICAL issues "
f"(score: {review_score}/10):\n{fix_descriptions}\n\n"
f"Fix ALL critical issues. Return complete corrected files "
f"using ```filename:xxx.py format.\n\n"
f"Current code:\n"
+ "\n\n".join(
f"```filename:{f}\n{c}\n```" for f, c in files.items()
)
)
try:
fix_resp = _chat_with_prompt(
llm,
_pm.system("code_generation"),
fix_prompt,
max_tokens=_code_max_tokens,
)
fixed_files = _extract_multi_file_blocks(fix_resp.content)
if fixed_files and "main.py" in fixed_files:
files = fixed_files
for fname, code in files.items():
(exp_dir / fname).write_text(code, encoding="utf-8")
logger.info(
"Stage 10: Code fixed after review "
"(was %d/10, %d critical issues)",
review_score, len(critical_issues),
)
except Exception as exc:
logger.debug("Review-fix failed: %s", exc)
except Exception as exc:
logger.debug("Code review failed: %s", exc)
# --- FIX-3: Topic-experiment alignment check ---
# BUG-171: Previous 8000-char truncation caused false-positive misalignment
# for multi-file experiments (30-90K chars). LLM saw "[truncated]" and
# concluded code was incomplete. Fix: build a structured summary that
# includes file inventory + full main.py + per-file function/class headers.
alignment_ok = True
alignment_note = ""
if llm is not None:
# Build structured code summary for alignment check
_file_inventory = []
for _fn, _cd in files.items():
_lines = _cd.count("\n") + 1
_file_inventory.append(f" {_fn}: {_lines} lines, {len(_cd)} chars")
_inventory_block = "FILES GENERATED:\n" + "\n".join(_file_inventory)
# BUG-179: Beast Mode may use a different entry point (e.g.
# run_experiment.py). Detect the actual entry point by scanning
# for ``if __name__ == "__main__"`` in all files, preferring main.py.
_entry_file = "main.py"
if "main.py" not in files or not files.get("main.py", "").strip():
for _fn, _cd in files.items():
if 'if __name__' in _cd and '__main__' in _cd:
_entry_file = _fn
break
elif files.get("main.py", ""):
# main.py exists but may be a stub — if another file has the
# real orchestration (more lines + __main__ guard), prefer it
_main_lines = files["main.py"].count("\n")
for _fn, _cd in files.items():
if _fn == "main.py":
continue
if ('if __name__' in _cd and '__main__' in _cd
and _cd.count("\n") > _main_lines * 1.5):
_entry_file = _fn
break
_main_code = files.get(_entry_file, files.get("main.py", ""))
_main_block = f"# --- {_entry_file} (FULL — entry point) ---\n{_main_code}"
# Cap main.py at 12000 chars to stay within token budget
if len(_main_block) > 12000:
_main_block = _main_block[:12000] + "\n... [main.py truncated at 12000 chars]"
# For other files, include imports + function/class signatures
_other_summaries = []
for _fn, _cd in files.items():
if _fn == _entry_file:
continue
_sig_lines = []
for _line in _cd.split("\n"):
_stripped = _line.strip()
if (_stripped.startswith("def ") or _stripped.startswith("class ")
or _stripped.startswith("async def ")
# BUG-209: Include import lines — they reveal which
# techniques/libraries are used (e.g. CosineAnnealingLR)
or _stripped.startswith("import ")
or _stripped.startswith("from ")):
_sig_lines.append(_line)
if _sig_lines:
_other_summaries.append(
f"# --- {_fn} (imports + signatures) ---\n"
+ "\n".join(_sig_lines)
)
else:
# Small file — include first 800 chars
_preview = _cd[:800]
if len(_cd) > 800:
_preview += f"\n... [{len(_cd) - 800} more chars]"
_other_summaries.append(f"# --- {_fn} (preview) ---\n{_preview}")
_other_block = "\n\n".join(_other_summaries)
# Cap other summaries
if len(_other_block) > 6000:
_other_block = _other_block[:6000] + "\n... [other files truncated]"
all_code_for_check = (
f"{_inventory_block}\n\n{_main_block}\n\n{_other_block}"
)
align_prompt = (
f"Research topic: {config.research.topic}\n\n"
f"Experiment code:\n```python\n{all_code_for_check}\n```\n\n"
"TASK: Evaluate whether this experiment code actually tests the "
"stated research topic. Answer with JSON:\n"
'{"aligned": true/false, "reason": "...", "suggestions": "..."}\n\n'
"IMPORTANT: The code spans MULTIPLE files. The file inventory above "
"shows ALL generated files. Only main.py is shown in full; other "
"files show function/class signatures. Do NOT mark as misaligned "
"just because helper files are summarized — they contain full "
"implementations.\n\n"
"Check specifically:\n"
"- Does main.py orchestrate an experiment matching the topic?\n"
"- Do the helper file signatures indicate relevant models/methods?\n"
"- If the topic mentions a specific technique, is there evidence of "
"its implementation (function names, class names, imports)?\n"
"- Are the experimental conditions meaningfully different from each other?\n"
)
try:
align_resp = llm.chat(
[{"role": "user", "content": align_prompt}],
system="You are a scientific code reviewer checking topic-experiment alignment.",
max_tokens=1024,
)
align_data = _safe_json_loads(align_resp.content, {})
if isinstance(align_data, dict) and not align_data.get("aligned", True):
alignment_ok = False
alignment_note = align_data.get("reason", "Misaligned")
suggestions = align_data.get("suggestions", "")
logger.warning(
"Stage 10: Topic-experiment MISALIGNMENT detected: %s",
alignment_note,
)
# BUG-R6-01: Allow up to 2 regeneration attempts with re-check.
_max_regen = 2
for _regen_attempt in range(1, _max_regen + 1):
logger.info(
"Stage 10: Alignment regen attempt %d/%d",
_regen_attempt, _max_regen,
)
regen_prompt = (
f"The experiment code you previously generated does NOT align "
f"with the research topic.\n\n"
f"TOPIC: {config.research.topic}\n"
f"MISALIGNMENT: {alignment_note}\n"
f"SUGGESTIONS: {suggestions}\n\n"
f"REGENERATE the experiment code to DIRECTLY test the stated "
f"topic. The code MUST implement the core technique described "
f"in the topic, not a generic proxy.\n\n"
f"CRITICAL CONSTRAINTS:\n"
f"- You MUST implement the EXACT algorithm/method from the topic.\n"
f"- Do NOT substitute a deep-learning proxy (ResNet, BERT, etc.) "
f"when the topic describes a tabular, bandit, or game-theoretic method.\n"
f"- Use ONLY lightweight CPU-friendly libraries (numpy, scipy, "
f"sklearn) unless the topic EXPLICITLY requires deep learning.\n"
f"- The experiment must be self-contained and runnable without GPU.\n\n"
f"{pkg_hint}\n{compute_budget}\n"
f"PLAN:\n{exp_plan}\n\n"
f"Return multiple files using ```filename:xxx.py format."
)
regen_resp = _chat_with_prompt(
llm,
system=_pm.system("code_generation"),
user=regen_prompt,
max_tokens=_code_max_tokens,
)
regen_files = _extract_multi_file_blocks(regen_resp.content)
if not regen_files or "main.py" not in regen_files:
logger.warning(
"Stage 10: Regen attempt %d produced no main.py",
_regen_attempt,
)
continue
files = regen_files
for fname, code in files.items():
(exp_dir / fname).write_text(code, encoding="utf-8")
# Re-check alignment on regenerated code (BUG-171 fix)
_rc_inv = []
for _fn, _cd in files.items():
_rc_inv.append(f" {_fn}: {_cd.count(chr(10))+1} lines")
_rc_main = files.get("main.py", "")
if len(_rc_main) > 12000:
_rc_main = _rc_main[:12000] + "\n... [truncated]"
_rc_sigs = []
for _fn, _cd in files.items():
if _fn == "main.py":
continue
# BUG-209: Include imports alongside signatures
_slines = [l for l in _cd.split("\n")
if l.strip().startswith((
"def ", "class ", "async def ",
"import ", "from ",
))]
if _slines:
_rc_sigs.append(f"# {_fn} imports+signatures:\n" + "\n".join(_slines))
recheck_code = (
"FILES:\n" + "\n".join(_rc_inv) + "\n\n"
f"# main.py (FULL):\n{_rc_main}\n\n"
+ "\n".join(_rc_sigs)
)
recheck_resp = llm.chat(
[{"role": "user", "content": (
f"Research topic: {config.research.topic}\n\n"
f"Experiment code:\n```python\n{recheck_code}\n```\n\n"
"TASK: Evaluate whether this experiment code actually tests "
"the stated research topic. Only main.py is shown in full; "
"other files show signatures only. Answer with JSON:\n"
'{"aligned": true/false, "reason": "...", "suggestions": "..."}\n'
)}],
system="You are a scientific code reviewer checking topic-experiment alignment.",
max_tokens=1024,
)
recheck_data = _safe_json_loads(recheck_resp.content, {})
if isinstance(recheck_data, dict) and recheck_data.get("aligned", False):
alignment_ok = True
alignment_note = f"Regenerated after alignment check (attempt {_regen_attempt})"
logger.info(
"Stage 10: Code aligned after regen attempt %d",
_regen_attempt,
)
break
else:
alignment_note = recheck_data.get("reason", alignment_note)
suggestions = recheck_data.get("suggestions", suggestions)
logger.warning(
"Stage 10: Regen attempt %d still misaligned: %s",
_regen_attempt, alignment_note,
)
except Exception as exc:
logger.debug("Alignment check failed: %s", exc)
# --- FIX-7: Ablation distinctness check ---
main_code = files.get("main.py", "")
if llm is not None and main_code and "condition" in main_code.lower():
try:
ablation_prompt = (
f"Examine this experiment code:\n```python\n{main_code[:6000]}\n```\n\n"
"Check if any experimental conditions (methods/ablations) have "
"IDENTICAL configurations (same hyperparameters, same code paths). "
"Answer JSON: "
'{"has_duplicates": true/false, "details": "which conditions are identical"}'
)
abl_resp = llm.chat(
[{"role": "user", "content": ablation_prompt}],
system="You are a code reviewer checking experimental conditions.",
max_tokens=512,
)
abl_data = _safe_json_loads(abl_resp.content, {})
if isinstance(abl_data, dict) and abl_data.get("has_duplicates"):
logger.warning(
"Stage 10: Duplicate ablation conditions detected: %s",
abl_data.get("details", ""),
)
(stage_dir / "ablation_warning.json").write_text(
json.dumps(abl_data, indent=2), encoding="utf-8"
)
# --- Attempt ablation repair ---
all_code_ctx = "\n\n".join(
f"```filename:{f}\n{c}\n```" for f, c in files.items()
)
dup_details = abl_data.get("details", "unknown")
abl_repair_prompt = (
f"ABLATION REPAIR REQUIRED — duplicate conditions detected:\n"
f"{dup_details}\n\n"
f"Rewrite the ablation/variant conditions so each one is "
f"GENUINELY DIFFERENT. Concrete strategies:\n"
f"- 'no_': REMOVE the component entirely "
f"(e.g., replace attention with mean pooling, remove a loss term)\n"
f"- 'reduced_capacity': HALVE hidden dimensions or layers\n"
f"- Different conditions MUST produce different outputs on the "
f"same input. Add a startup assertion that runs one forward pass "
f"per condition on identical input and prints:\n"
f" ABLATION_CHECK: vs outputs_differ=True\n\n"
f"Return ALL files using ```filename:xxx.py format.\n\n"
f"Current code:\n{all_code_ctx}\n"
)
try:
abl_repair_resp = _chat_with_prompt(
llm,
_pm.system("code_generation"),
abl_repair_prompt,
max_tokens=_code_max_tokens,
)
repaired_files = _extract_multi_file_blocks(
abl_repair_resp.content
)
if repaired_files and "main.py" in repaired_files:
files = repaired_files
for fname, code in files.items():
(exp_dir / fname).write_text(code, encoding="utf-8")
logger.info(
"Stage 10: Ablation repair applied — "
"rewrote duplicate conditions"
)
except Exception as exc:
logger.debug("Ablation repair failed: %s", exc)
except Exception as exc:
logger.debug("Ablation validation skipped: %s", exc)
# --- Write spec ---
file_list = ", ".join(f"`{f}`" for f in sorted(files.keys()))
main_validation = validate_code(files.get("main.py", ""))
_align_status = "ALIGNED" if alignment_ok else f"MISALIGNED: {alignment_note}"
spec = f"""# Experiment Specification
## Topic
{config.research.topic}
## Project Structure
Multi-file experiment project with {len(files)} file(s): {file_list}
## Entry Point
`main.py` \u2014 executed directly via sandbox
## Outputs
- `main.py` emits metric lines in `name: value` format
- Primary metric key: `{metric}`
## Topic-Experiment Alignment
{_align_status}
## Constraints
- Time budget per run: {config.experiment.time_budget_sec}s
- Max iterations: {config.experiment.max_iterations}
- Self-contained execution (no external data, no network)
- Validated: {main_validation.summary()}
## Generated
{_utcnow_iso()}
"""
(stage_dir / "experiment_spec.md").write_text(spec, encoding="utf-8")
artifacts = ["experiment/", "experiment_spec.md"]
if (stage_dir / "validation_report.md").exists():
artifacts.append("validation_report.md")
# BUG-R6-01: Fail stage if alignment check detected persistent mismatch
# after all regen attempts, instead of silently proceeding.
if not alignment_ok:
logger.error(
"Stage 10: Persistent topic-experiment misalignment after all "
"regen attempts. Failing stage. Reason: %s",
alignment_note,
)
return StageResult(
stage=Stage.CODE_GENERATION,
status=StageStatus.FAILED,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-10/{a}" for a in artifacts),
error=f"Topic-experiment misalignment: {alignment_note}",
)
return StageResult(
stage=Stage.CODE_GENERATION,
status=StageStatus.DONE,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-10/{a}" for a in artifacts),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_execution.py
================================================
"""Stages 11-13: Resource planning, experiment execution, and iterative refinement."""
from __future__ import annotations
import json
import logging
import math
import re
import time as _time
from pathlib import Path
from typing import Any
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.experiment.validator import (
CodeValidation,
format_issues_for_llm,
validate_code,
)
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain
from researchclaw.pipeline._helpers import (
StageResult,
_chat_with_prompt,
_detect_runtime_issues,
_ensure_sandbox_deps,
_extract_code_block,
_extract_multi_file_blocks,
_get_evolution_overlay,
_load_hardware_profile,
_parse_metrics_from_stdout,
_read_prior_artifact,
_safe_filename,
_safe_json_loads,
_utcnow_iso,
_write_stage_meta,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
def _execute_resource_planning(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
exp_plan = _read_prior_artifact(run_dir, "exp_plan.yaml") or ""
schedule: dict[str, Any] | None = None
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "resource_planning")
sp = _pm.for_stage("resource_planning", evolution_overlay=_overlay, exp_plan=exp_plan)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
parsed = _safe_json_loads(resp.content, {})
if isinstance(parsed, dict):
schedule = parsed
if schedule is None:
schedule = {
"tasks": [
{
"id": "baseline",
"name": "Run baseline",
"depends_on": [],
"gpu_count": 1,
"estimated_minutes": 20,
"priority": "high",
},
{
"id": "proposed",
"name": "Run proposed method",
"depends_on": ["baseline"],
"gpu_count": 1,
"estimated_minutes": 30,
"priority": "high",
},
],
"total_gpu_budget": 1,
"generated": _utcnow_iso(),
}
schedule.setdefault("generated", _utcnow_iso())
(stage_dir / "schedule.json").write_text(
json.dumps(schedule, indent=2), encoding="utf-8"
)
return StageResult(
stage=Stage.RESOURCE_PLANNING,
status=StageStatus.DONE,
artifacts=("schedule.json",),
evidence_refs=("stage-11/schedule.json",),
)
def _execute_experiment_run(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
from researchclaw.experiment.factory import create_sandbox
from researchclaw.experiment.runner import ExperimentRunner
schedule_text = _read_prior_artifact(run_dir, "schedule.json") or "{}"
# Try multi-file experiment directory first, fall back to single file
exp_dir_path = _read_prior_artifact(run_dir, "experiment/")
code_text = ""
if exp_dir_path and Path(exp_dir_path).is_dir():
main_path = Path(exp_dir_path) / "main.py"
if main_path.exists():
code_text = main_path.read_text(encoding="utf-8")
if not code_text:
code_text = _read_prior_artifact(run_dir, "experiment.py") or ""
runs_dir = stage_dir / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
mode = config.experiment.mode
if mode in ("sandbox", "docker"):
# P7: Auto-install missing dependencies before subprocess sandbox
if mode == "sandbox":
_all_code = code_text
if exp_dir_path and Path(exp_dir_path).is_dir():
for _pyf in Path(exp_dir_path).glob("*.py"):
try:
_all_code += "\n" + _pyf.read_text(encoding="utf-8")
except (OSError, UnicodeDecodeError):
pass
_ensure_sandbox_deps(_all_code, config.experiment.sandbox.python_path)
sandbox = create_sandbox(config.experiment, runs_dir / "sandbox")
# Use run_project for multi-file, run for single-file
if exp_dir_path and Path(exp_dir_path).is_dir():
result = sandbox.run_project(
Path(exp_dir_path), timeout_sec=config.experiment.time_budget_sec
)
else:
result = sandbox.run(
code_text, timeout_sec=config.experiment.time_budget_sec
)
# Try to read structured results.json from sandbox working dir
structured_results: dict[str, Any] | None = None
sandbox_project = runs_dir / "sandbox" / "_project"
results_json_path = sandbox_project / "results.json"
if results_json_path.exists():
try:
structured_results = json.loads(
results_json_path.read_text(encoding="utf-8")
)
# Copy results.json to runs dir for easy access
(runs_dir / "results.json").write_text(
results_json_path.read_text(encoding="utf-8"),
encoding="utf-8",
)
except (json.JSONDecodeError, OSError):
structured_results = None
# If sandbox metrics are empty, try to parse from stdout
effective_metrics = result.metrics
if not effective_metrics and result.stdout:
effective_metrics = _parse_metrics_from_stdout(result.stdout)
# Determine run status: completed / partial (timed out with data) / failed
# R6-2: Detect stdout failure signals even when exit code is 0
_stdout_has_failure = bool(
result.stdout
and not effective_metrics
and any(
sig in result.stdout
for sig in ("FAIL:", "NaN/divergence", "Traceback (most recent")
)
)
if result.returncode == 0 and not result.timed_out and not _stdout_has_failure:
run_status = "completed"
elif result.timed_out and effective_metrics:
run_status = "partial"
logger.warning(
"Experiment timed out but captured %d partial metrics",
len(effective_metrics),
)
else:
run_status = "failed"
if _stdout_has_failure:
logger.warning(
"Experiment exited cleanly but stdout contains failure signals"
)
# P1: Warn if experiment completed suspiciously fast (trivially easy benchmark)
if run_status == "completed" and result.elapsed_sec and result.elapsed_sec < 5.0:
logger.warning(
"Stage 12: Experiment completed in %.2fs — benchmark may be trivially easy. "
"Consider increasing task difficulty.",
result.elapsed_sec,
)
run_payload: dict[str, Any] = {
"run_id": "run-1",
"task_id": "sandbox-main",
"status": run_status,
"metrics": effective_metrics,
"elapsed_sec": result.elapsed_sec,
"stdout": result.stdout,
"stderr": result.stderr,
"timed_out": result.timed_out,
"completed_at": _utcnow_iso(),
}
if structured_results is not None:
run_payload["structured_results"] = structured_results
# Auto-generate results.json from parsed metrics if sandbox didn't produce one
if structured_results is None and effective_metrics:
auto_results = {"source": "stdout_parsed", "metrics": effective_metrics}
(runs_dir / "results.json").write_text(
json.dumps(auto_results, indent=2), encoding="utf-8"
)
logger.info("Stage 12: Auto-generated results.json from stdout metrics (%d keys)", len(effective_metrics))
(runs_dir / "run-1.json").write_text(
json.dumps(run_payload, indent=2), encoding="utf-8"
)
# R11-6: Time budget adequacy check
if result.timed_out or (result.elapsed_sec and result.elapsed_sec > config.experiment.time_budget_sec * 0.9):
# Parse stdout to estimate how many conditions/seeds completed
_stdout = result.stdout or ""
_completed_conditions = set()
_completed_seeds = 0
for _line in _stdout.splitlines():
if "condition=" in _line and "seed=" in _line:
_completed_seeds += 1
_cond_match = re.match(r".*condition=(\S+)", _line)
if _cond_match:
_completed_conditions.add(_cond_match.group(1))
_time_budget_warning = {
"timed_out": result.timed_out,
"elapsed_sec": result.elapsed_sec,
"budget_sec": config.experiment.time_budget_sec,
"conditions_completed": sorted(_completed_conditions),
"total_seed_runs": _completed_seeds,
"warning": (
f"Experiment used {result.elapsed_sec:.0f}s of "
f"{config.experiment.time_budget_sec}s budget. "
f"Only {len(_completed_conditions)} conditions completed "
f"({_completed_seeds} seed-runs). Consider increasing "
f"time_budget_sec for more complete results."
),
}
logger.warning(
"Stage 12: %s", _time_budget_warning["warning"]
)
(stage_dir / "time_budget_warning.json").write_text(
json.dumps(_time_budget_warning, indent=2), encoding="utf-8"
)
# FIX-8: Validate seed count from structured results
if structured_results and isinstance(structured_results, dict):
_sr_conditions = structured_results.get("conditions", structured_results.get("per_condition", {}))
if isinstance(_sr_conditions, dict):
for _cname, _cdata in _sr_conditions.items():
if isinstance(_cdata, dict):
_seeds_run = _cdata.get("seeds_run", _cdata.get("n_seeds", 0))
if isinstance(_seeds_run, (int, float)) and 0 < _seeds_run < 3:
logger.warning(
"Stage 12: Condition '%s' ran only %d seed(s) — "
"minimum 3 required for statistical validity",
_cname, int(_seeds_run),
)
elif mode == "simulated":
schedule = _safe_json_loads(schedule_text, {})
tasks = schedule.get("tasks", []) if isinstance(schedule, dict) else []
if not isinstance(tasks, list):
tasks = []
for idx, task in enumerate(tasks or [{"id": "task-1", "name": "simulated"}]):
task_id = (
str(task.get("id", f"task-{idx + 1}"))
if isinstance(task, dict)
else f"task-{idx + 1}"
)
payload = {
"run_id": f"run-{idx + 1}",
"task_id": task_id,
"status": "simulated",
"key_metrics": {
config.experiment.metric_key: round(0.3 + idx * 0.03, 4),
"secondary_metric": round(0.6 - idx * 0.04, 4),
},
"notes": "Simulated run result",
"completed_at": _utcnow_iso(),
}
run_id = str(payload["run_id"])
(runs_dir / f"{_safe_filename(run_id)}.json").write_text(
json.dumps(payload, indent=2), encoding="utf-8"
)
else:
runner = ExperimentRunner(config.experiment, runs_dir / "workspace")
history = runner.run_loop(code_text, run_id=f"exp-{run_dir.name}", llm=llm)
runner.save_history(stage_dir / "experiment_history.json")
for item in history.results:
payload = {
"run_id": f"run-{item.iteration}",
"task_id": item.run_id,
"status": "completed" if item.error is None else "failed",
"metrics": item.metrics,
"primary_metric": item.primary_metric,
"improved": item.improved,
"kept": item.kept,
"elapsed_sec": item.elapsed_sec,
"error": item.error,
"completed_at": _utcnow_iso(),
}
run_id = str(payload["run_id"])
(runs_dir / f"{_safe_filename(run_id)}.json").write_text(
json.dumps(payload, indent=2), encoding="utf-8"
)
return StageResult(
stage=Stage.EXPERIMENT_RUN,
status=StageStatus.DONE,
artifacts=("runs/",),
evidence_refs=("stage-12/runs/",),
)
def _execute_iterative_refine(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
from researchclaw.experiment.factory import create_sandbox
from researchclaw.experiment.validator import format_issues_for_llm, validate_code
def _to_float(value: Any) -> float | None:
try:
if value is None:
return None
f = float(value)
# BUG-EX-01: NaN/Inf block all future improvement detection
if math.isnan(f) or math.isinf(f):
return None
return f
except (TypeError, ValueError):
return None
# R10-Fix3: Skip iterative refinement in simulated mode (no real execution)
if config.experiment.mode == "simulated":
logger.info(
"Stage 13: Skipping iterative refinement in simulated mode "
"(no real code execution available)"
)
import shutil
final_dir = stage_dir / "experiment_final"
# Copy latest experiment code as final (directory or single file)
copied = False
for stage_num in (12, 10):
src_dir = run_dir / f"stage-{stage_num:02d}" / "experiment"
if src_dir.is_dir():
if final_dir.exists():
shutil.rmtree(final_dir)
shutil.copytree(src_dir, final_dir)
copied = True
break
# Also check for single experiment.py
src_file = run_dir / f"stage-{stage_num:02d}" / "experiment.py"
if src_file.is_file():
(stage_dir / "experiment_final.py").write_text(
src_file.read_text(encoding="utf-8"), encoding="utf-8"
)
copied = True
break
log: dict[str, Any] = {
"generated": _utcnow_iso(),
"mode": "simulated",
"skipped": True,
"skip_reason": "Iterative refinement not meaningful in simulated mode",
"metric_key": config.experiment.metric_key,
}
(stage_dir / "refinement_log.json").write_text(
json.dumps(log, indent=2), encoding="utf-8"
)
return StageResult(
stage=Stage.ITERATIVE_REFINE,
status=StageStatus.DONE,
artifacts=("refinement_log.json",),
evidence_refs=(),
)
metric_key = config.experiment.metric_key
metric_direction = config.experiment.metric_direction
# P9: Detect metric direction mismatch between config and experiment code.
# The code-gen stage instructs experiments to print a line like:
# METRIC_DEF: primary_metric | direction=higher | desc=...
# Log a warning if mismatch is detected, but trust the config value
# (BUG-06 fix: no longer auto-override, since Stage 9 and 12 now
# explicitly enforce config.metric_direction in prompts).
_runs_dir_detect = _read_prior_artifact(run_dir, "runs/")
if _runs_dir_detect and Path(_runs_dir_detect).is_dir():
import re as _re_detect
for _rf in sorted(Path(_runs_dir_detect).glob("*.json"))[:5]:
try:
_rp = _safe_json_loads(_rf.read_text(encoding="utf-8"), {})
_stdout = _rp.get("stdout", "") if isinstance(_rp, dict) else ""
_match = _re_detect.search(
r"METRIC_DEF:.*direction\s*=\s*(higher|lower)", _stdout
)
if _match:
_detected = _match.group(1)
_detected_dir = "maximize" if _detected == "higher" else "minimize"
if _detected_dir != metric_direction:
logger.warning(
"P9: Metric direction mismatch — config says '%s' but "
"experiment code declares 'direction=%s'. "
"Keeping config value '%s'. Code will be "
"corrected in next refinement cycle.",
metric_direction,
_detected,
metric_direction,
)
break
except OSError:
pass
maximize = metric_direction == "maximize"
def _is_better(candidate: float | None, current: float | None) -> bool:
if candidate is None:
return False
if current is None:
return True
return candidate > current if maximize else candidate < current
def _find_metric(metrics: dict[str, object], key: str) -> float | None:
"""R13-4: Find metric value with fuzzy key matching.
Tries exact match first, then looks for aggregate keys that contain
the metric name (e.g. 'primary_metric_mean' when key='primary_metric').
"""
# Exact match
val = _to_float(metrics.get(key))
if val is not None:
return val
# Try aggregate/mean keys containing the metric name
# Prefer keys ending with the metric name or containing '_mean'
candidates: list[tuple[str, float]] = []
for mk, mv in metrics.items():
fv = _to_float(mv)
if fv is None:
continue
if mk == key or mk.endswith(f"/{key}"):
return fv # Exact match via condition prefix
if key in mk and ("mean" in mk or "avg" in mk):
candidates.append((mk, fv))
elif mk.endswith(f"_{key}") or mk.endswith(f"/{key}_mean"):
candidates.append((mk, fv))
if candidates:
# Take the aggregate mean if available, otherwise first match
for ck, cv in candidates:
if "mean" in ck:
return cv
return candidates[0][1]
# Last resort: if there's an "overall" or root-level aggregate
for mk, mv in metrics.items():
fv = _to_float(mv)
if fv is not None and key in mk and "/" not in mk and "seed" not in mk:
return fv
return None
requested_iterations = int(getattr(config.experiment, "max_iterations", 10) or 10)
max_iterations = max(1, min(requested_iterations, 10))
# BUG-57: Wall-clock time cap for the entire refinement stage.
# Default: 3× the per-iteration time budget (e.g., 2400s → 7200s = 2h).
import time as _time_bug57
_refine_start_time = _time_bug57.monotonic()
_per_iter_budget = int(getattr(config.experiment, "time_budget_sec", 2400) or 2400)
_max_refine_wall_sec = int(
getattr(config.experiment, "max_refine_duration_sec", 0) or 0
) or int(_per_iter_budget * 1.5)
# --- Collect baseline metrics from prior runs ---
runs_dir_path: Path | None = None
runs_dir_text = _read_prior_artifact(run_dir, "runs/")
if runs_dir_text:
runs_dir_path = Path(runs_dir_text)
run_summaries: list[str] = []
baseline_metric: float | None = None
if runs_dir_path is not None:
for run_file in sorted(runs_dir_path.glob("*.json"))[:40]:
payload = _safe_json_loads(run_file.read_text(encoding="utf-8"), {})
if not isinstance(payload, dict):
continue
# R5-5: Truncate stdout/stderr for context efficiency
summary = dict(payload)
if "stdout" in summary and isinstance(summary["stdout"], str):
lines = summary["stdout"].splitlines()
if len(lines) > 30:
summary["stdout"] = (
f"[...truncated {len(lines) - 30} lines...]\n"
+ "\n".join(lines[-30:])
)
if len(summary["stdout"]) > 2000:
summary["stdout"] = summary["stdout"][-2000:]
if "stderr" in summary and isinstance(summary["stderr"], str):
lines = summary["stderr"].splitlines()
if len(lines) > 50:
summary["stderr"] = "\n".join(lines[-50:])
if len(summary["stderr"]) > 2000:
summary["stderr"] = summary["stderr"][-2000:]
run_summaries.append(json.dumps(summary, ensure_ascii=False))
metrics = payload.get("metrics")
if not isinstance(metrics, dict):
metrics = (
payload.get("key_metrics")
if isinstance(payload.get("key_metrics"), dict)
else {}
)
metric_val = (
_find_metric(metrics, metric_key)
if isinstance(metrics, dict)
else None
)
if metric_val is None:
metric_val = _to_float(payload.get("primary_metric"))
if _is_better(metric_val, baseline_metric):
baseline_metric = metric_val
# --- Read experiment project (multi-file or single-file) ---
# BUG-58: When PIVOT rolls back to Stage 13, prefer the best refined code
# from a previous cycle (stage-13_vX/experiment_final/) over the original
# unrefined code (stage-12/experiment/ or stage-10/experiment/).
# Enhanced: try ALL versioned directories (latest first) with fallback chain.
exp_dir_text: str | None = None
_prev_refine_dirs = sorted(
run_dir.glob("stage-13_v*/experiment_final"),
key=lambda p: p.parent.name,
reverse=True, # latest version first
)
# BUG-58 fix: Find the best version across ALL cycles (not just latest)
_best_prev_metric: float | None = None
_best_prev_dir: Path | None = None
for _prd in _prev_refine_dirs:
if not _prd.is_dir():
continue
_prd_log = _prd.parent / "refinement_log.json"
if _prd_log.is_file():
_prd_data = _safe_json_loads(
_prd_log.read_text(encoding="utf-8"), {}
)
_prd_metric = _prd_data.get("best_metric") if isinstance(_prd_data, dict) else None
if isinstance(_prd_metric, (int, float)) and _is_better(_prd_metric, _best_prev_metric):
_best_prev_metric = _prd_metric
_best_prev_dir = _prd
elif _best_prev_dir is None:
# No log but directory exists — use as fallback
_best_prev_dir = _prd
if _best_prev_dir is not None:
exp_dir_text = str(_best_prev_dir)
logger.info(
"BUG-58: Recovered best refined code from PIVOT cycle: %s (metric=%s)",
_best_prev_dir.parent.name,
f"{_best_prev_metric:.4f}" if _best_prev_metric is not None else "N/A",
)
if not exp_dir_text:
exp_dir_text = _read_prior_artifact(run_dir, "experiment/")
best_files: dict[str, str] = {}
if exp_dir_text and Path(exp_dir_text).is_dir():
# BUG-EX-02: Load ALL text files (not just .py) — requirements.txt,
# setup.py, config files are needed for Docker sandbox phases.
for src_file in sorted(Path(exp_dir_text).iterdir()):
if src_file.is_file() and src_file.suffix in (
".py", ".txt", ".yaml", ".yml", ".json", ".cfg", ".ini", ".sh",
):
try:
best_files[src_file.name] = src_file.read_text(encoding="utf-8")
except UnicodeDecodeError:
pass # skip binary files
if not best_files:
# Backward compat: single experiment.py
original_code = _read_prior_artifact(run_dir, "experiment.py") or ""
if original_code:
best_files = {"main.py": original_code}
# --- Detect if prior experiment timed out ---
prior_timed_out = False
prior_time_budget = config.experiment.time_budget_sec
if runs_dir_path is not None:
for run_file in sorted(runs_dir_path.glob("*.json"))[:5]:
try:
payload = _safe_json_loads(run_file.read_text(encoding="utf-8"), {})
if isinstance(payload, dict) and payload.get("timed_out"):
prior_timed_out = True
break
except OSError:
pass
best_metric = baseline_metric
best_version = "experiment/"
# BUG-58: Recover best_metric from best previous PIVOT cycle
if _best_prev_metric is not None and _is_better(_best_prev_metric, best_metric):
best_metric = _best_prev_metric
logger.info(
"BUG-58: Recovered best_metric=%.4f from previous PIVOT",
best_metric,
)
no_improve_streak = 0
consecutive_no_metrics = 0
log: dict[str, Any] = {
"generated": _utcnow_iso(),
"mode": config.experiment.mode,
"metric_key": metric_key,
"metric_direction": metric_direction,
"max_iterations_requested": requested_iterations,
"max_iterations_executed": max_iterations,
"baseline_metric": baseline_metric,
"project_files": list(best_files.keys()),
"iterations": [],
"converged": False,
"stop_reason": "max_iterations_reached",
}
# --- Helper: write files to a directory ---
def _write_project(target_dir: Path, project_files: dict[str, str]) -> None:
target_dir.mkdir(parents=True, exist_ok=True)
for fname, code in project_files.items():
(target_dir / fname).write_text(code, encoding="utf-8")
# --- Helper: format all files for LLM context ---
def _files_to_context(project_files: dict[str, str]) -> str:
parts = []
for fname, code in sorted(project_files.items()):
parts.append(f"```filename:{fname}\n{code}\n```")
return "\n\n".join(parts)
if llm is None:
logger.info("Stage 13: LLM unavailable, saving original experiment as final")
final_dir = stage_dir / "experiment_final"
_write_project(final_dir, best_files)
# Backward compat
if "main.py" in best_files:
(stage_dir / "experiment_final.py").write_text(
best_files["main.py"], encoding="utf-8"
)
log.update(
{
"converged": True,
"stop_reason": "llm_unavailable",
"best_metric": best_metric,
"best_version": "experiment_final/",
"iterations": [
{
"iteration": 0,
"version_dir": "experiment_final/",
"source": "fallback_original",
"metric": best_metric,
}
],
}
)
(stage_dir / "refinement_log.json").write_text(
json.dumps(log, indent=2), encoding="utf-8"
)
artifacts = ("refinement_log.json", "experiment_final/")
return StageResult(
stage=Stage.ITERATIVE_REFINE,
status=StageStatus.DONE,
artifacts=artifacts,
evidence_refs=tuple(f"stage-13/{a}" for a in artifacts),
)
_pm = prompts or PromptManager()
timeout_refine_attempts = 0
# R7-3: Read experiment plan to detect condition coverage gaps
_exp_plan_text = _read_prior_artifact(run_dir, "exp_plan.yaml") or ""
_condition_coverage_hint = ""
if _exp_plan_text and run_summaries:
# Check if stdout contains condition labels
_all_stdout = " ".join(run_summaries)
_has_condition_labels = "condition=" in _all_stdout
if not _has_condition_labels and _exp_plan_text.strip():
_condition_coverage_hint = (
"\nCONDITION COVERAGE GAP DETECTED:\n"
"The experiment plan specifies multiple conditions/treatments, "
"but the output contains NO condition labels (no 'condition=...' in stdout).\n"
"You MUST:\n"
"1. Run ALL conditions/treatments from the experiment plan independently\n"
"2. Label each metric output: `condition= {metric_key}: `\n"
"3. Print a SUMMARY line comparing all conditions after completion\n"
"This is the MOST IMPORTANT improvement — a single unlabeled metric stream "
"cannot support any comparative conclusions.\n\n"
)
logger.info(
"Stage 13: condition coverage gap detected, injecting multi-condition hint"
)
# P1: Track metrics history for saturation detection
_metrics_history: list[float | None] = [baseline_metric]
for iteration in range(1, max_iterations + 1):
# BUG-57: Check wall-clock time before starting a new iteration
_elapsed = _time_bug57.monotonic() - _refine_start_time
if _elapsed > _max_refine_wall_sec:
logger.warning(
"Stage 13: Wall-clock time cap reached (%.0fs > %ds). "
"Stopping refinement after %d iterations.",
_elapsed, _max_refine_wall_sec, iteration - 1,
)
log["stop_reason"] = "wall_clock_time_cap"
break
logger.info("Stage 13: refinement iteration %d/%d (%.0fs elapsed, cap %ds)",
iteration, max_iterations, _elapsed, _max_refine_wall_sec)
# P1: Detect metric saturation and inject difficulty upgrade hint
_saturation_hint = ""
_valid_metrics = [m for m in _metrics_history if m is not None]
if len(_valid_metrics) >= 2:
_last_two = _valid_metrics[-2:]
_saturated = False
# Use relative change rate instead of hard-coded thresholds
_change_rate = abs(_last_two[-1] - _last_two[-2]) / max(abs(_last_two[-2]), 1e-8)
if metric_direction == "minimize":
_saturated = all(m <= 0.001 for m in _last_two) or (
_change_rate < 0.001 and _last_two[-1] < 0.01
)
else:
_saturated = all(m >= 0.999 for m in _last_two) or (
_change_rate < 0.001 and _last_two[-1] > 0.99
)
if _saturated:
_saturation_hint = (
"\n\nWARNING — BENCHMARK SATURATION DETECTED:\n"
"All methods achieve near-perfect scores, making the task too easy "
"to discriminate between methods.\n"
"YOU MUST increase benchmark difficulty in this iteration:\n"
"1. Increase the number of actions/decisions from 8 to at least 20\n"
"2. Increase the horizon from 12-18 to at least 50-100 steps\n"
"3. Increase noise level to at least 0.3-0.5\n"
"4. Add partial observability (agent cannot see full state)\n"
"5. Add delayed rewards (reward only at episode end)\n"
"6. Ensure random search achieves < 50% success rate\n"
"Without this change, the experiment produces meaningless results.\n"
)
logger.warning("Stage 13: metric saturation detected, injecting difficulty upgrade hint")
files_context = _files_to_context(best_files)
# BUG-10 fix: anchor refinement to original experiment plan
_exp_plan_anchor = ""
if _exp_plan_text.strip():
_exp_plan_anchor = (
"Original experiment plan (exp_plan.yaml):\n"
"```yaml\n" + _exp_plan_text[:4000] + "\n```\n"
"You MUST preserve ALL condition names from this plan.\n\n"
)
ip = _pm.sub_prompt(
"iterative_improve",
metric_key=metric_key,
metric_direction=metric_direction,
files_context=files_context,
run_summaries=chr(10).join(run_summaries[:20]),
condition_coverage_hint=_condition_coverage_hint,
topic=config.research.topic,
exp_plan_anchor=_exp_plan_anchor,
)
# --- Timeout-aware prompt injection ---
user_prompt = ip.user + _saturation_hint
if prior_timed_out and baseline_metric is None:
timeout_refine_attempts += 1
timeout_hint = (
f"\n\nCRITICAL: The experiment TIMED OUT after {prior_time_budget}s "
f"with NO results. You MUST drastically reduce the experiment scale:\n"
f"- Reduce total runs to ≤50\n"
f"- Reduce steps per run to ≤2000\n"
f"- Remove conditions that are not essential\n"
f"- Add time.time() checks to stop gracefully before timeout\n"
f"- Print intermediate metrics frequently so partial data is captured\n"
f"- Time budget is {prior_time_budget}s — design for ≤{int(prior_time_budget * 0.7)}s\n"
)
user_prompt = user_prompt + timeout_hint
logger.warning(
"Stage 13: injecting timeout-aware prompt (attempt %d)",
timeout_refine_attempts,
)
response = _chat_with_prompt(
llm,
ip.system,
user_prompt,
max_tokens=ip.max_tokens or 8192,
)
extracted_files = _extract_multi_file_blocks(response.content)
# If LLM returns only single block, treat as main.py update
if not extracted_files:
single_code = _extract_code_block(response.content)
if single_code.strip():
extracted_files = {"main.py": single_code}
# R8-2: Merge with best_files to preserve supporting modules
# (e.g., graphs.py, game.py) that the LLM didn't rewrite
candidate_files = dict(best_files)
if extracted_files:
candidate_files.update(extracted_files)
# If LLM returned nothing at all, candidate_files == best_files (unchanged)
# BUG-R6-02: Preserve entry point when LLM strips main() function.
# The LLM often returns only class/function improvements without the
# main() entry point, causing the script to exit with no output.
_new_main = candidate_files.get("main.py", "")
_old_main = best_files.get("main.py", "")
if (
_new_main
and _old_main
and "if __name__" not in _new_main
and "if __name__" in _old_main
):
# Extract the entry-point block from original main.py
_ep_idx = _old_main.rfind("\ndef main(")
if _ep_idx == -1:
_ep_idx = _old_main.rfind("\nif __name__")
if _ep_idx != -1:
_entry_block = _old_main[_ep_idx:]
candidate_files["main.py"] = _new_main.rstrip() + "\n\n" + _entry_block
logger.info(
"Stage 13 iter %d: restored entry point stripped by LLM "
"(%d chars appended from original main.py)",
iteration,
len(_entry_block),
)
# Validate main.py
main_code = candidate_files.get("main.py", "")
validation = validate_code(main_code)
issue_text = ""
repaired = False
if not validation.ok:
issue_text = format_issues_for_llm(validation)
logger.info(
"Stage 13 iteration %d validation failed: %s",
iteration,
validation.summary(),
)
irp = _pm.sub_prompt(
"iterative_repair",
issue_text=issue_text,
all_files_ctx=_files_to_context(candidate_files),
)
repair_response = _chat_with_prompt(llm, irp.system, irp.user)
candidate_files["main.py"] = _extract_code_block(repair_response.content)
validation = validate_code(candidate_files["main.py"])
repaired = True
# Save version directory
version_dir = stage_dir / f"experiment_v{iteration}"
_write_project(version_dir, candidate_files)
iter_record: dict[str, Any] = {
"iteration": iteration,
"version_dir": f"experiment_v{iteration}/",
"files": list(candidate_files.keys()),
"validation_ok": validation.ok,
"validation_summary": validation.summary(),
"repaired": repaired,
"metric": None,
"improved": False,
}
if issue_text:
iter_record["validation_issues"] = issue_text
metric_val = None # R6-3: initialize before conditional block
if validation.ok and config.experiment.mode in ("sandbox", "docker"):
# P7: Ensure deps for refined code (subprocess sandbox only)
if config.experiment.mode == "sandbox":
_refine_code = "\n".join(candidate_files.values())
_ensure_sandbox_deps(_refine_code, config.experiment.sandbox.python_path)
sandbox = create_sandbox(
config.experiment,
stage_dir / f"refine_sandbox_v{iteration}",
)
rerun = sandbox.run_project(
version_dir,
timeout_sec=config.experiment.time_budget_sec,
)
metric_val = _find_metric(rerun.metrics, metric_key)
# R19-1: Store stdout (capped) so PAIRED lines survive for Stage 14
_stdout_cap = rerun.stdout[:50000] if rerun.stdout else ""
iter_record["sandbox"] = {
"returncode": rerun.returncode,
"metrics": rerun.metrics,
"elapsed_sec": rerun.elapsed_sec,
"timed_out": rerun.timed_out,
"stderr": rerun.stderr[:2000] if rerun.stderr else "",
"stdout": _stdout_cap,
}
iter_record["metric"] = metric_val
# BUG-110: Parse ABLATION_CHECK lines from stdout
if rerun.stdout:
import re as _re_ablation
_ablation_checks = _re_ablation.findall(
r"ABLATION_CHECK:\s*(\S+)\s+vs\s+(\S+)\s+outputs_differ=(True|False)",
rerun.stdout,
)
if _ablation_checks:
_identical_pairs = [
(c1, c2) for c1, c2, diff in _ablation_checks if diff == "False"
]
iter_record["ablation_checks"] = [
{"cond1": c1, "cond2": c2, "differ": diff == "True"}
for c1, c2, diff in _ablation_checks
]
if _identical_pairs:
_pairs_str = ", ".join(f"{c1} vs {c2}" for c1, c2 in _identical_pairs)
logger.warning(
"BUG-110: Identical ablation outputs detected: %s. "
"Ablation conditions may not be wired correctly.",
_pairs_str,
)
iter_record["ablation_identical"] = True
# --- Track timeout in refine sandbox ---
if rerun.timed_out:
prior_timed_out = True
timeout_refine_attempts += 1
logger.warning(
"Stage 13 iteration %d: sandbox timed out after %.1fs",
iteration,
rerun.elapsed_sec,
)
# If still no metrics after timeout, use partial stdout metrics
if not rerun.metrics and rerun.stdout:
from researchclaw.experiment.sandbox import parse_metrics as _parse_sb_metrics
partial = _parse_sb_metrics(rerun.stdout)
if partial:
iter_record["sandbox"]["metrics"] = partial
metric_val = _find_metric(partial, metric_key)
iter_record["metric"] = metric_val
logger.info(
"Stage 13 iteration %d: recovered %d partial metrics from timeout stdout",
iteration,
len(partial),
)
# --- Detect runtime issues (NaN/Inf, stderr warnings) ---
runtime_issues = _detect_runtime_issues(rerun)
if runtime_issues:
iter_record["runtime_issues"] = runtime_issues
logger.info(
"Stage 13 iteration %d: runtime issues detected: %s",
iteration,
runtime_issues[:200],
)
# Attempt LLM repair with runtime context
rrp = _pm.sub_prompt(
"iterative_repair",
issue_text=runtime_issues,
all_files_ctx=_files_to_context(candidate_files),
)
repair_resp = _chat_with_prompt(llm, rrp.system, rrp.user)
repaired_files = _extract_multi_file_blocks(repair_resp.content)
if not repaired_files:
single = _extract_code_block(repair_resp.content)
if single.strip():
repaired_files = dict(candidate_files)
repaired_files["main.py"] = single
if repaired_files:
# BUG-106 fix: merge instead of replace to preserve
# supporting modules (trainers.py, utils.py, etc.)
merged = dict(candidate_files)
merged.update(repaired_files)
candidate_files = merged
_write_project(version_dir, candidate_files)
# Re-run after runtime fix
sandbox2 = create_sandbox(
config.experiment,
stage_dir / f"refine_sandbox_v{iteration}_fix",
)
rerun2 = sandbox2.run_project(
version_dir,
timeout_sec=config.experiment.time_budget_sec,
)
metric_val = _find_metric(rerun2.metrics, metric_key)
iter_record["sandbox_after_fix"] = {
"returncode": rerun2.returncode,
"metrics": rerun2.metrics,
"elapsed_sec": rerun2.elapsed_sec,
"timed_out": rerun2.timed_out,
}
iter_record["metric"] = metric_val
iter_record["runtime_repaired"] = True
if metric_val is not None:
consecutive_no_metrics = 0
# R6-1: Only count toward no_improve_streak when we have real metrics
if _is_better(metric_val, best_metric):
best_metric = metric_val
best_files = dict(candidate_files)
best_version = f"experiment_v{iteration}/"
iter_record["improved"] = True
no_improve_streak = 0
else:
no_improve_streak += 1
else:
consecutive_no_metrics += 1
elif validation.ok and best_version == "experiment/":
best_files = dict(candidate_files)
best_version = f"experiment_v{iteration}/"
# P1: Track metric for saturation detection
_metrics_history.append(metric_val)
log["iterations"].append(iter_record)
if consecutive_no_metrics >= 3:
log["stop_reason"] = "consecutive_no_metrics"
logger.warning("Stage 13: Aborting after %d consecutive iterations without metrics", consecutive_no_metrics)
break
if no_improve_streak >= 2:
log["converged"] = True
log["stop_reason"] = "no_improvement_for_2_iterations"
logger.info(
"Stage 13 converged after %d iterations (no improvement streak=%d)",
iteration,
no_improve_streak,
)
break
# Write final experiment directory
final_dir = stage_dir / "experiment_final"
_write_project(final_dir, best_files)
# Backward compat: also write experiment_final.py (copy of main.py)
if "main.py" in best_files:
(stage_dir / "experiment_final.py").write_text(
best_files["main.py"], encoding="utf-8"
)
log["best_metric"] = best_metric
log["best_version"] = best_version
log["final_version"] = "experiment_final/"
# BUG-110: Aggregate ablation check results across iterations
_all_ablation_identical = any(
iter_rec.get("ablation_identical", False)
for iter_rec in log.get("iterations", [])
if isinstance(iter_rec, dict)
)
if _all_ablation_identical:
log["ablation_identical_warning"] = True
(stage_dir / "refinement_log.json").write_text(
json.dumps(log, indent=2), encoding="utf-8"
)
artifacts = ["refinement_log.json", "experiment_final/"]
artifacts.extend(
entry["version_dir"]
for entry in log["iterations"]
if isinstance(entry, dict) and isinstance(entry.get("version_dir"), str)
)
return StageResult(
stage=Stage.ITERATIVE_REFINE,
status=StageStatus.DONE,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-13/{a}" for a in artifacts),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_experiment_design.py
================================================
"""Stage 9: Experiment design."""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any
import yaml
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain
from researchclaw.pipeline._helpers import (
StageResult,
_build_context_preamble,
_chat_with_prompt,
_extract_yaml_block,
_get_evolution_overlay,
_load_hardware_profile,
_read_prior_artifact,
_safe_json_loads,
_utcnow_iso,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
def _execute_experiment_design(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
hypotheses = _read_prior_artifact(run_dir, "hypotheses.md") or ""
preamble = _build_context_preamble(
config, run_dir, include_goal=True, include_hypotheses=True
)
plan: dict[str, Any] | None = None
# ── Domain detection ──────────────────────────────────────────────────
# Detect the research domain early so we can adapt experiment design
# and code generation. For ML domains, existing behavior is unchanged.
_domain_profile = None
try:
from researchclaw.domains.detector import detect_domain as _detect_domain_adv
_domain_profile = _detect_domain_adv(
topic=config.research.topic,
hypotheses=hypotheses,
)
logger.info(
"Domain detected: %s (%s)",
_domain_profile.display_name,
_domain_profile.domain_id,
)
# Persist domain profile for Stage 10
import json as _json_dd
(stage_dir / "domain_profile.json").write_text(
_json_dd.dumps({
"domain_id": _domain_profile.domain_id,
"display_name": _domain_profile.display_name,
"experiment_paradigm": _domain_profile.experiment_paradigm,
"core_libraries": _domain_profile.core_libraries,
"gpu_required": _domain_profile.gpu_required,
}, indent=2),
encoding="utf-8",
)
except Exception: # noqa: BLE001
logger.debug("Domain detection unavailable", exc_info=True)
if llm is not None:
_pm = prompts or PromptManager()
# Pass dataset_guidance block for experiment design
try:
_dg_block = _pm.block("dataset_guidance")
except (KeyError, Exception): # noqa: BLE001
_dg_block = ""
# I-08: Inject RL step guidance for RL topics
_rl_kws = ("reinforcement learning", "ppo", "sac", "td3", "ddpg",
"dqn", "mujoco", "continuous control", "actor-critic",
"policy gradient", "exploration bonus")
_is_rl_topic = any(kw in config.research.topic.lower() for kw in _rl_kws)
if _is_rl_topic:
try:
_dg_block += _pm.block("rl_step_guidance")
except Exception: # noqa: BLE001
pass
# Improvement G: For RL with short budget, constrain to classic control
if config.experiment.time_budget_sec <= 3600:
_dg_block += (
"\n\n## RL TIME CONSTRAINT (MANDATORY):\n"
f"Your time budget is {config.experiment.time_budget_sec}s (≤ 3600s).\n"
"You MUST use ONLY classic control environments: "
"CartPole-v1, Pendulum-v1, MountainCar-v0, Acrobot-v1, LunarLander-v3.\n"
"Do NOT use MuJoCo (HalfCheetah, Hopper, Walker2d, Ant, Humanoid) — "
"they require >5000s for meaningful training.\n"
)
if config.experiment.time_budget_sec <= 1800:
_dg_block += (
"Time budget ≤ 1800s: use ONLY CartPole-v1 or Pendulum-v1 "
"(the simplest environments).\n"
)
# F-01: Inject framework docs for experiment design
try:
from researchclaw.data import detect_frameworks, load_framework_docs
_fw_ids = detect_frameworks(config.research.topic, hypotheses)
if _fw_ids:
_fw_docs = load_framework_docs(_fw_ids, max_chars=4000)
if _fw_docs:
_dg_block += _fw_docs
except Exception: # noqa: BLE001
pass
# Improvement A: Compute hardware profile + per-condition budget
_hw_profile_str = (
"- GPU: NVIDIA RTX 6000 Ada (49140 MB VRAM)\n"
"- GPU count: 1\n"
"- CPU: shared server"
)
_per_condition_sec = int(config.experiment.time_budget_sec * 0.7 / 6)
_tier1 = "CIFAR-10, CIFAR-100, MNIST, FashionMNIST, STL-10, SVHN"
_overlay = _get_evolution_overlay(run_dir, "experiment_design")
sp = _pm.for_stage(
"experiment_design",
evolution_overlay=_overlay,
preamble=preamble,
hypotheses=hypotheses,
dataset_guidance=_dg_block,
time_budget_sec=config.experiment.time_budget_sec,
metric_key=config.experiment.metric_key,
metric_direction=config.experiment.metric_direction,
hardware_profile=_hw_profile_str,
per_condition_budget_sec=_per_condition_sec,
available_tier1_datasets=_tier1,
)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
raw_yaml = _extract_yaml_block(resp.content)
try:
parsed = yaml.safe_load(raw_yaml)
except yaml.YAMLError:
parsed = None
# Fallback: reasoning models sometimes emit the YAML without fences
# or wrapped in prose. Try parsing the whole response as YAML.
if not isinstance(parsed, dict):
try:
parsed = yaml.safe_load(resp.content)
except yaml.YAMLError:
pass
# Last fallback: try to find any YAML-like dict in the response
if not isinstance(parsed, dict):
import re as _re_yaml
# Look for lines starting with known keys
_yaml_lines = []
_capturing = False
for line in resp.content.splitlines():
if _re_yaml.match(
r"^(baselines|proposed_methods|ablations|datasets|"
r"metrics|objectives|risks|compute_budget)\s*:",
line,
):
_capturing = True
if _capturing:
if line.strip() == "" or line.startswith("```"):
continue
if line.startswith("#") or line.startswith("**"):
continue
_yaml_lines.append(line)
if _yaml_lines:
try:
parsed = yaml.safe_load("\n".join(_yaml_lines))
except yaml.YAMLError:
pass
if isinstance(parsed, dict):
plan = parsed
else:
logger.warning(
"Stage 09: LLM response could not be parsed as YAML "
"(len=%d, first 200 chars: %s). Content extraction method "
"returned: %s",
len(resp.content),
resp.content[:200],
raw_yaml[:200] if raw_yaml else "",
)
# BUG-12: Retry with a stricter, shorter prompt
if llm is not None:
logger.info("Stage 09: Retrying with strict YAML-only prompt...")
_retry_prompt = (
"Output ONLY valid YAML. No prose, no markdown fences, no explanation.\n"
f"Topic: {config.research.topic}\n"
"Required keys: baselines, proposed_methods, ablations, "
"datasets, metrics, objectives, risks, compute_budget.\n"
"Each key maps to a list of strings."
)
_retry_resp = _chat_with_prompt(
llm,
"You output ONLY valid YAML. Nothing else.",
_retry_prompt,
max_tokens=4096,
)
try:
_retry_parsed = yaml.safe_load(_retry_resp.content)
if isinstance(_retry_parsed, dict):
plan = _retry_parsed
logger.info("Stage 09: Strict YAML retry succeeded.")
except yaml.YAMLError:
pass
# BUG-12: Fallback 4 — extract method/baseline names from Stage 8 hypotheses
if plan is None:
_hyp_text = _read_prior_artifact(run_dir, "hypotheses.md") or ""
if _hyp_text:
import re as _re_hyp
# Extract method-like names from hypothesis text
_method_candidates = _re_hyp.findall(
r"(?:proposed|our|novel|new)\s+(?:method|approach|algorithm|framework|model)[:\s]+[\"']?([A-Za-z][\w-]+)",
_hyp_text, _re_hyp.IGNORECASE,
)
_baseline_candidates = _re_hyp.findall(
r"(?:baseline|compare|existing|standard|traditional)\s+(?:method|approach|model)?[:\s]+[\"']?([A-Za-z][\w-]+)",
_hyp_text, _re_hyp.IGNORECASE,
)
if _method_candidates or _baseline_candidates:
logger.info(
"Stage 09: Extracted names from hypotheses: methods=%s, baselines=%s",
_method_candidates[:3], _baseline_candidates[:3],
)
plan = {
"topic": config.research.topic,
"generated": _utcnow_iso(),
"objectives": ["Evaluate hypotheses with controlled experiments"],
"datasets": ["primary_dataset"],
"baselines": _baseline_candidates[:3] or ["baseline_1", "baseline_2"],
"proposed_methods": _method_candidates[:3] or ["proposed_method"],
"ablations": ["without_key_component", "simplified_version"],
"metrics": [config.experiment.metric_key, "secondary_metric"],
"risks": ["validity threats", "confounding variables"],
"compute_budget": {"max_gpu": 1, "max_hours": 4},
}
if plan is None:
# BUG-12: Use domain-aware names instead of fully generic placeholders
_topic_prefix = config.research.topic.split()[0] if config.research.topic else "method"
logger.warning(
"Stage 09: LLM failed to produce valid experiment plan YAML. "
"Using topic-derived fallback."
)
plan = {
"topic": config.research.topic,
"generated": _utcnow_iso(),
"objectives": ["Evaluate hypotheses with controlled experiments"],
"datasets": ["primary_dataset", "secondary_dataset"],
"baselines": [f"{_topic_prefix}_baseline_1", f"{_topic_prefix}_baseline_2"],
"proposed_methods": [f"{_topic_prefix}_proposed", f"{_topic_prefix}_variant"],
"ablations": ["without_key_component", "simplified_version"],
"metrics": [config.experiment.metric_key, "secondary_metric"],
"risks": ["validity threats", "confounding variables"],
"compute_budget": {"max_gpu": 1, "max_hours": 4},
}
# ── BA: BenchmarkAgent — intelligent dataset/baseline selection ──────
_benchmark_plan = None
# BUG-40: Skip BenchmarkAgent for non-ML domains — it has no relevant
# benchmarks for physics/chemistry/mathematics/etc. and would inject
# wrong datasets (e.g., CIFAR-10 for PDE topics).
_ba_domain_id, _, _ = _detect_domain(
config.research.topic,
tuple(config.research.domains) if config.research.domains else (),
)
_ba_domain_ok = _ba_domain_id == "ml"
if not _ba_domain_ok:
logger.info(
"BenchmarkAgent skipped: domain '%s' is not ML (topic: %s)",
_ba_domain_id, config.research.topic[:80],
)
if (
_ba_domain_ok
and config.experiment.benchmark_agent.enabled
and config.experiment.mode in ("sandbox", "docker")
and llm is not None
):
try:
from researchclaw.agents.benchmark_agent import BenchmarkOrchestrator
from researchclaw.agents.benchmark_agent.orchestrator import (
BenchmarkAgentConfig as _BACfg,
)
_ba_cfg_raw = config.experiment.benchmark_agent
_ba_cfg = _BACfg(
enabled=_ba_cfg_raw.enabled,
enable_hf_search=_ba_cfg_raw.enable_hf_search,
max_hf_results=_ba_cfg_raw.max_hf_results,
enable_web_search=_ba_cfg_raw.enable_web_search,
max_web_results=_ba_cfg_raw.max_web_results,
web_search_min_local=_ba_cfg_raw.web_search_min_local,
tier_limit=_ba_cfg_raw.tier_limit,
min_benchmarks=_ba_cfg_raw.min_benchmarks,
min_baselines=_ba_cfg_raw.min_baselines,
prefer_cached=_ba_cfg_raw.prefer_cached,
max_iterations=_ba_cfg_raw.max_iterations,
)
_hw = _load_hardware_profile(run_dir)
_ba = BenchmarkOrchestrator(
llm,
config=_ba_cfg,
gpu_memory_mb=(
_hw.get("gpu_memory_mb", 49000) if _hw else 49000
),
time_budget_sec=config.experiment.time_budget_sec,
network_policy=(
config.experiment.docker.network_policy
if config.experiment.mode == "docker"
else "full"
),
stage_dir=stage_dir / "benchmark_agent",
)
_benchmark_plan = _ba.orchestrate({
"topic": config.research.topic,
"hypothesis": hypotheses,
"experiment_plan": plan.get("objectives", "") if isinstance(plan, dict) else "",
})
# Inject BenchmarkAgent selections into experiment plan
if isinstance(plan, dict) and _benchmark_plan.selected_benchmarks:
plan["datasets"] = [
b.get("name", "Unknown") for b in _benchmark_plan.selected_benchmarks
]
# Normalize existing baselines to list of strings
# BUG-35: LLM may emit baselines as dict, list of dicts,
# or list of strings — normalize all to list[str].
_baselines_from_plan = plan.get("baselines", [])
if isinstance(_baselines_from_plan, dict):
_baselines_from_plan = list(_baselines_from_plan.keys())
elif isinstance(_baselines_from_plan, list):
_baselines_from_plan = [
item["name"] if isinstance(item, dict) else str(item)
for item in _baselines_from_plan
]
else:
_baselines_from_plan = []
plan["baselines"] = [
bl.get("name", "Unknown") for bl in _benchmark_plan.selected_baselines
] + _baselines_from_plan
# Deduplicate baselines
plan["baselines"] = list(dict.fromkeys(plan["baselines"]))
logger.info(
"BenchmarkAgent: %d benchmarks, %d baselines selected (%d LLM calls, %.1fs)",
len(_benchmark_plan.selected_benchmarks),
len(_benchmark_plan.selected_baselines),
_benchmark_plan.total_llm_calls,
_benchmark_plan.elapsed_sec,
)
except Exception as _ba_exc:
logger.warning("BenchmarkAgent failed (non-fatal): %s", _ba_exc)
# Save benchmark plan for code_generation stage
if _benchmark_plan is not None:
try:
(stage_dir / "benchmark_plan.json").write_text(
json.dumps(_benchmark_plan.to_dict(), indent=2, ensure_ascii=False),
encoding="utf-8",
)
except Exception: # noqa: BLE001
pass
plan.setdefault("topic", config.research.topic)
# BUG-R41-09: Enforce condition count limit based on time budget.
# Too many conditions (30+) guarantee timeouts and wasted compute.
_time_budget = getattr(
getattr(config, "experiment", None), "time_budget_sec", 3600
)
_max_conditions = 8 # default for budgets ≤ 3600s
if _time_budget > 3600:
_max_conditions = 12
if _time_budget > 7200:
_max_conditions = 20
_baselines = plan.get("baselines", [])
if isinstance(_baselines, dict):
_baselines = list(_baselines.values())
_proposed = plan.get("proposed_methods", [])
if isinstance(_proposed, dict):
_proposed = list(_proposed.values())
_ablations = plan.get("ablations", [])
if isinstance(_ablations, dict):
_ablations = list(_ablations.values())
_total = len(_baselines) + len(_proposed) + len(_ablations)
if _total > _max_conditions:
logger.warning(
"Stage 9: Plan has %d conditions (limit %d for %ds budget). "
"Trimming to fit.",
_total, _max_conditions, _time_budget,
)
# Keep all proposed methods (up to max), trim baselines and ablations
_proposed_count = min(len(_proposed), max(1, _max_conditions - 4))
_remaining = max(0, _max_conditions - _proposed_count)
_baseline_budget = max(1, _remaining // 2)
_ablation_budget = max(0, _remaining - _baseline_budget)
if len(_proposed) > _proposed_count:
plan["proposed_methods"] = _proposed[:_proposed_count]
logger.info(
"Stage 9: Trimmed proposed methods %d → %d",
len(_proposed), _proposed_count,
)
if len(_baselines) > _baseline_budget:
plan["baselines"] = _baselines[:_baseline_budget]
logger.info(
"Stage 9: Trimmed baselines %d → %d",
len(_baselines), _baseline_budget,
)
if len(_ablations) > _ablation_budget:
plan["ablations"] = _ablations[:_ablation_budget]
logger.info(
"Stage 9: Trimmed ablations %d → %d",
len(_ablations), _ablation_budget,
)
(stage_dir / "exp_plan.yaml").write_text(
yaml.dump(plan, default_flow_style=False, allow_unicode=True),
encoding="utf-8",
)
return StageResult(
stage=Stage.EXPERIMENT_DESIGN,
status=StageStatus.DONE,
artifacts=("exp_plan.yaml",),
evidence_refs=("stage-09/exp_plan.yaml",),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_literature.py
================================================
"""Stages 3-6: Search strategy, literature collection, screening, and knowledge extraction."""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any
import yaml
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._helpers import (
StageResult,
_build_fallback_queries,
_chat_with_prompt,
_extract_topic_keywords,
_extract_yaml_block,
_get_evolution_overlay,
_parse_jsonl_rows,
_read_prior_artifact,
_safe_filename,
_safe_json_loads,
_utcnow_iso,
_write_jsonl,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Local helpers
# ---------------------------------------------------------------------------
def _expand_search_queries(queries: list[str], topic: str) -> list[str]:
"""Expand search queries for broader literature coverage.
Generates additional queries by extracting key phrases from the topic
and creating focused sub-queries. This ensures we find papers even when
the original queries are too narrow or specific for arXiv.
"""
expanded = list(queries) # keep originals
seen = {q.lower().strip() for q in queries}
# Extract key phrases from topic by splitting on common delimiters
# e.g. "Comparing A, B, and C on X with Y" → ["A", "B", "C", "X", "Y"]
topic_words = topic.split()
# Generate shorter, broader queries from the topic
if len(topic_words) > 5:
# First 5 words as a broader query
broad = " ".join(topic_words[:5])
if broad.lower().strip() not in seen:
expanded.append(broad)
seen.add(broad.lower().strip())
# Last 5 words as another perspective
tail = " ".join(topic_words[-5:])
if tail.lower().strip() not in seen:
expanded.append(tail)
seen.add(tail.lower().strip())
# Add "survey" and "benchmark" variants of the topic
for suffix in ("survey", "benchmark", "comparison"):
# Take first 4 content words + suffix
short_topic = " ".join(topic_words[:4])
variant = f"{short_topic} {suffix}"
if variant.lower().strip() not in seen:
expanded.append(variant)
seen.add(variant.lower().strip())
return expanded
# ---------------------------------------------------------------------------
# Stage executors
# ---------------------------------------------------------------------------
def _execute_search_strategy(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
problem_tree = _read_prior_artifact(run_dir, "problem_tree.md") or ""
topic = config.research.topic
plan: dict[str, Any] | None = None
sources: list[dict[str, Any]] | None = None
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "search_strategy")
sp = _pm.for_stage("search_strategy", evolution_overlay=_overlay, topic=topic, problem_tree=problem_tree)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
payload = _safe_json_loads(resp.content, {})
if isinstance(payload, dict):
yaml_text = str(payload.get("search_plan_yaml", "")).strip()
if yaml_text:
try:
parsed = yaml.safe_load(_extract_yaml_block(yaml_text))
except yaml.YAMLError:
parsed = None
if isinstance(parsed, dict):
plan = parsed
src = payload.get("sources", [])
if isinstance(src, list):
sources = [item for item in src if isinstance(item, dict)]
if plan is None:
# Build smart fallback queries by extracting key terms from topic
# instead of using the raw (often very long) topic string.
_fallback_queries = _build_fallback_queries(topic)
plan = {
"topic": topic,
"generated": _utcnow_iso(),
"search_strategies": [
{
"name": "keyword_core",
"queries": _fallback_queries[:5],
"sources": ["arxiv", "semantic_scholar", "openreview"],
"max_results_per_query": 60,
},
{
"name": "backward_forward_citation",
"queries": _fallback_queries[5:10] or _fallback_queries[:3],
"sources": ["semantic_scholar", "google_scholar"],
"depth": 1,
},
],
"filters": {
"min_year": 2020,
"language": ["en"],
"peer_review_preferred": True,
},
"deduplication": {"method": "title_doi_hash", "fuzzy_threshold": 0.9},
}
if not sources:
sources = [
{
"id": "arxiv",
"name": "arXiv",
"type": "api",
"url": "https://export.arxiv.org/api/query",
"status": "available",
"query": topic,
"verified_at": _utcnow_iso(),
},
{
"id": "semantic_scholar",
"name": "Semantic Scholar",
"type": "api",
"url": "https://api.semanticscholar.org/graph/v1/paper/search",
"status": "available",
"query": topic,
"verified_at": _utcnow_iso(),
},
]
if config.openclaw_bridge.use_web_fetch:
for src in sources:
try:
response = adapters.web_fetch.fetch(str(src.get("url", "")))
src["status"] = (
"verified"
if response.status_code in (200, 301, 302, 405)
else "unreachable"
)
src["http_status"] = response.status_code
except Exception: # noqa: BLE001
src["status"] = "unknown"
(stage_dir / "search_plan.yaml").write_text(
yaml.dump(plan, default_flow_style=False, allow_unicode=True),
encoding="utf-8",
)
(stage_dir / "sources.json").write_text(
json.dumps(
{"sources": sources, "count": len(sources), "generated": _utcnow_iso()},
indent=2,
),
encoding="utf-8",
)
# F1.5: Extract queries from plan for Stage 4 real literature search
queries_list: list[str] = []
year_min = 2020
if isinstance(plan, dict):
strategies = plan.get("search_strategies", [])
if isinstance(strategies, list):
for strat in strategies:
if isinstance(strat, dict):
qs = strat.get("queries", [])
if isinstance(qs, list):
queries_list.extend(str(q) for q in qs if q)
filters = plan.get("filters", {})
if isinstance(filters, dict) and filters.get("min_year"):
try:
year_min = int(filters["min_year"])
except (ValueError, TypeError):
pass
# --- Sanitize queries: shorten overly long queries ---
# LLMs often produce the full topic title as a query, which is too long for
# arXiv and Semantic Scholar (they work best with 3-8 keyword queries).
_stop = {
"a", "an", "the", "of", "for", "in", "on", "and", "or", "with",
"to", "by", "from", "its", "is", "are", "was", "be", "as", "at",
"via", "using", "based", "study", "analysis", "empirical",
"towards", "toward", "into", "exploring", "comparison", "tasks",
"effectiveness", "investigation", "comprehensive", "novel",
}
def _extract_keywords(text: str) -> list[str]:
"""Extract meaningful keywords from text, removing stop words."""
return [
w for w in re.split(r"[^a-zA-Z0-9]+", text)
if w.lower() not in _stop and len(w) > 1
]
_MAX_QUERY_LEN = 60 # characters — beyond this, shorten to keywords
_SEARCH_SUFFIXES = ["benchmark", "survey", "seminal", "state of the art"]
def _shorten_query(q: str, max_kw: int = 6) -> str:
"""Shorten a query to *max_kw* keywords, preserving any trailing suffix."""
q_stripped = q.strip()
# Check if query ends with a known search suffix
suffix = ""
q_core = q_stripped
for sfx in _SEARCH_SUFFIXES:
if q_stripped.lower().endswith(sfx):
suffix = sfx
q_core = q_stripped[: -len(sfx)].strip()
break
# Extract keywords from the core part
kws = _extract_keywords(q_core)
shortened = " ".join(kws[:max_kw])
if suffix:
shortened = f"{shortened} {suffix}"
return shortened
if queries_list:
sanitized: list[str] = []
for q in queries_list:
if len(q) > _MAX_QUERY_LEN:
shortened = _shorten_query(q)
if shortened.strip():
sanitized.append(shortened)
else:
sanitized.append(q)
queries_list = sanitized
if not queries_list:
# Build diverse keyword queries from the topic
_words = _extract_keywords(topic)
kw_primary = " ".join(_words[:6])
kw_short = " ".join(_words[:4])
queries_list = [
kw_primary,
f"{kw_short} benchmark",
f"{kw_short} survey",
]
# Ensure minimum query diversity — if dedup leaves too few, add variants
_all_kw = _extract_keywords(topic)
_seen_q: set[str] = set()
unique_queries: list[str] = []
for q in queries_list:
q_lower = q.strip().lower()
if q_lower and q_lower not in _seen_q:
_seen_q.add(q_lower)
unique_queries.append(q.strip())
# If we have fewer than 5 unique queries, generate supplemental keyword variants
if len(unique_queries) < 5 and len(_all_kw) >= 3:
supplements = [
" ".join(_all_kw[:4]) + " survey",
" ".join(_all_kw[:4]) + " benchmark",
" ".join(_all_kw[1:5]), # shifted window for diversity
" ".join(_all_kw[:3]) + " comparison",
" ".join(_all_kw[:3]) + " deep learning",
" ".join(_all_kw[2:6]), # another shifted window
]
for s in supplements:
s_lower = s.strip().lower()
if s_lower not in _seen_q:
_seen_q.add(s_lower)
unique_queries.append(s.strip())
if len(unique_queries) >= 8:
break
queries_list = unique_queries
(stage_dir / "queries.json").write_text(
json.dumps({"queries": queries_list, "year_min": year_min}, indent=2),
encoding="utf-8",
)
return StageResult(
stage=Stage.SEARCH_STRATEGY,
status=StageStatus.DONE,
artifacts=("search_plan.yaml", "sources.json", "queries.json"),
evidence_refs=(
"stage-03/search_plan.yaml",
"stage-03/sources.json",
"stage-03/queries.json",
),
)
def _execute_literature_collect(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
"""Stage 4: Collect literature — prefer real APIs, fallback to LLM."""
topic = config.research.topic
# Read queries.json from Stage 3 (F1.5 output)
queries_text = _read_prior_artifact(run_dir, "queries.json")
queries_data = _safe_json_loads(queries_text or "{}", {})
queries: list[str] = queries_data.get("queries", [topic])
year_min: int = queries_data.get("year_min", 2020)
# --- Try real API search first ---
candidates: list[dict[str, Any]] = []
bibtex_entries: list[str] = []
real_search_succeeded = False
try:
from researchclaw.literature.search import (
search_papers_multi_query,
papers_to_bibtex,
)
# Expand queries for broader coverage
expanded_queries = _expand_search_queries(queries, config.research.topic)
logger.info(
"[literature] Searching %d queries (expanded from %d) "
"across OpenAlex → S2 → arXiv…",
len(expanded_queries),
len(queries),
)
papers = search_papers_multi_query(
expanded_queries,
limit_per_query=40,
year_min=year_min,
s2_api_key=config.llm.s2_api_key,
)
if papers:
real_search_succeeded = True
# Count by source
src_counts: dict[str, int] = {}
for p in papers:
src_counts[p.source] = src_counts.get(p.source, 0) + 1
d = p.to_dict()
d["collected_at"] = _utcnow_iso()
candidates.append(d)
bibtex_entries.append(p.to_bibtex())
src_str = ", ".join(f"{s}: {n}" for s, n in src_counts.items())
logger.info(
"[literature] Found %d papers (%s)", len(papers), src_str
)
except Exception: # noqa: BLE001
logger.warning(
"[rate-limit] Literature search failed — falling back to LLM",
exc_info=True,
)
# --- Inject foundational/seminal papers ---
try:
from researchclaw.data import load_seminal_papers
seminal = load_seminal_papers(topic)
if seminal:
_existing_titles = {c.get("title", "").lower() for c in candidates}
_injected = 0
for sp in seminal:
if sp.get("title", "").lower() not in _existing_titles:
candidates.append({
"id": f"seminal-{sp.get('cite_key', '')}",
"title": sp.get("title", ""),
"source": "seminal_library",
"url": "",
"year": sp.get("year", 2020),
"abstract": f"Foundational paper on {', '.join(sp.get('keywords', [])[:3])}.",
"authors": [{"name": sp.get("authors", "")}],
"cite_key": sp.get("cite_key", ""),
"venue": sp.get("venue", ""),
"collected_at": _utcnow_iso(),
})
_injected += 1
if _injected:
logger.info("Stage 4: Injected %d seminal papers from seed library", _injected)
except Exception: # noqa: BLE001
logger.debug("Seminal paper injection skipped", exc_info=True)
# --- Fallback: LLM-generated candidates ---
if not candidates and llm is not None:
plan_text = _read_prior_artifact(run_dir, "search_plan.yaml") or ""
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "literature_collect")
sp = _pm.for_stage("literature_collect", evolution_overlay=_overlay, topic=topic, plan_text=plan_text)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
payload = _safe_json_loads(resp.content, {})
if isinstance(payload, dict) and isinstance(payload.get("candidates"), list):
candidates = [row for row in payload["candidates"] if isinstance(row, dict)]
# --- Web search augmentation (Tavily/DDG + Google Scholar + Crawl4AI) ---
web_context_parts: list[str] = []
if config.web_search.enabled:
try:
from researchclaw.web.agent import WebSearchAgent
import os
tavily_key = config.web_search.tavily_api_key or os.environ.get(
config.web_search.tavily_api_key_env, ""
)
web_agent = WebSearchAgent(
tavily_api_key=tavily_key,
enable_scholar=config.web_search.enable_scholar,
enable_crawling=config.web_search.enable_crawling,
enable_pdf=config.web_search.enable_pdf_extraction,
max_web_results=config.web_search.max_web_results,
max_scholar_results=config.web_search.max_scholar_results,
max_crawl_urls=config.web_search.max_crawl_urls,
)
web_result = web_agent.search_and_extract(
topic, search_queries=queries,
)
# Convert Google Scholar papers into candidates
for sp in web_result.scholar_papers:
_existing_titles = {
str(c.get("title", "")).lower().strip() for c in candidates
}
if sp.title.lower().strip() not in _existing_titles:
lit_paper = sp.to_literature_paper()
d = lit_paper.to_dict()
d["collected_at"] = _utcnow_iso()
candidates.append(d)
bibtex_entries.append(lit_paper.to_bibtex())
# Save web search context for downstream stages
web_context = web_result.to_context_string(max_length=20_000)
if web_context.strip():
(stage_dir / "web_context.md").write_text(
web_context, encoding="utf-8"
)
web_context_parts.append(web_context)
# Save full web search metadata
(stage_dir / "web_search_result.json").write_text(
json.dumps(web_result.to_dict(), indent=2, default=str),
encoding="utf-8",
)
logger.info(
"[web-search] Added %d scholar papers, %d web results, %d crawled pages",
len(web_result.scholar_papers),
len(web_result.web_results),
len(web_result.crawled_pages),
)
except Exception: # noqa: BLE001
logger.warning(
"[web-search] Web search augmentation failed — continuing with academic APIs only",
exc_info=True,
)
# --- Ultimate fallback: placeholder data ---
# BUG-L2: Do NOT overwrite real_search_succeeded here — it was already
# set correctly in the search block above. Overwriting would mislabel
# LLM-hallucinated or seminal papers as "real search" results.
if not candidates:
logger.warning("Stage 4: All literature searches failed — using placeholder papers")
candidates = [
{
"id": f"candidate-{idx + 1}",
"title": f"[Placeholder] Study {idx + 1} on {topic}",
"source": "arxiv" if idx % 2 == 0 else "semantic_scholar",
"url": f"https://example.org/{_safe_filename(topic.lower())}/{idx + 1}",
"year": 2024,
"abstract": f"This candidate investigates {topic} and reports preliminary findings.",
"collected_at": _utcnow_iso(),
"is_placeholder": True,
}
for idx in range(max(20, config.research.daily_paper_count or 20))
]
# Write candidates
out = stage_dir / "candidates.jsonl"
_write_jsonl(out, candidates)
# BUG-50 fix: Generate BibTeX from candidates when real search failed
# (LLM/placeholder fallback paths don't populate bibtex_entries)
if not bibtex_entries and candidates:
for c in candidates:
if c.get("is_placeholder"):
continue
_ck = c.get("cite_key", "")
if not _ck:
# Derive cite_key from first author surname + year
_authors = c.get("authors", [])
_surname = "unknown"
if isinstance(_authors, list) and _authors:
_a0 = _authors[0] if isinstance(_authors[0], str) else (_authors[0].get("name", "") if isinstance(_authors[0], dict) else "")
_surname = _a0.split()[-1].lower() if _a0.strip() else "unknown"
_yr = c.get("year", 2024)
_title_word = "".join(
w[0] for w in str(c.get("title", "study")).split()[:3]
).lower()
_ck = f"{_surname}{_yr}{_title_word}"
_title = c.get("title", "Untitled")
_year = c.get("year", 2024)
_author_str = ""
_raw_authors = c.get("authors", [])
if isinstance(_raw_authors, list):
_names = []
for _a in _raw_authors:
if isinstance(_a, str):
_names.append(_a)
elif isinstance(_a, dict):
_names.append(_a.get("name", ""))
_author_str = " and ".join(n for n in _names if n)
bibtex_entries.append(
f"@article{{{_ck},\n"
f" title={{{_title}}},\n"
f" author={{{_author_str or 'Unknown'}}},\n"
f" year={{{_year}}},\n"
f" url={{{c.get('url', '')}}},\n"
f"}}"
)
logger.info(
"Stage 4: Generated %d BibTeX entries from candidates (fallback)",
len(bibtex_entries),
)
# Write references.bib (F2.4)
artifacts = ["candidates.jsonl"]
if web_context_parts:
artifacts.append("web_context.md")
if (stage_dir / "web_search_result.json").exists():
artifacts.append("web_search_result.json")
if bibtex_entries:
bib_content = "\n\n".join(bibtex_entries) + "\n"
(stage_dir / "references.bib").write_text(bib_content, encoding="utf-8")
artifacts.append("references.bib")
logger.info(
"Stage 4: Wrote %d BibTeX entries to references.bib", len(bibtex_entries)
)
# Write search metadata
(stage_dir / "search_meta.json").write_text(
json.dumps(
{
"real_search": real_search_succeeded,
"queries_used": queries,
"year_min": year_min,
"total_candidates": len(candidates),
"bibtex_entries": len(bibtex_entries),
"ts": _utcnow_iso(),
},
indent=2,
),
encoding="utf-8",
)
artifacts.append("search_meta.json")
return StageResult(
stage=Stage.LITERATURE_COLLECT,
status=StageStatus.DONE,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-04/{a}" for a in artifacts),
)
def _execute_literature_screen(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
candidates_text = _read_prior_artifact(run_dir, "candidates.jsonl") or ""
# --- P1-1: keyword relevance pre-filter ---
# Before LLM screening, drop papers whose title+abstract share no keywords
# with the research topic. This catches cross-domain noise cheaply.
topic_keywords = _extract_topic_keywords(
config.research.topic, config.research.domains
)
filtered_rows: list[dict[str, Any]] = []
dropped_count = 0
for raw_line in candidates_text.strip().splitlines():
row = _safe_json_loads(raw_line, {})
if not isinstance(row, dict):
continue
title = str(row.get("title", "")).lower()
abstract = str(row.get("abstract", "")).lower()
text_blob = f"{title} {abstract}"
overlap = sum(1 for kw in topic_keywords if kw in text_blob)
# T2.2: Relaxed from ≥2 to ≥1 keyword hit — previous threshold was
# too aggressive (94% rejection rate). Single-keyword matches are
# still screened by the LLM in the next step.
if overlap >= 1:
row["keyword_overlap"] = overlap
filtered_rows.append(row)
else:
dropped_count += 1
# If pre-filter dropped everything, fall back to original (safety valve)
if not filtered_rows:
filtered_rows = _parse_jsonl_rows(candidates_text)
# Rebuild candidates_text from filtered rows
candidates_text = "\n".join(
json.dumps(r, ensure_ascii=False) for r in filtered_rows
)
logger.info(
"Domain pre-filter: kept %d, dropped %d (keywords: %s)",
len(filtered_rows),
dropped_count,
topic_keywords[:8],
)
shortlist: list[dict[str, Any]] = []
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "literature_screen")
sp = _pm.for_stage(
"literature_screen",
evolution_overlay=_overlay,
topic=config.research.topic,
domains=", ".join(config.research.domains)
if config.research.domains
else "general",
quality_threshold=config.research.quality_threshold,
candidates_text=candidates_text,
)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
payload = _safe_json_loads(resp.content, {})
if isinstance(payload, dict) and isinstance(payload.get("shortlist"), list):
shortlist = [row for row in payload["shortlist"] if isinstance(row, dict)]
# T2.2: Ensure minimum shortlist size of 15 for adequate related work
_MIN_SHORTLIST = 15
if not shortlist:
rows = (
filtered_rows[:_MIN_SHORTLIST]
if filtered_rows
else _parse_jsonl_rows(candidates_text)[:_MIN_SHORTLIST]
)
for idx, item in enumerate(rows):
item["relevance_score"] = round(0.75 - idx * 0.02, 3)
item["quality_score"] = round(0.72 - idx * 0.015, 3)
item["keep_reason"] = "Template screened entry"
shortlist.append(item)
elif len(shortlist) < _MIN_SHORTLIST:
# T2.2: LLM returned too few — supplement from filtered candidates
existing_titles = {
str(s.get("title", "")).lower().strip() for s in shortlist
}
for row in filtered_rows:
if len(shortlist) >= _MIN_SHORTLIST:
break
title_lower = str(row.get("title", "")).lower().strip()
if title_lower and title_lower not in existing_titles:
row.setdefault("relevance_score", 0.5)
row.setdefault("quality_score", 0.5)
row.setdefault("keep_reason", "Supplemented to meet minimum shortlist")
shortlist.append(row)
existing_titles.add(title_lower)
logger.info(
"Stage 5: Supplemented shortlist to %d papers (minimum: %d)",
len(shortlist), _MIN_SHORTLIST,
)
out = stage_dir / "shortlist.jsonl"
_write_jsonl(out, shortlist)
return StageResult(
stage=Stage.LITERATURE_SCREEN,
status=StageStatus.DONE,
artifacts=("shortlist.jsonl",),
evidence_refs=("stage-05/shortlist.jsonl",),
)
def _execute_knowledge_extract(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
shortlist = _read_prior_artifact(run_dir, "shortlist.jsonl") or ""
# Inject web context from Stage 4 if available
web_context = _read_prior_artifact(run_dir, "web_context.md") or ""
if web_context:
shortlist = shortlist + "\n\n--- Web Search Context ---\n" + web_context[:10_000]
cards_dir = stage_dir / "cards"
cards_dir.mkdir(parents=True, exist_ok=True)
cards: list[dict[str, Any]] = []
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "knowledge_extract")
sp = _pm.for_stage("knowledge_extract", evolution_overlay=_overlay, shortlist=shortlist)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
payload = _safe_json_loads(resp.content, {})
if isinstance(payload, dict) and isinstance(payload.get("cards"), list):
cards = [item for item in payload["cards"] if isinstance(item, dict)]
if not cards:
rows = _parse_jsonl_rows(shortlist)
for idx, paper in enumerate(rows[:6]):
title = str(paper.get("title", f"Paper {idx + 1}"))
cards.append(
{
"card_id": f"card-{idx + 1}",
"title": title,
"problem": f"How to improve {config.research.topic}",
"method": "Template method summary",
"data": "Template dataset",
"metrics": "Template metric",
"findings": "Template key finding",
"limitations": "Template limitation",
"citation": str(paper.get("url", "")),
"cite_key": str(paper.get("cite_key", "")),
}
)
for idx, card in enumerate(cards):
card_id = _safe_filename(str(card.get("card_id", f"card-{idx + 1}")))
parts = [f"# {card.get('title', card_id)}", ""]
for key in (
"cite_key",
"problem",
"method",
"data",
"metrics",
"findings",
"limitations",
"citation",
):
parts.append(f"## {key.title()}")
parts.append(str(card.get(key, "")))
parts.append("")
(cards_dir / f"{card_id}.md").write_text("\n".join(parts), encoding="utf-8")
return StageResult(
stage=Stage.KNOWLEDGE_EXTRACT,
status=StageStatus.DONE,
artifacts=("cards/",),
evidence_refs=("stage-06/cards/",),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_paper_writing.py
================================================
"""Stages 16-17: Paper outline and paper draft generation."""
from __future__ import annotations
import json
import logging
import math
import re
from pathlib import Path
from typing import Any
import yaml
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain, _is_ml_domain
from researchclaw.pipeline._helpers import (
StageResult,
_build_context_preamble,
_chat_with_prompt,
_collect_experiment_results,
_default_paper_outline,
_extract_paper_title,
_generate_framework_diagram_prompt,
_generate_neurips_checklist,
_get_evolution_overlay,
_read_best_analysis,
_read_prior_artifact,
_safe_json_loads,
_topic_constraint_block,
_utcnow_iso,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
def _execute_paper_outline(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
analysis = _read_best_analysis(run_dir)
decision = _read_prior_artifact(run_dir, "decision.md") or ""
preamble = _build_context_preamble(
config,
run_dir,
include_analysis=True,
include_decision=True,
include_experiment_data=True,
)
# WS-5.2: Read iteration feedback if available (multi-round iteration)
feedback = ""
iter_ctx_path = run_dir / "iteration_context.json"
if iter_ctx_path.exists():
try:
ctx = json.loads(iter_ctx_path.read_text(encoding="utf-8"))
iteration = ctx.get("iteration", 1)
prev_score = ctx.get("quality_score")
reviews_excerpt = ctx.get("reviews_excerpt", "")
if iteration > 1 and reviews_excerpt:
feedback = (
f"\n\n## Iteration {iteration} Feedback\n"
f"Previous quality score: {prev_score}/10\n"
f"Reviewer feedback to address:\n{reviews_excerpt[:2000]}\n"
f"\nYou MUST address these reviewer concerns in this revision.\n"
)
except (json.JSONDecodeError, KeyError):
pass
if llm is not None:
_pm = prompts or PromptManager()
# IMP-20: Pass academic style guide block for outline stage
try:
_asg = _pm.block("academic_style_guide")
except (KeyError, Exception):
_asg = ""
_overlay = _get_evolution_overlay(run_dir, "paper_outline")
sp = _pm.for_stage(
"paper_outline",
evolution_overlay=_overlay,
preamble=preamble,
topic_constraint=_pm.block("topic_constraint", topic=config.research.topic),
feedback=feedback,
analysis=analysis,
decision=decision,
academic_style_guide=_asg,
)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
outline = resp.content
# Reasoning models may consume all tokens on CoT — retry with more
if not outline.strip() and sp.max_tokens:
logger.warning("Empty outline from LLM — retrying with 2x tokens")
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens * 2,
)
outline = resp.content
if not outline.strip():
logger.warning("LLM returned empty outline — using default")
outline = _default_paper_outline(config.research.topic)
else:
outline = _default_paper_outline(config.research.topic)
(stage_dir / "outline.md").write_text(outline, encoding="utf-8")
return StageResult(
stage=Stage.PAPER_OUTLINE,
status=StageStatus.DONE,
artifacts=("outline.md",),
evidence_refs=("stage-16/outline.md",),
)
def _collect_raw_experiment_metrics(run_dir: Path) -> tuple[str, bool]:
"""Collect raw experiment metric lines from stdout for paper writing.
Returns a tuple of (formatted block, has_parsed_metrics).
``has_parsed_metrics`` is True when at least one run had a non-empty
``metrics`` dict in its JSON payload — a reliable signal of real data.
"""
metric_lines: list[str] = []
run_count = 0
has_parsed_metrics = False
for stage_subdir in sorted(run_dir.glob("stage-*/runs")):
for run_file in sorted(stage_subdir.glob("*.json")):
if run_file.name == "results.json":
continue
try:
payload = json.loads(run_file.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
continue
if not isinstance(payload, dict):
continue
# R10: Skip simulated data — only collect real experiment results
if payload.get("status") == "simulated":
continue
run_count += 1
# Extract from parsed metrics (check both 'metrics' and 'key_metrics')
metrics = payload.get("metrics", {}) or payload.get("key_metrics", {})
if isinstance(metrics, dict) and metrics:
has_parsed_metrics = True
for k, v in metrics.items():
metric_lines.append(f" {k}: {v}")
# Also extract from stdout for full detail
# BUG-23: Filter out infrastructure lines that are NOT experiment results
_INFRA_KEYS = {
"SEED_COUNT", "TIME_ESTIMATE", "TRAINING_STEPS",
"REGISTERED_CONDITIONS", "METRIC_DEF", "GPU_MEMORY",
"BATCH_SIZE", "NUM_WORKERS", "TOTAL_PARAMS",
"time_budget_sec", "max_epochs", "num_seeds",
}
stdout = payload.get("stdout", "")
if stdout:
for line in stdout.splitlines():
line = line.strip()
if ":" in line:
parts = line.rsplit(":", 1)
try:
float(parts[1].strip())
key_part = parts[0].strip().split("/")[-1] # last segment
if key_part in _INFRA_KEYS:
continue # skip infrastructure lines
metric_lines.append(f" {line}")
except (ValueError, TypeError, IndexError):
pass
# R19-4 + R23-1: Collect metrics from refinement_log.json (Stage 13).
# If refinement has richer data than Stage 12 runs/, REPLACE Stage 12 data
# to avoid confusing the paper writer with conflicting sources.
_refine_lines: list[str] = []
_refine_run_count = 0
# Scan ALL refinement logs across versions, pick by quality (primary
# metric) then richness (metric count). BUG-207: Previous logic picked
# the sandbox entry with the most metric keys regardless of whether it
# represented a regression (e.g. sandbox_after_fix with 1.29% accuracy
# winning over sandbox with 78.93% because it had 6 more keys).
_best_refine_metrics: dict[str, Any] = {}
_best_refine_stdout = ""
_best_refine_primary: float | None = None
for _rl_path in sorted(run_dir.glob("stage-13*/refinement_log.json")):
try:
_rlog = json.loads(_rl_path.read_text(encoding="utf-8"))
for _it in _rlog.get("iterations", []):
for _sbx_key in ("sandbox", "sandbox_after_fix"):
_sbx = _it.get(_sbx_key, {})
if not isinstance(_sbx, dict):
continue
_sbx_metrics = _sbx.get("metrics", {})
if not isinstance(_sbx_metrics, dict) or not _sbx_metrics:
continue
# Extract primary metric value for quality comparison
_sbx_primary: float | None = None
for _pm_key in ("primary_metric", "best_metric"):
if _pm_key in _sbx_metrics:
try:
_sbx_primary = float(_sbx_metrics[_pm_key])
except (ValueError, TypeError):
pass
break
# Prefer higher primary metric; fall back to count
_dominated = False
if _best_refine_primary is not None and _sbx_primary is not None:
if _sbx_primary > _best_refine_primary:
_dominated = True # new is better
elif _sbx_primary < _best_refine_primary * 0.5:
continue # skip: regression (>50% worse)
# Accept if quality-dominant or richer-with-no-regression
if _dominated or len(_sbx_metrics) > len(_best_refine_metrics):
_best_refine_metrics = _sbx_metrics
_best_refine_stdout = _sbx.get("stdout", "")
_best_refine_primary = _sbx_primary
except (json.JSONDecodeError, OSError):
pass
if _best_refine_metrics and len(_best_refine_metrics) > len(metric_lines) // 2:
# Refinement has richer data — REPLACE Stage 12 data to avoid conflicts
metric_lines = []
run_count = 1
for k, v in _best_refine_metrics.items():
metric_lines.append(f" {k}: {v}")
# Also extract PAIRED and metric lines from stdout
if _best_refine_stdout:
for _line in _best_refine_stdout.splitlines():
_line = _line.strip()
if _line.startswith("PAIRED:"):
metric_lines.append(f" {_line}")
elif ":" in _line:
parts = _line.rsplit(":", 1)
try:
float(parts[1].strip())
metric_lines.append(f" {_line}")
except (ValueError, TypeError, IndexError):
pass
elif _best_refine_metrics:
# Refinement has some data but not richer — append to existing
run_count += 1
for k, v in _best_refine_metrics.items():
metric_lines.append(f" {k}: {v}")
if _best_refine_stdout:
for _line in _best_refine_stdout.splitlines():
_line = _line.strip()
if _line.startswith("PAIRED:"):
metric_lines.append(f" {_line}")
if not metric_lines:
return "", has_parsed_metrics
# Deduplicate while preserving order
seen: set[str] = set()
unique: list[str] = []
for line in metric_lines:
if line not in seen:
seen.add(line)
unique.append(line)
# BUG-29: Reformat raw metric lines into human-readable condition summaries
# to prevent LLM from pasting raw path-style lines into the paper
_grouped: dict[str, list[str]] = {}
_ungrouped: list[str] = []
for line in unique[:200]:
stripped = line.strip()
# Match pattern: condition/env/step/metric: value
parts = stripped.split("/")
if len(parts) >= 3 and ":" in parts[-1]:
cond = parts[0]
detail = "/".join(parts[1:])
_grouped.setdefault(cond, []).append(f" - {detail}")
else:
_ungrouped.append(stripped)
formatted_lines: list[str] = []
if _grouped:
for cond, details in sorted(_grouped.items()):
formatted_lines.append(f"## Condition: {cond}")
formatted_lines.extend(details[:30])
if _ungrouped:
formatted_lines.extend(_ungrouped)
return (
f"\n\nACTUAL EXPERIMENT DATA (from {run_count} run(s) — use ONLY these numbers):\n"
"```\n"
+ "\n".join(formatted_lines[:200])
+ "\n```\n"
"CRITICAL: Every number in the Results table MUST come from the data above. "
"Do NOT round excessively, do NOT invent numbers, do NOT change values. "
f"The experiment ran {run_count} time(s) — state this accurately in the methodology.\n"
"NEVER paste raw metric paths (like 'condition/env/step/metric: value') "
"into the paper. Always convert to formatted LaTeX tables or inline prose.\n"
), has_parsed_metrics
def _write_paper_sections(
*,
llm: LLMClient,
pm: PromptManager,
run_dir: Path | None = None,
preamble: str,
topic_constraint: str,
exp_metrics_instruction: str,
citation_instruction: str,
outline: str,
model_name: str = "",
) -> str:
"""Write a conference-grade paper in 3 sequential LLM calls.
Call 1: Title + Abstract + Introduction + Related Work
Call 2: Method + Experiments (with full experiment data)
Call 3: Results + Discussion + Limitations + Conclusion
Each call receives prior sections for coherence.
"""
# Render writing_structure block for injection
try:
_writing_structure = pm.block("writing_structure")
except (KeyError, Exception): # noqa: BLE001
_writing_structure = ""
_overlay = _get_evolution_overlay(run_dir, "paper_draft")
system = pm.for_stage(
"paper_draft",
evolution_overlay=_overlay,
preamble=preamble,
topic_constraint=topic_constraint,
exp_metrics_instruction=exp_metrics_instruction,
citation_instruction=citation_instruction,
writing_structure=_writing_structure,
outline=outline,
).system
sections: list[str] = []
# --- R4-3: Title guidelines and abstract structure ---
try:
title_guidelines = pm.block("title_guidelines")
except (KeyError, Exception): # noqa: BLE001
title_guidelines = ""
try:
abstract_structure = pm.block("abstract_structure")
except (KeyError, Exception): # noqa: BLE001
abstract_structure = ""
# IMP-20/25/31/24: Academic style, narrative, anti-hedging, anti-repetition
try:
academic_style_guide = pm.block("academic_style_guide")
except (KeyError, Exception): # noqa: BLE001
academic_style_guide = ""
try:
narrative_writing_rules = pm.block("narrative_writing_rules")
except (KeyError, Exception): # noqa: BLE001
narrative_writing_rules = ""
try:
anti_hedging_rules = pm.block("anti_hedging_rules")
except (KeyError, Exception): # noqa: BLE001
anti_hedging_rules = ""
try:
anti_repetition_rules = pm.block("anti_repetition_rules")
except (KeyError, Exception): # noqa: BLE001
anti_repetition_rules = ""
# --- Call 1: Title + Abstract + Introduction + Related Work ---
call1_user = (
f"{preamble}\n\n"
f"{topic_constraint}"
f"{citation_instruction}\n\n"
f"{title_guidelines}\n\n"
f"{academic_style_guide}\n"
f"{narrative_writing_rules}\n"
f"{anti_hedging_rules}\n"
f"{anti_repetition_rules}\n\n"
"Write the following sections of a NeurIPS/ICML-quality paper in markdown. "
"Follow the LENGTH REQUIREMENTS strictly:\n\n"
"1. **Title** (HARD RULE: MUST be 14 words or fewer. Create a catchy method name "
"first, then build the title: 'MethodName: Subtitle'. If your title exceeds 14 words, "
"it will be automatically rejected. NEVER use 'Untitled Paper'.)\n"
f"2. **Abstract** (150-220 words — HARD LIMIT. Do NOT exceed 220 words. "
f"Do NOT include raw metric paths or 16-digit decimals.){abstract_structure}\n"
"3. **Introduction** (800-1000 words): real-world motivation, problem statement, "
"research gap analysis with citations, method overview, 3-4 contributions as bullet points, "
"paper organization paragraph. MUST cite 8-12 references.\n"
"4. **Related Work** (600-800 words): organized into 3-4 thematic subsections, each discussing "
"4-5 papers with proper citations. Compare approaches, identify limitations, position this work.\n\n"
f"Outline:\n{outline}\n\n"
"Output markdown with ## headers. Do NOT include a References section.\n"
"IMPORTANT: Start DIRECTLY with '## Title'. Do NOT include any preamble, "
"data verification, condition listing, or metric enumeration before the title. "
"The paper should read like a published manuscript, not a data report."
)
# R14-1: Higher token limit for reasoning models
_paper_max_tokens = 12000
if any(model_name.startswith(p) for p in ("gpt-5", "o3", "o4")):
_paper_max_tokens = 24000
# T3.5: Retry once on failure, use placeholder if still fails
try:
resp1 = _chat_with_prompt(llm, system, call1_user, max_tokens=_paper_max_tokens, retries=1)
part1 = resp1.content.strip()
except Exception: # noqa: BLE001
logger.error("Stage 17: Part 1 LLM call failed after retry — using placeholder")
part1 = (
"## Title\n[PLACEHOLDER — LLM call failed]\n\n"
"## Abstract\n[This section could not be generated due to an LLM error. "
"Please regenerate this stage.]\n\n"
"## Introduction\n[PLACEHOLDER]\n\n"
"## Related Work\n[PLACEHOLDER]"
)
sections.append(part1)
logger.info("Stage 17: Part 1 (Title+Abstract+Intro+Related Work) — %d chars", len(part1))
# --- Call 2: Method + Experiments ---
call2_user = (
f"{preamble}\n\n"
f"{topic_constraint}"
f"{exp_metrics_instruction}\n\n"
f"{narrative_writing_rules}\n"
f"{anti_hedging_rules}\n\n"
# IMP-21: Citation instruction for Method + Experiments
"CITATION REQUIREMENT: The Method section MUST cite at least 3-5 related "
"technical papers (foundations your method builds on). The Experiments section "
"MUST cite baseline method papers. Use [cite_key] syntax.\n"
f"{citation_instruction}\n\n"
"You are continuing a paper. The sections written so far are:\n\n"
f"---\n{part1}\n---\n\n"
"Now write the next sections, maintaining consistency with the above:\n\n"
"5. **Method** (1000-1500 words): formal problem definition with mathematical notation "
"($x$, $\\theta$, etc.), detailed algorithm description with equations, step-by-step procedure, "
"complexity analysis, design rationale for key choices. Include algorithm pseudocode if applicable. "
"Write as FLOWING PROSE — do NOT use bullet-point lists for method components.\n"
"6. **Experiments** (800-1200 words): detailed experimental setup, datasets with statistics "
"(size, splits, features), all baselines and their implementations, hyperparameter settings "
"in a markdown table, evaluation metrics with mathematical definitions, hardware and runtime info.\n"
"METHOD NAMES IN TABLES: Use SHORT abbreviations (4-8 chars) for method names "
"in tables. Define abbreviation mappings in a footnote. "
"NEVER put method names longer than 20 characters in table cells.\n\n"
f"Outline:\n{outline}\n\n"
"Output markdown with ## headers. Continue from where Part 1 ended."
)
try:
resp2 = _chat_with_prompt(llm, system, call2_user, max_tokens=_paper_max_tokens, retries=1)
part2 = resp2.content.strip()
except Exception: # noqa: BLE001
logger.error("Stage 17: Part 2 LLM call failed after retry — using placeholder")
part2 = (
"## Method\n[PLACEHOLDER — LLM call failed. Please regenerate this stage.]\n\n"
"## Experiments\n[PLACEHOLDER]"
)
sections.append(part2)
logger.info("Stage 17: Part 2 (Method+Experiments) — %d chars", len(part2))
# --- Call 3: Results + Discussion + Limitations + Conclusion ---
call3_user = (
f"{preamble}\n\n"
f"{topic_constraint}"
f"{exp_metrics_instruction}\n\n"
f"{narrative_writing_rules}\n"
f"{anti_hedging_rules}\n"
f"{anti_repetition_rules}\n\n"
# IMP-21: Citation instruction for Results + Discussion + Conclusion
"CITATION REQUIREMENT: The Discussion section MUST cite at least 3-5 papers "
"when comparing findings with prior work. The Conclusion may cite 1-2 "
"foundational references.\n"
f"{citation_instruction}\n\n"
"You are completing a paper. The sections written so far are:\n\n"
f"---\n{part1}\n\n{part2}\n---\n\n"
"Now write the final sections, maintaining consistency:\n\n"
"7. **Results** (600-800 words):\n"
" - START with an AGGREGATED results table (Table 1): rows = methods, columns = metrics.\n"
" Each cell = mean \u00b1 std across seeds. Bold the best value per column.\n"
" EVERY table MUST have a descriptive caption that allows understanding without "
" reading the main text. NEVER use just 'Table 1' as a caption.\n"
" - Follow with a PER-REGIME table (Table 2) breaking down by easy/hard regimes.\n"
" - Include a STATISTICAL COMPARISON table (Table 3): paired t-tests between key methods.\n"
" - NEVER dump raw per-seed numbers in the main text. Aggregate first, then discuss.\n"
" - MUST include at least 2 figures using markdown image syntax: \n"
" One figure MUST be a performance comparison chart. Figures MUST be referenced "
" in text: 'As shown in Figure 1, ...'\n"
"8. **Discussion** (400-600 words): interpretation of key findings, unexpected results, "
"comparison with prior work (CITE 3-5 papers here!), practical implications.\n"
"9. **Limitations** (200-300 words): honest assessment of scope, dataset, methodology. "
"ALL caveats consolidated HERE — nowhere else in the paper.\n"
"10. **Conclusion** (100-200 words MAXIMUM — this is a HARD LIMIT): "
"Summarize contributions in 2-3 sentences. State main finding in 1 sentence. "
"Suggest 2-3 concrete future directions in 1-2 sentences. "
"Do NOT repeat any specific numbers from Results. Do NOT restate the abstract. "
"A good conclusion is SHORT and forward-looking.\n\n"
"CRITICAL FORMATTING RULES FOR ALL SECTIONS:\n"
"- Write as FLOWING PROSE paragraphs, NOT bullet-point lists\n"
"- NEVER dump raw metric paths like 'config/method_name/seed_3/primary_metric'\n"
"- All numbers must be rounded to 4 decimal places maximum\n"
"- Every table MUST have a descriptive caption (not just 'Table 1')\n"
"- Use \\begin{algorithm} or pseudocode notation, NOT \\begin{verbatim}\n\n"
"Output markdown with ## headers. Do NOT include a References section."
)
try:
resp3 = _chat_with_prompt(llm, system, call3_user, max_tokens=_paper_max_tokens, retries=1)
part3 = resp3.content.strip()
except Exception: # noqa: BLE001
logger.error("Stage 17: Part 3 LLM call failed after retry — using placeholder")
part3 = (
"## Results\n[PLACEHOLDER — LLM call failed. Please regenerate this stage.]\n\n"
"## Discussion\n[PLACEHOLDER]\n\n"
"## Limitations\n[PLACEHOLDER]\n\n"
"## Conclusion\n[PLACEHOLDER]"
)
sections.append(part3)
logger.info("Stage 17: Part 3 (Results+Discussion+Limitations+Conclusion) — %d chars", len(part3))
# Combine all sections
draft = "\n\n".join(sections)
# R32: Strip data verification preamble that LLMs sometimes emit before
# the actual paper. The preamble typically starts with "## Tested Conditions"
# or similar headings and ends before "## Title".
import re as _re_strip
_title_match = _re_strip.search(r"^## Title\b", draft, _re_strip.MULTILINE)
if _title_match and _title_match.start() > 200:
_stripped = draft[_title_match.start():]
logger.info(
"R32: Stripped %d-char preamble before '## Title'",
_title_match.start(),
)
draft = _stripped
total_words = len(draft.split())
logger.info("Stage 17: Full draft — %d chars, ~%d words", len(draft), total_words)
return draft
# ---------------------------------------------------------------------------
# Draft quality validation (section balance + bullet-point density)
# ---------------------------------------------------------------------------
# Sections where bullets/numbered lists are acceptable.
_BULLET_LENIENT_SECTIONS = frozenset({
"introduction", "limitations", "limitation",
"limitations and future work", "abstract",
})
# Main body sections used for balance ratio check.
_BALANCE_SECTIONS = frozenset({
"introduction", "related work", "method", "experiments", "results",
"discussion",
})
def _validate_draft_quality(
draft: str,
stage_dir: Path | None = None,
) -> dict[str, Any]:
"""Validate a paper draft for section balance and prose quality.
Checks:
1. Per-section word count vs ``SECTION_WORD_TARGETS``.
2. Bullet-point / numbered-list density per section.
3. Largest-to-smallest main-section word-count ratio.
Returns a dict with ``section_analysis``, ``overall_warnings``, and
``revision_directives``. Optionally writes ``draft_quality.json`` to
*stage_dir*.
"""
from researchclaw.prompts import SECTION_WORD_TARGETS, _SECTION_TARGET_ALIASES
_heading_re = re.compile(r"^(#{1,4})\s+(.+)$", re.MULTILINE)
matches = list(_heading_re.finditer(draft))
sections_data: list[dict[str, Any]] = []
for i, m in enumerate(matches):
level = len(m.group(1))
heading = m.group(2).strip()
start = m.end()
end = matches[i + 1].start() if i + 1 < len(matches) else len(draft)
body = draft[start:end].strip()
sections_data.append({
"heading": heading,
"heading_lower": heading.strip().lower(),
"level": level,
"body": body,
})
section_analysis: list[dict[str, Any]] = []
overall_warnings: list[str] = []
revision_directives: list[str] = []
main_section_words: dict[str, int] = {}
_bullet_re = re.compile(r"^\s*[-*]\s+", re.MULTILINE)
_numbered_re = re.compile(r"^\s*\d+\.\s+", re.MULTILINE)
# BUG-24: Accumulate subsection (H3+) word counts into parent H2 sections
_subsection_words: dict[str, int] = {}
_current_parent = ""
for sec in sections_data:
if sec["level"] <= 2:
_current_parent = sec["heading_lower"]
_subsection_words.setdefault(_current_parent, 0)
else:
# Add subsection words to parent
_subsection_words[_current_parent] = (
_subsection_words.get(_current_parent, 0) + len(sec["body"].split())
)
for sec in sections_data:
if sec["level"] > 2:
continue
heading_lower: str = sec["heading_lower"]
body: str = sec["body"]
# BUG-24: Include subsection words in the parent's word count
word_count = len(body.split()) + _subsection_words.get(heading_lower, 0)
canon = heading_lower
if canon not in SECTION_WORD_TARGETS:
canon = _SECTION_TARGET_ALIASES.get(heading_lower, "")
entry: dict[str, Any] = {
"heading": sec["heading"],
"word_count": word_count,
"canonical": canon,
}
if canon and canon in SECTION_WORD_TARGETS:
lo, hi = SECTION_WORD_TARGETS[canon]
entry["target"] = [lo, hi]
if word_count < int(lo * 0.7):
overall_warnings.append(
f"{sec['heading']} is severely under target "
f"({word_count} words, target {lo}-{hi})"
)
revision_directives.append(
f"EXPAND {sec['heading']} from {word_count} to {lo}+ words. "
f"Add substantive content \u2014 do NOT pad with filler."
)
entry["status"] = "severely_short"
elif word_count < lo:
overall_warnings.append(
f"{sec['heading']} is under target "
f"({word_count} words, target {lo}-{hi})"
)
revision_directives.append(
f"Expand {sec['heading']} from {word_count} to {lo}+ words."
)
entry["status"] = "short"
elif word_count > int(hi * 1.3):
overall_warnings.append(
f"{sec['heading']} exceeds target "
f"({word_count} words, target {lo}-{hi})"
)
revision_directives.append(
f"Compress {sec['heading']} from {word_count} to {hi} words or fewer."
)
entry["status"] = "long"
else:
entry["status"] = "ok"
if body:
total_lines = len([ln for ln in body.splitlines() if ln.strip()])
bullet_lines = len(_bullet_re.findall(body)) + len(_numbered_re.findall(body))
density = bullet_lines / total_lines if total_lines > 0 else 0.0
entry["bullet_density"] = round(density, 2)
threshold = 0.50 if heading_lower in _BULLET_LENIENT_SECTIONS else 0.25
if density > threshold and total_lines >= 4:
overall_warnings.append(
f"{sec['heading']} has {bullet_lines}/{total_lines} "
f"bullet/numbered lines ({density:.0%} density, "
f"threshold {threshold:.0%})"
)
revision_directives.append(
f"REWRITE {sec['heading']} as flowing academic prose. "
f"Convert bullet points to narrative paragraphs."
)
entry["bullet_status"] = "high"
else:
entry["bullet_status"] = "ok"
canon_balance = canon or heading_lower
if canon_balance in _BALANCE_SECTIONS:
main_section_words[canon_balance] = word_count
section_analysis.append(entry)
if len(main_section_words) >= 2:
wc_values = list(main_section_words.values())
max_wc = max(wc_values)
min_wc = min(wc_values)
if min_wc > 0 and max_wc / min_wc > 3.0:
largest = max(main_section_words, key=main_section_words.get) # type: ignore[arg-type]
smallest = min(main_section_words, key=main_section_words.get) # type: ignore[arg-type]
overall_warnings.append(
f"Section imbalance: {largest} ({max_wc} words) vs "
f"{smallest} ({min_wc} words) \u2014 ratio {max_wc / min_wc:.1f}x"
)
revision_directives.append(
f"Rebalance sections: expand {smallest} and/or compress {largest} "
f"to achieve more even section lengths."
)
# --- C-4/C-5: Citation count and recency checks ---
_cite_pattern = re.compile(r"\[([a-zA-Z][a-zA-Z0-9_-]*\d{4}[a-zA-Z0-9]*)\]")
cited_keys = set(_cite_pattern.findall(draft))
if cited_keys:
n_citations = len(cited_keys)
if n_citations < 15:
overall_warnings.append(
f"Only {n_citations} unique citations found (target: >=15 for a full paper)"
)
revision_directives.append(
f"Add more references — a top-venue paper typically cites 25-40 works. "
f"Currently only {n_citations} unique citations."
)
# Check recency: count citations with year >= current_year - 2
_year_pat = re.compile(r"(\d{4})")
import datetime as _dt_cit
_cur_year = _dt_cit.datetime.now().year
recent_count = sum(
1 for k in cited_keys
for m in [_year_pat.search(k)]
if m and int(m.group(1)) >= _cur_year - 2
)
recency_ratio = recent_count / n_citations if n_citations > 0 else 0.0
if recency_ratio < 0.3 and n_citations >= 10:
overall_warnings.append(
f"Citation recency low: only {recent_count}/{n_citations} "
f"({recency_ratio:.0%}) from last 3 years (target: >=30%%)"
)
# --- Abstract and Conclusion length enforcement ---
for sec in sections_data:
hl = sec["heading_lower"]
body_text: str = sec["body"]
wc = len(body_text.split())
if hl == "abstract" and wc > 250:
overall_warnings.append(
f"Abstract is too long: {wc} words (target: 150-220 words)"
)
revision_directives.append(
f"COMPRESS the Abstract from {wc} to 150-220 words. "
f"Remove raw metric values, redundant context, and self-references."
)
if hl in ("conclusion", "conclusions", "conclusion and future work"):
if wc > 300:
overall_warnings.append(
f"Conclusion is too long: {wc} words (target: 100-200 words)"
)
revision_directives.append(
f"COMPRESS the Conclusion from {wc} to 100-200 words. "
f"Do NOT repeat specific metric values from Results. "
f"Summarize findings in 2-3 sentences, then 2-3 future directions."
)
# --- Raw metric path detection (log dumps in prose) ---
_raw_path_re = re.compile(
r"\\texttt\{[a-zA-Z0-9_/.-]+(?:/[a-zA-Z0-9_/.-]+){2,}",
)
raw_path_count = len(_raw_path_re.findall(draft))
if raw_path_count > 3:
overall_warnings.append(
f"Raw metric paths in prose: {raw_path_count} instances of "
f"\\texttt{{config/path/metric}} style dumps"
)
revision_directives.append(
"REMOVE raw experiment log paths from prose. Replace "
"\\texttt{config/metric/path} with human-readable metric names "
"and summarize values in tables, not inline text."
)
# --- Writing quality lint ---
_weasel_words = re.compile(
r"\b(various|many|several|quite|fairly|really|very|rather|"
r"somewhat|relatively|arguably|interestingly|importantly|"
r"it is well known that|it is obvious that|clearly)\b",
re.IGNORECASE,
)
_duplicate_words = re.compile(r"\b(\w+)\s+\1\b", re.IGNORECASE)
weasel_count = len(_weasel_words.findall(draft))
dup_matches = _duplicate_words.findall(draft)
dup_count = len([d for d in dup_matches if d.lower() not in ("that", "had")])
if weasel_count > 20:
overall_warnings.append(
f"High weasel-word count: {weasel_count} instances "
f"(consider replacing vague words with precise language)"
)
revision_directives.append(
"Replace vague hedging words (various, several, quite, fairly, "
"rather, somewhat) with precise quantities or remove them."
)
if dup_count > 0:
overall_warnings.append(
f"Duplicate adjacent words found: {dup_count} instance(s) "
f"(e.g., 'the the', 'is is')"
)
revision_directives.append(
"Fix duplicate adjacent words (likely typos)."
)
# --- AI-slop / boilerplate detection ---
_BOILERPLATE_PHRASES = [
"delves into", "delve into", "it is worth noting",
"it should be noted", "it is important to note",
"leverage the power of", "leverages the power of",
"in this paper, we propose", "in this work, we propose",
"to the best of our knowledge",
"in the realm of", "in the landscape of",
"plays a crucial role", "plays a pivotal role",
"groundbreaking", "cutting-edge", "state-of-the-art",
"game-changing", "paradigm shift",
"a myriad of", "a plethora of",
"aims to bridge the gap", "bridge the gap",
"shed light on", "sheds light on",
"pave the way", "paves the way",
"the advent of", "with the advent of",
"in recent years", "in recent times",
"has gained significant attention",
"has attracted considerable interest",
"has emerged as a promising",
"a comprehensive overview",
"a holistic approach", "holistic understanding",
"showcasing the efficacy", "demonstrate the efficacy",
"multifaceted", "underscores the importance",
"navigate the complexities",
"harness the potential", "harnessing the power",
"it is imperative to", "it is crucial to",
"a nuanced understanding", "nuanced approach",
"robust and scalable", "seamlessly integrates",
"the intricacies of", "intricate interplay",
"facilitate a deeper understanding",
"a testament to",
]
draft_lower = draft.lower()
boilerplate_hits: list[str] = []
for phrase in _BOILERPLATE_PHRASES:
count = draft_lower.count(phrase)
if count > 0:
boilerplate_hits.extend([phrase] * count)
if len(boilerplate_hits) > 5:
unique_phrases = sorted(set(boilerplate_hits))[:5]
overall_warnings.append(
f"AI boilerplate detected: {len(boilerplate_hits)} instances "
f"of generic LLM phrases (e.g., {', '.join(repr(p) for p in unique_phrases[:3])})"
)
revision_directives.append(
"REWRITE sentences containing AI-generated boilerplate phrases. "
"Replace generic language (e.g., 'delves into', 'it is worth noting', "
"'leverages the power of', 'plays a crucial role', 'paves the way') "
"with precise, specific academic language."
)
# --- Related work depth check ---
_rw_headings = {"related work", "related works", "background", "literature review"}
rw_body = ""
for sec in sections_data:
if sec["heading_lower"] in _rw_headings and sec["level"] <= 2:
rw_body = sec["body"]
break
if rw_body and len(rw_body.split()) > 50:
_comparative_pats = re.compile(
r"\b(unlike|in contrast|whereas|while .+ focus|"
r"however|differ(?:s|ent)|our (?:method|approach) .+ instead|"
r"we (?:instead|differ)|compared to|as opposed to|"
r"goes beyond|extends|improves upon|addresses the limitation)\b",
re.IGNORECASE,
)
sentences = [s.strip() for s in re.split(r"[.!?]+", rw_body) if s.strip()]
comparative_sents = sum(1 for s in sentences if _comparative_pats.search(s))
ratio = comparative_sents / len(sentences) if sentences else 0.0
if ratio < 0.15 and len(sentences) >= 5:
overall_warnings.append(
f"Related Work is purely descriptive: only {comparative_sents}/{len(sentences)} "
f"sentences ({ratio:.0%}) contain comparative language (target: >=15%)"
)
revision_directives.append(
"REWRITE Related Work to critically compare with prior methods. "
"Use phrases like 'unlike X, our approach...', 'in contrast to...', "
"'while X focuses on... we address...' for at least 20% of sentences."
)
# --- Statistical rigor check (result sections) ---
_results_headings = {"results", "experiments", "experimental results", "evaluation"}
results_body = ""
for sec in sections_data:
if sec["heading_lower"] in _results_headings and sec["level"] <= 2:
results_body += sec["body"] + "\n"
if results_body and len(results_body.split()) > 100:
has_std = bool(re.search(r"\u00b1|\\pm|\bstd\b|\\std\b|standard deviation", results_body, re.IGNORECASE))
has_ci = bool(re.search(r"confidence interval|\bCI\b|95%|p-value|p\s*<", results_body, re.IGNORECASE))
has_seeds = bool(re.search(r"(?:seed|run|trial)s?\s*[:=]\s*\d|averaged?\s+over\s+\d+\s+(?:seed|run|trial)", results_body, re.IGNORECASE))
if not has_std and not has_ci and not has_seeds:
overall_warnings.append(
"No statistical measures found in results (no std, CI, p-values, or multi-seed reporting)"
)
revision_directives.append(
"ADD error bars (\u00b1std), confidence intervals, or note the number of "
"random seeds used. Single-run results without variance reporting "
"are insufficient for top venues."
)
result: dict[str, Any] = {
"section_analysis": section_analysis,
"overall_warnings": overall_warnings,
"revision_directives": revision_directives,
}
if stage_dir is not None:
(stage_dir / "draft_quality.json").write_text(
json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8"
)
if overall_warnings:
logger.warning(
"Draft quality: %d warning(s) \u2014 %s",
len(overall_warnings),
"; ".join(overall_warnings[:3]),
)
else:
logger.info("Draft quality: all checks passed")
return result
def _review_compiled_pdf(
pdf_path: Path,
llm: LLMClient,
topic: str,
) -> dict[str, Any]:
"""Multi-dimensional LLM review of compiled paper (AI-Scientist style).
Scores the paper on 7 academic review dimensions (1-10 each),
identifies specific strengths/weaknesses, and provides an overall
accept/reject recommendation with confidence.
Returns a dict with dimensional scores, issues, and decision.
"""
if not pdf_path.exists():
return {}
# Use source-based review since not all models support vision
tex_path = pdf_path.with_suffix(".tex")
if not tex_path.exists():
return {}
tex_content = tex_path.read_text(encoding="utf-8")[:12000]
review_prompt = (
"You are a senior Area Chair at a top AI conference (NeurIPS/ICML/ICLR) "
"reviewing a paper submission. Provide a rigorous, structured review.\n\n"
f"PAPER TOPIC: {topic}\n\n"
f"LaTeX source:\n```latex\n{tex_content}\n```\n\n"
"REVIEW INSTRUCTIONS:\n"
"Score each dimension 1-10 (1=unacceptable, 5=borderline, 8=strong accept, "
"10=best paper candidate). Be critical but fair.\n\n"
"DIMENSIONS:\n"
"1. SOUNDNESS: Are claims well-supported? Is methodology correct? "
"Are there logical gaps or unsupported claims?\n"
"2. PRESENTATION: Is the writing clear, flowing, and professional? "
"Are there grammar errors, bullet lists in prose sections, or "
"boilerplate phrases? Is it free of AI-generated slop?\n"
"3. CONTRIBUTION: Is the contribution significant? Does it advance "
"the field beyond incremental improvement?\n"
"4. ORIGINALITY: Is the approach novel? Does it differentiate clearly "
"from prior work?\n"
"5. CLARITY: Are the method and results easy to understand? Are figures "
"and tables well-designed with descriptive captions?\n"
"6. SIGNIFICANCE: Would the community benefit from this work? Does it "
"open new research directions?\n"
"7. REPRODUCIBILITY: Are experimental details sufficient to reproduce "
"results? Are hyperparameters, datasets, and metrics clearly stated?\n\n"
"Also evaluate:\n"
"- Are all figures referenced in the text?\n"
"- Are tables properly formatted (booktabs style, no vertical rules)?\n"
"- Does the related work critically compare, not just list papers?\n"
"- Are statistical measures (std, CI, multiple seeds) reported?\n"
"- Is there a clear limitations section?\n\n"
"Return a JSON object:\n"
"{\n"
' "soundness": N,\n'
' "presentation": N,\n'
' "contribution": N,\n'
' "originality": N,\n'
' "clarity": N,\n'
' "significance": N,\n'
' "reproducibility": N,\n'
' "overall_score": N,\n'
' "confidence": N,\n'
' "decision": "accept" or "reject",\n'
' "strengths": ["strength1", "strength2", ...],\n'
' "weaknesses": ["weakness1", "weakness2", ...],\n'
' "critical_issues": ["issue requiring revision", ...],\n'
' "minor_issues": ["formatting/typo issues", ...],\n'
' "summary": "2-3 sentence overall assessment"\n'
"}\n"
)
try:
resp = llm.chat(
messages=[{"role": "user", "content": review_prompt}],
system=(
"You are a meticulous, critical academic reviewer. "
"You have reviewed 100+ papers at top venues. "
"Score honestly — most papers deserve 4-6, not 7-9. "
"Flag any sign of AI-generated boilerplate."
),
)
review_data = _safe_json_loads(resp.content, {})
if isinstance(review_data, dict) and "overall_score" in review_data:
# Compute weighted aggregate if individual scores present
dim_scores = {
k: review_data.get(k, 0)
for k in (
"soundness", "presentation", "contribution",
"originality", "clarity", "significance",
"reproducibility",
)
}
valid = {k: v for k, v in dim_scores.items() if isinstance(v, (int, float)) and v > 0}
if valid:
review_data["mean_score"] = round(sum(valid.values()) / len(valid), 2)
return review_data
except Exception as exc: # noqa: BLE001
logger.debug("PDF review LLM call failed: %s", exc)
return {}
def _check_ablation_effectiveness(
exp_summary: dict[str, Any],
threshold: float = 0.02,
) -> list[str]:
"""P7: Check if ablation results are within *threshold* of baseline.
Returns a list of warning strings for ineffective ablations.
Threshold tightened from 5% to 2% (Improvement C) — ablations with
< 2% relative difference AND < 1pp absolute difference are flagged
as TRIVIAL.
"""
warnings: list[str] = []
cond_summaries = exp_summary.get("condition_summaries", {})
if not isinstance(cond_summaries, dict) or not cond_summaries:
return warnings
# Find baseline/control condition
baseline_name = None
baseline_mean = None
for name, data in cond_summaries.items():
if not isinstance(data, dict):
continue
name_lower = name.lower()
if any(tag in name_lower for tag in ("baseline", "control", "vanilla", "standard")):
metrics = data.get("metrics") or {}
if not isinstance(metrics, dict):
metrics = {}
# Use the first metric that has a _mean suffix or the first available
for mk, mv in metrics.items():
if mk.endswith("_mean"):
baseline_name = name
baseline_mean = float(mv)
break
if baseline_mean is None:
for mk, mv in metrics.items():
try:
baseline_name = name
baseline_mean = float(mv)
break
except (TypeError, ValueError):
continue
if baseline_name:
break
if baseline_name is None or baseline_mean is None:
return warnings
# Check each ablation condition
for name, data in cond_summaries.items():
if not isinstance(data, dict):
continue
name_lower = name.lower()
if name == baseline_name:
continue
if not any(tag in name_lower for tag in ("ablation", "no_", "without", "reduced")):
continue
metrics = data.get("metrics") or {}
if not isinstance(metrics, dict):
metrics = {}
for mk, mv in metrics.items():
if not mk.endswith("_mean"):
continue
try:
abl_val = float(mv)
except (TypeError, ValueError):
continue
if baseline_mean != 0:
rel_diff = abs(abl_val - baseline_mean) / abs(baseline_mean)
else:
rel_diff = abs(abl_val - baseline_mean)
abs_diff = abs(abl_val - baseline_mean)
# Improvement C: Tighter check — both relative < threshold
# AND absolute < 1pp → TRIVIAL
if rel_diff < threshold and abs_diff < 1.0:
warnings.append(
f"TRIVIAL: Ablation '{name}' {mk}={abl_val:.4f} is within "
f"{rel_diff:.1%} (abs {abs_diff:.4f}pp) of baseline "
f"'{baseline_name}' {mk}={baseline_mean:.4f} — "
f"ablation is ineffective"
)
elif rel_diff < threshold:
warnings.append(
f"Ablation '{name}' {mk}={abl_val:.4f} is within "
f"{rel_diff:.1%} of baseline '{baseline_name}' "
f"{mk}={baseline_mean:.4f} — ablation may be ineffective"
)
break # Only check the first _mean metric per condition
# Improvement C: Prepend CRITICAL summary if >50% trivial
trivial_count = sum(1 for w in warnings if w.startswith("TRIVIAL:"))
if trivial_count > 0 and len(warnings) > 0 and trivial_count / len(warnings) > 0.5:
warnings.insert(0, (
f"CRITICAL: {trivial_count}/{len(warnings)} ablations are trivially "
f"similar to baseline (<{threshold:.0%} relative, <1pp absolute). "
f"The ablation design is likely broken — components are not effectively removed."
))
return warnings
def _detect_result_contradictions(
exp_summary: dict[str, Any],
metric_direction: str = "maximize",
) -> list[str]:
"""P10: Detect contradictions in experiment results before paper writing.
Returns a list of advisory strings to inject into paper writing prompt.
"""
advisories: list[str] = []
cond_summaries = exp_summary.get("condition_summaries", {})
if not isinstance(cond_summaries, dict) or not cond_summaries:
return advisories
# Collect primary metric means per condition
means: dict[str, float] = {}
for name, data in cond_summaries.items():
if not isinstance(data, dict):
continue
metrics = data.get("metrics", {})
for mk, mv in metrics.items():
if mk.endswith("_mean"):
try:
means[name] = float(mv)
except (TypeError, ValueError):
pass
break
if len(means) < 2:
return advisories
# Check 1: All methods within noise margin (2% relative spread)
vals = list(means.values())
val_range = max(vals) - min(vals)
val_mean = sum(vals) / len(vals)
if val_mean != 0 and (val_range / abs(val_mean)) < 0.02:
advisories.append(
"NULL RESULT: All methods produce nearly identical primary metric values "
f"(range={val_range:.4f}, mean={val_mean:.4f}). Frame this as a null result — "
"the methods are statistically indistinguishable. Do NOT claim any method "
"is superior. Discuss possible explanations (task too easy/hard, metric "
"insensitive, insufficient differentiation in methods)."
)
# Check 2: Control/simple baseline outperforms proposed method
# BUG-P1: Respect metric_direction — "higher is better" vs "lower is better"
_maximize = metric_direction == "maximize"
baseline_val = None
baseline_name = None
proposed_val = None
proposed_name = None
for name, val in means.items():
name_lower = name.lower()
if any(tag in name_lower for tag in ("baseline", "control", "random", "vanilla")):
if baseline_val is None or (_maximize and val > baseline_val) or (not _maximize and val < baseline_val):
baseline_val = val
baseline_name = name
elif any(tag in name_lower for tag in ("proposed", "our", "novel", "method")):
if proposed_val is None or (_maximize and val > proposed_val) or (not _maximize and val < proposed_val):
proposed_val = val
proposed_name = name
if baseline_val is not None and proposed_val is not None:
_baseline_wins = (baseline_val > proposed_val) if _maximize else (baseline_val < proposed_val)
if _baseline_wins:
advisories.append(
f"NEGATIVE RESULT: Baseline '{baseline_name}' ({baseline_val:.4f}) "
f"outperforms proposed method '{proposed_name}' ({proposed_val:.4f}). "
"This is a NEGATIVE result. Do NOT claim the proposed method is superior. "
"Frame as 'An Empirical Study of...' or 'When X Falls Short'. "
"Discuss why the baseline won and what this implies for future work."
)
return advisories
def _execute_paper_draft(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
outline = _read_prior_artifact(run_dir, "outline.md") or ""
preamble = _build_context_preamble(
config,
run_dir,
include_goal=True,
include_hypotheses=True,
include_analysis=True,
include_experiment_data=True, # WS-5.1: inject real experiment data
)
# BUG-222: Read PROMOTED BEST experiment_summary for the paper prompt.
# Previous code (R21-1) picked the "richest" experiment_summary across
# all stage-14* dirs. After REFINE regression, a later iteration with
# more conditions but worse quality could win, feeding the LLM regressed
# data. Now: prefer experiment_summary_best.json (written by
# _promote_best_stage14()), fall back to richest stage-14* for
# non-REFINE runs.
exp_summary_text = None
_best_path = run_dir / "experiment_summary_best.json"
if _best_path.is_file():
try:
_text = _best_path.read_text(encoding="utf-8")
_parsed = _safe_json_loads(_text, {})
if isinstance(_parsed, dict) and (
_parsed.get("condition_summaries") or _parsed.get("metrics_summary")
):
exp_summary_text = _text
logger.info("BUG-222: Using promoted experiment_summary_best.json")
except OSError:
pass
if exp_summary_text is None:
# Fallback: pick richest stage-14* (pre-BUG-222 behavior)
_best_metric_count = 0
for _s14_dir in sorted(run_dir.glob("stage-14*")):
_candidate = _s14_dir / "experiment_summary.json"
if _candidate.is_file():
_text = _candidate.read_text(encoding="utf-8")
_parsed = _safe_json_loads(_text, {})
if isinstance(_parsed, dict):
_mcount = _parsed.get("total_metric_keys", 0) or len(
_parsed.get("metrics_summary", {})
)
_paired_count = len(_parsed.get("paired_comparisons", []))
_cond_count = len(_parsed.get("condition_summaries", {}))
_score = _mcount + _paired_count * 10 + _cond_count * 5
if _score > _best_metric_count:
_best_metric_count = _score
exp_summary_text = _text
logger.info(
"R21-1 fallback: Selected %s (score=%d)",
_s14_dir.name, _score,
)
if exp_summary_text is None:
exp_summary_text = _read_prior_artifact(run_dir, "experiment_summary.json")
exp_metrics_instruction = ""
has_real_metrics = False
_verified_registry = None # Phase 1: anti-fabrication verified data registry
# BUG-108: Load refinement_log so VerifiedRegistry has per-iteration metrics
_refinement_log_for_vr: dict | None = None
_rl_candidates = sorted(run_dir.glob("stage-13*/refinement_log.json"), reverse=True)
_rl_path = _rl_candidates[0] if _rl_candidates else None
if _rl_path and _rl_path.is_file():
try:
_refinement_log_for_vr = json.loads(_rl_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
pass
if exp_summary_text:
exp_summary = _safe_json_loads(exp_summary_text, {})
# Phase 1: Build VerifiedRegistry from experiment data
if isinstance(exp_summary, dict):
try:
from researchclaw.pipeline.verified_registry import VerifiedRegistry
# BUG-222: Use best_only=True to ensure paper tables reflect
# only the promoted best iteration, not regressed data
_verified_registry = VerifiedRegistry.from_run_dir(
run_dir,
metric_direction=config.experiment.metric_direction,
best_only=True,
)
logger.info(
"Stage 17: VerifiedRegistry — %d verified values, %d conditions",
len(_verified_registry.values),
len(_verified_registry.condition_names),
)
except Exception as _vr_exc:
logger.warning("Stage 17: Failed to build VerifiedRegistry: %s", _vr_exc)
if isinstance(exp_summary, dict) and exp_summary.get("metrics_summary"):
has_real_metrics = True
exp_metrics_instruction = (
"\n\nIMPORTANT: Use the ACTUAL experiment results provided in the context. "
"All numbers in the Results and Experiments sections MUST reference real data. "
"Do NOT write 'no quantitative results yet' or use placeholder numbers. "
"Cite specific metrics with their actual values.\n"
)
# Collect raw experiment stdout metrics as hard constraint for the paper
raw_metrics_block, _has_parsed_metrics = _collect_raw_experiment_metrics(run_dir)
if raw_metrics_block:
# BUG-23: Raw stdout alone is not sufficient — require either
# metrics_summary data, parsed metrics from run JSONs,
# OR at least 3 condition= patterns in raw block
_has_condition_pattern = len(re.findall(
r"condition[=:]", raw_metrics_block, re.IGNORECASE
)) >= 3
if has_real_metrics or _has_parsed_metrics or _has_condition_pattern:
has_real_metrics = True
exp_metrics_instruction += raw_metrics_block
# R18-1 + R19-6: Inject paired statistical comparisons AND condition summaries
if exp_summary_text:
exp_summary_parsed = _safe_json_loads(exp_summary_text, {})
if isinstance(exp_summary_parsed, dict):
# R19-6: Inject experiment scale header so LLM knows the data richness
_total_conds = exp_summary_parsed.get("total_conditions")
_total_mkeys = exp_summary_parsed.get("total_metric_keys")
if _total_conds or _total_mkeys:
scale_block = "\n\n## EXPERIMENT SCALE\n"
if _total_conds:
scale_block += f"- Total conditions tested: {_total_conds}\n"
if _total_mkeys:
scale_block += f"- Total metric keys collected: {_total_mkeys}\n"
scale_block += (
"- This is a MULTI-SEED experiment. Report mean +/- std across seeds.\n"
"- Do NOT describe results as 'single run' or 'preliminary'.\n"
)
exp_metrics_instruction += scale_block
# Improvement B: Inject seed insufficiency warnings
_seed_warns = exp_summary_parsed.get("seed_insufficiency_warnings", [])
if _seed_warns:
_sw_block = (
"\n\n## SEED INSUFFICIENCY WARNINGS\n"
"Some conditions were run with fewer than 3 seeds. "
"Results for these conditions MUST be footnoted as preliminary.\n"
"All tables MUST show mean ± std format. Single-run values "
"MUST be footnoted with '†single seed — interpret with caution'.\n"
)
for _sw in _seed_warns:
_sw_block += f"- {_sw}\n"
exp_metrics_instruction += _sw_block
# R19-6 + R33: Inject condition summaries with CIs
cond_summaries = exp_summary_parsed.get("condition_summaries", {})
if isinstance(cond_summaries, dict) and cond_summaries:
cond_block = "\n\n## PER-CONDITION SUMMARY (use in Results tables)\n"
for cname, cdata in sorted(cond_summaries.items()):
cond_block += f"\n### {cname}\n"
if not isinstance(cdata, dict):
continue
sr = cdata.get("success_rate")
if sr is not None:
try:
cond_block += f"- Success rate: {float(sr):.1%}\n"
except (ValueError, TypeError):
cond_block += f"- Success rate: {sr}\n"
ns = cdata.get("n_seeds") or cdata.get("n_seed_metrics")
if ns:
cond_block += f"- Seeds: {ns}\n"
ci_lo = cdata.get("ci95_low")
ci_hi = cdata.get("ci95_high")
if ci_lo is not None and ci_hi is not None:
try:
cond_block += f"- Bootstrap 95% CI: [{float(ci_lo):.4f}, {float(ci_hi):.4f}]\n"
except (ValueError, TypeError):
cond_block += f"- Bootstrap 95% CI: [{ci_lo}, {ci_hi}]\n"
cm = cdata.get("metrics") or {}
if isinstance(cm, dict) and cm:
for mk, mv in sorted(cm.items()):
if isinstance(mv, (int, float)):
cond_block += f"- {mk}: {mv:.4f}\n"
else:
cond_block += f"- {mk}: {mv}\n"
exp_metrics_instruction += cond_block
# R18-1: Inject paired statistical comparisons
paired = exp_summary_parsed.get("paired_comparisons", [])
if paired:
paired_block = "\n\n## PAIRED STATISTICAL COMPARISONS (use these in Results)\n"
paired_block += f"Total: {len(paired)} paired tests computed.\n"
for pc in paired:
if not isinstance(pc, dict):
continue
method = pc.get("method", "?")
baseline = pc.get("baseline", "?")
regime = pc.get("regime", "all")
md = pc.get("mean_diff", "?")
sd = pc.get("std_diff", "?")
ts = pc.get("t_stat", "?")
pv = pc.get("p_value", "?")
ci_lo = pc.get("ci95_low")
ci_hi = pc.get("ci95_high")
ci_str = ""
if ci_lo is not None and ci_hi is not None:
try:
ci_str = f", 95% CI [{float(ci_lo):.3f}, {float(ci_hi):.3f}]"
except (ValueError, TypeError):
ci_str = f", 95% CI [{ci_lo}, {ci_hi}]"
paired_block += (
f"- {method} vs {baseline} (regime={regime}): "
f"mean_diff={md}, std_diff={sd}, "
f"t={ts}, p={pv}{ci_str}\n"
)
exp_metrics_instruction += paired_block
# R24: Method naming map — translate generic condition labels
_cond_names = list(cond_summaries.keys()) if isinstance(cond_summaries, dict) and cond_summaries else []
if _cond_names:
naming_block = (
"\n\n## METHOD NAMING (CRITICAL — do NOT use generic labels in the paper)\n"
"The condition labels below come from the experiment code. In the paper, "
"you MUST use DESCRIPTIVE algorithm names, not generic labels.\n"
"- If a condition name is already descriptive (e.g., 'random_search', "
"'bayesian_optimization', 'ppo_policy'), use it directly as a proper name.\n"
"- If a condition name is generic (e.g., 'baseline_1', 'method_variant_1'), "
"you MUST infer the algorithm from the experiment code/context and use the "
"real algorithm name (e.g., 'Random Search', 'Bayesian Optimization', "
"'PPO', 'Curiosity-Driven RL').\n"
"- NEVER write `baseline_1` or `method_variant_1` in the paper text.\n"
f"- Conditions to name: {_cond_names}\n"
)
exp_metrics_instruction += naming_block
# IMP-8: Inject broken ablation warnings
abl_warnings = exp_summary_parsed.get("ablation_warnings", [])
if abl_warnings:
broken_block = (
"\n\n## BROKEN ABLATIONS (DO NOT discuss as valid results)\n"
"The following ablation conditions produced IDENTICAL outputs, "
"indicating implementation bugs. Do NOT present their differences "
"as findings. Mention them ONLY in a 'Limitations' sub-section "
"as known implementation issues:\n"
)
for _aw in abl_warnings:
broken_block += f"- {_aw}\n"
broken_block += (
"\nIf you reference these conditions, state explicitly: "
"'Due to an implementation defect, conditions X and Y produced "
"identical outputs; their comparison is therefore uninformative.'\n"
)
exp_metrics_instruction += broken_block
# R25: Statistical table format requirement
if paired:
stat_table_block = (
"\n\n## STATISTICAL TABLE REQUIREMENT (MANDATORY in Results section)\n"
"The Results section MUST include a statistical comparison table with columns:\n"
"| Comparison | Mean Diff | Std Diff | t-statistic | p-value | Significance |\n"
"Use the PAIRED STATISTICAL COMPARISONS data above to fill this table.\n"
"Mark significance: *** (p<0.001), ** (p<0.01), * (p<0.05), n.s.\n"
"This is non-negotiable — a top-venue paper MUST have statistical tests.\n"
)
exp_metrics_instruction += stat_table_block
# R26: Metric definition requirement
exp_metrics_instruction += (
"\n\n## METRIC DEFINITIONS (MANDATORY in Experiments section)\n"
"The Experiments section MUST define each metric:\n"
"- **Primary metric**: what it measures, how it is computed, range, direction "
"(higher/lower is better), and units if applicable.\n"
"- **Secondary metric**: same details.\n"
"- For time-to-event metrics: explain the horizon, what constitutes success, "
"and how failures are handled (e.g., set to max horizon).\n"
"- These definitions MUST appear BEFORE any results tables.\n"
)
# R27: Multi-seed framing enforcement
_any_seeds = any(
(cond_summaries.get(c) or {}).get("n_seed_metrics", 0) > 1
for c in _cond_names
) if _cond_names else False
if _any_seeds:
exp_metrics_instruction += (
"\n\n## MULTI-SEED EXPERIMENT FRAMING (CRITICAL)\n"
"This experiment uses MULTIPLE independent random seeds per condition.\n"
"- Report mean +/- std (or SE) for all metrics.\n"
"- NEVER describe this as 'a single run' or '1 benchmark-artifact run'.\n"
"- Frame as: 'We evaluate each method across N seeds per regime.'\n"
"- The seed-level data IS the evidence base — it is NOT a single observation.\n"
"- Include per-regime breakdowns (easy vs hard) as separate rows in tables.\n"
)
# BUG-003: Inject actual evaluated datasets as a hard constraint
if exp_summary_text:
_ds_parsed = _safe_json_loads(exp_summary_text, {})
if isinstance(_ds_parsed, dict):
_datasets: set[str] = set()
# Extract from condition names (often contain dataset info)
for _cname in (_ds_parsed.get("condition_summaries") or {}).keys():
_datasets.add(str(_cname))
# Extract from explicit "datasets" field if present
for _ds in (_ds_parsed.get("datasets") or []):
if isinstance(_ds, str):
_datasets.add(_ds)
# Extract from "benchmark" or "dataset" fields
for _key in ("benchmark", "dataset", "dataset_name"):
_dv = _ds_parsed.get(_key)
if isinstance(_dv, str) and _dv:
_datasets.add(_dv)
if _datasets:
exp_metrics_instruction += (
"\n\n## ACTUAL EVALUATED DATASETS (HARD CONSTRAINT)\n"
"The following datasets/conditions were ACTUALLY tested in experiments:\n"
+ "".join(f"- {d}\n" for d in sorted(_datasets))
+ "\nCRITICAL: Do NOT claim evaluation on any dataset not listed above.\n"
"Do NOT fabricate results for datasets you did not run experiments on.\n"
"If you reference other datasets, clearly state they are 'not evaluated "
"in this work' or are 'left for future work'.\n"
)
# P7: Ablation effectiveness check
if exp_summary_text:
_exp_parsed_p7 = _safe_json_loads(exp_summary_text, {})
if isinstance(_exp_parsed_p7, dict):
_abl_warnings = _check_ablation_effectiveness(_exp_parsed_p7)
if _abl_warnings:
_abl_block = (
"\n\n## ABLATION EFFECTIVENESS WARNINGS\n"
"The following ablations showed minimal effect (within 5% of baseline). "
"Discuss this honestly — it may indicate the ablated component is not "
"important, or the ablation was not properly implemented:\n"
)
for _aw in _abl_warnings:
_abl_block += f"- {_aw}\n"
exp_metrics_instruction += _abl_block
logger.warning("P7: Ablation effectiveness warnings: %s", _abl_warnings)
# P10: Contradiction detection
if exp_summary_text:
_exp_parsed_p10 = _safe_json_loads(exp_summary_text, {})
if isinstance(_exp_parsed_p10, dict):
_contradictions = _detect_result_contradictions(
_exp_parsed_p10, metric_direction=config.experiment.metric_direction
)
if _contradictions:
_contra_block = (
"\n\n## RESULT INTERPRETATION ADVISORIES (CRITICAL — read before writing)\n"
)
for _ca in _contradictions:
_contra_block += f"- {_ca}\n"
exp_metrics_instruction += _contra_block
logger.warning("P10: Contradiction advisories: %s", _contradictions)
# R10: HARD BLOCK — refuse to write paper when all data is simulated
all_simulated = True
for stage_subdir in sorted(run_dir.glob("stage-*/runs")):
for run_file in sorted(stage_subdir.glob("*.json")):
if run_file.name == "results.json":
continue
try:
_payload = json.loads(run_file.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
continue
if isinstance(_payload, dict) and _payload.get("status") != "simulated":
all_simulated = False
break
if not all_simulated:
break
if all_simulated:
logger.error(
"BLOCKED: All experiment data is simulated (mode='simulated'). "
"Cannot write a paper based on formulaic fake data. "
"Switch to experiment.mode='sandbox' and re-run."
)
(stage_dir / "paper_draft.md").write_text(
"# Paper Draft Blocked\n\n"
"**Reason**: All experiment results are from simulated mode "
"(formulaic data: `0.3 + idx * 0.03`). "
"These are not real experimental results.\n\n"
"**Action Required**: Set `experiment.mode: 'sandbox'` in "
"config.arc.yaml and re-run the pipeline.",
encoding="utf-8",
)
return StageResult(
stage=Stage.PAPER_DRAFT,
status=StageStatus.FAILED,
artifacts=("paper_draft.md",),
evidence_refs=(),
)
# R4-2: HARD BLOCK — refuse to write paper with no real data (ML/empirical domains)
# For non-empirical domains (math proofs, theoretical economics), allow proceeding
_domain_id, _domain_name, _domain_venues = _detect_domain(
config.research.topic, config.research.domains
)
_empirical_domains = {"ml", "engineering", "biology", "chemistry"}
if not has_real_metrics:
if _domain_id in _empirical_domains:
logger.error(
"BLOCKED: Cannot write paper — experiment produced NO metrics. "
"The pipeline will not fabricate results."
)
(stage_dir / "paper_draft.md").write_text(
"# Paper Draft Blocked\n\n"
"**Reason**: Experiment stage produced no metrics (status: failed/timeout). "
"Cannot write a paper without real experimental data.\n\n"
"**Action Required**: Fix experiment execution or increase time_budget_sec.",
encoding="utf-8",
)
return StageResult(
stage=Stage.PAPER_DRAFT,
status=StageStatus.FAILED,
artifacts=("paper_draft.md",),
evidence_refs=(),
)
else:
logger.warning(
"No experiment metrics found, but domain '%s' may be non-empirical "
"(theoretical/mathematical). Proceeding with paper draft.",
_domain_name,
)
# R11-5: Experiment quality minimum threshold before paper writing
# Parse analysis.md for quality rating and condition completeness
analysis_text = _read_best_analysis(run_dir)
_quality_warnings: list[str] = []
# Check 1: Was the analysis quality rating very low?
import re as _re_q
_rating_match = _re_q.search(
r"(?:quality\s+rating|result\s+quality)[:\s]*\**(\d+)\s*/\s*10",
analysis_text,
_re_q.IGNORECASE,
)
if _rating_match:
_analysis_rating = int(_rating_match.group(1))
if _analysis_rating <= 3:
_quality_warnings.append(
f"Analysis rated experiment quality {_analysis_rating}/10"
)
# BUG-23: If quality rating is ≤ 2, force has_real_metrics = False
# to prevent fabricated results even if stdout had stray numbers.
# R5-BUG-05: Skip override when _has_parsed_metrics is True — the
# analysis.md may be stale (from pre-refinement Stage 14) while
# Stage 13 refinement produced real parsed metrics.
if _analysis_rating <= 2 and has_real_metrics and not _has_parsed_metrics:
logger.warning(
"BUG-23 guard: Analysis quality %d/10 \u2264 2 — "
"overriding has_real_metrics to False (experiment likely failed)",
_analysis_rating,
)
has_real_metrics = False
# Check 2: Are baselines missing?
_analysis_lower = analysis_text.lower()
if "no" in _analysis_lower and "baseline" in _analysis_lower:
if any(phrase in _analysis_lower for phrase in [
"no baseline", "no bo", "no random", "baselines are missing",
"missing baselines", "baseline coverage is missing",
]):
_quality_warnings.append("Baselines appear to be missing from results")
# Check 3: Is the metric undefined?
if any(phrase in _analysis_lower for phrase in [
"metric is undefined", "primary_metric is undefined",
"undefined metric", "metric undefined",
]):
_quality_warnings.append("Primary metric is undefined (direction/units/formula unknown)")
# Check 4: Very few conditions completed
_condition_count = len(_re_q.findall(
r"condition[=:\s]+\w+.*?(?:mean|primary_metric)",
raw_metrics_block or "",
_re_q.IGNORECASE,
))
if _quality_warnings:
_warning_block = "\n".join(f" - {w}" for w in _quality_warnings)
logger.warning(
"Stage 17: Experiment quality concerns detected before paper writing:\n%s",
_warning_block,
)
# Inject quality warnings into the paper writing prompt so the LLM
# writes an appropriately hedged paper
exp_metrics_instruction += (
"\n\n## EXPERIMENT QUALITY WARNINGS (address these honestly in the paper)\n"
+ "\n".join(f"- {w}" for w in _quality_warnings)
+ "\n\nBecause of these issues, the paper MUST:\n"
"- Use hedged language ('preliminary', 'pilot', 'initial exploration')\n"
"- NOT claim definitive comparisons between methods\n"
"- Dedicate a substantial Limitations section to these gaps\n"
"- Frame the contribution as methodology/framework, not empirical findings\n"
)
# Save warnings for tracking
(stage_dir / "quality_warnings.json").write_text(
json.dumps(_quality_warnings, indent=2), encoding="utf-8"
)
# Phase 1: Inject pre-built results tables from VerifiedRegistry
if _verified_registry is not None:
try:
from researchclaw.templates.results_table_builder import (
build_results_tables,
build_condition_whitelist,
)
_prebuilt_tables = build_results_tables(
_verified_registry,
metric_direction=_verified_registry.metric_direction,
)
_condition_whitelist = build_condition_whitelist(_verified_registry)
if _prebuilt_tables:
_tables_block = "\n\n".join(t.latex_code for t in _prebuilt_tables)
exp_metrics_instruction += (
"\n\n## PRE-BUILT RESULTS TABLES (MANDATORY — copy verbatim)\n"
"The tables below were AUTO-GENERATED from verified experiment data.\n"
"You MUST include these tables in the Results section EXACTLY as shown.\n"
"Do NOT modify any numbers. Do NOT add rows with fabricated data.\n"
"You MAY adjust formatting (bold, alignment) but NOT numerical values.\n\n"
+ _tables_block
)
logger.info("Stage 17: Injected pre-built results tables into prompt")
if _condition_whitelist:
exp_metrics_instruction += (
"\n\n## VERIFIED CONDITIONS (ONLY mention these in the paper)\n"
+ _condition_whitelist
+ "\nDo NOT discuss conditions not in this list. Do NOT invent new conditions.\n"
)
except Exception as _tb_exc:
logger.warning("Stage 17: Failed to build pre-built tables: %s", _tb_exc)
# R4-2: Anti-fabrication data integrity instruction
exp_metrics_instruction += (
"\n\n## CRITICAL: Data Integrity Rules\n"
"- You may ONLY report numbers that appear in the experiment data above\n"
"- If the experiment data is incomplete (fewer conditions than planned), report\n"
" ONLY the conditions that were actually run\n"
"- Do NOT extrapolate, interpolate, or 'fill in' missing cells in tables\n"
"- Do NOT invent confidence intervals, p-values, or statistical tests unless\n"
" the actual data supports them\n"
"- If only N conditions completed, simply report results for those N conditions\n"
" without repeating apologies or disclaimers about missing conditions\n"
"- Any table cell without real data must show '\u2014' (not a plausible number)\n"
"- FORBIDDEN: generating numbers that 'look right' based on your training data\n"
)
# IMP-6 + FA: Inject chart references into paper draft prompt
# Prefer FigureAgent's figure_plan.json (rich descriptions) over raw file scan
# BUG-FIX: figure_plan.json may be a list (from FigureAgent planner) or a dict
# (from executor overwrite). The orchestrator writes a list at planning time;
# the executor overwrites with a dict only when figure_count > 0. If the
# FigureAgent renders 0 charts the list persists, and calling .get() on it
# raises AttributeError.
_fa_descriptions = ""
# BUG-178: Iterate in reverse order so we read the LATEST stage-14
# iteration's figure plan, matching Stage 22 which copies charts
# from the newest iteration.
for _s14_dir in sorted(run_dir.glob("stage-14*"), reverse=True):
# Prefer the final plan (dict with figure_descriptions) if it exists
for _fp_name in ("figure_plan_final.json", "figure_plan.json"):
_fp_path = _s14_dir / _fp_name
if not _fp_path.exists():
continue
try:
_fp_data = json.loads(_fp_path.read_text(encoding="utf-8"))
if isinstance(_fp_data, dict):
_fa_descriptions = _fp_data.get("figure_descriptions", "")
elif isinstance(_fp_data, list) and _fp_data:
# List format from FigureAgent planner — synthesize descriptions
_desc_parts = ["## PLANNED FIGURES (from figure plan)\n"]
for _fig in _fp_data:
if isinstance(_fig, dict):
_fid = _fig.get("figure_id", "unnamed")
_ftitle = _fig.get("title", "")
_fcap = _fig.get("caption", "")
_fsec = _fig.get("section", "results")
_desc_parts.append(
f"- **{_fid}** ({_fsec}): {_ftitle}\n {_fcap}"
)
if len(_desc_parts) > 1:
_fa_descriptions = "\n".join(_desc_parts)
except (json.JSONDecodeError, OSError):
pass
if _fa_descriptions:
break
if _fa_descriptions:
break
if _fa_descriptions:
exp_metrics_instruction += "\n\n" + _fa_descriptions
logger.info("Stage 17: Injected FigureAgent figure descriptions into paper draft prompt")
else:
# Fallback: scan for chart files from the LATEST stage-14 iteration
# BUG-178: Must use reverse order to match Stage 22 chart copy behavior
_chart_files: list[str] = []
for _s14_dir in sorted(run_dir.glob("stage-14*"), reverse=True):
_charts_path = _s14_dir / "charts"
if _charts_path.is_dir():
_found = sorted(_charts_path.glob("*.png"))
if _found:
_chart_files = [f.name for f in _found]
break # Use only the latest iteration's charts
if _chart_files:
_chart_block = (
"\n\n## AVAILABLE FIGURES (embed in the paper)\n"
"The following figures were generated from actual experiment data. "
"You MUST reference at least 1-2 of these in the Results section "
"using markdown image syntax: ``\n\n"
)
for _cf_name in _chart_files:
_label = _cf_name.replace("_", " ").replace(".png", "").title()
_chart_block += f"- `charts/{_cf_name}` \u2014 {_label}\n"
_chart_block += (
"\nFor each figure referenced, write a descriptive caption and "
"discuss what the figure shows in 2-3 sentences.\n"
)
exp_metrics_instruction += _chart_block
logger.info(
"Stage 17: Injected %d chart references into paper draft prompt",
len(_chart_files),
)
# WS-5.5: Framework diagram placeholder instruction
exp_metrics_instruction += (
"\n\n## FRAMEWORK DIAGRAM PLACEHOLDER\n"
"In the Method/Approach section, include a placeholder for the methodology "
"framework overview figure. Insert this exactly:\n\n"
"```\n"
"\n"
"**Figure N.** Overview of the proposed methodology. "
"[A detailed framework diagram will be generated separately and inserted here.]\n"
"```\n\n"
"This figure should be referenced in the text as 'Figure N' and discussed briefly "
"(1-2 sentences describing the overall pipeline/architecture flow). "
"The actual image will be generated post-hoc using a text-to-image model.\n"
)
# P5: Extract hyperparameters from results.json for paper Method section
_hp_table = ""
for _s14_dir in sorted(run_dir.glob("stage-14*")):
for _run_file in sorted(_s14_dir.glob("runs/*.json")):
try:
_run_data = json.loads(_run_file.read_text(encoding="utf-8"))
if isinstance(_run_data, dict) and _run_data.get("hyperparameters"):
_hp = _run_data["hyperparameters"]
if isinstance(_hp, dict) and _hp:
_hp_table = "\n\n## HYPERPARAMETERS (include as a table in the Method section)\n"
_hp_table += "| Hyperparameter | Value |\n|---|---|\n"
for _hk, _hv in sorted(_hp.items()):
_hp_table += f"| {_hk} | {_hv} |\n"
_hp_table += (
"\nThis table MUST appear in the Method/Experiments section. "
"Include ALL hyperparameters used, with justification for key choices.\n"
)
break
except (json.JSONDecodeError, OSError):
continue
if _hp_table:
break
# Also check staging dirs for results.json
if not _hp_table:
for _staging_dir in sorted(run_dir.glob("stage-*/runs/_docker_*")):
_rjson = _staging_dir / "results.json"
if _rjson.is_file():
try:
_rdata = json.loads(_rjson.read_text(encoding="utf-8"))
if isinstance(_rdata, dict) and _rdata.get("hyperparameters"):
_hp = _rdata["hyperparameters"]
if isinstance(_hp, dict) and _hp:
_hp_table = "\n\n## HYPERPARAMETERS (include as a table in the Method section)\n"
_hp_table += "| Hyperparameter | Value |\n|---|---|\n"
for _hk, _hv in sorted(_hp.items()):
_hp_table += f"| {_hk} | {_hv} |\n"
_hp_table += (
"\nThis table MUST appear in the Method/Experiments section. "
"Include ALL hyperparameters used, with justification for key choices.\n"
)
break
except (json.JSONDecodeError, OSError):
continue
if _hp_table:
exp_metrics_instruction += _hp_table
# F2.6: Build citation list from references.bib / candidates with cite_keys
citation_instruction = ""
bib_text = _read_prior_artifact(run_dir, "references.bib")
# P3: Pre-verify citations before paper draft — remove hallucinated refs
if bib_text and bib_text.strip():
from researchclaw.literature.verify import (
filter_verified_bibtex,
verify_citations as _verify_cit,
)
try:
_pre_report = _verify_cit(bib_text, inter_verify_delay=0.5)
_kept = _pre_report.verified + _pre_report.suspicious
_removed = _pre_report.hallucinated
if _removed > 0:
bib_text = filter_verified_bibtex(
bib_text, _pre_report, include_suspicious=True
)
(stage_dir / "references_preverified.bib").write_text(
bib_text, encoding="utf-8"
)
logger.info(
"P3: Pre-verification kept %d/%d citations (removed %d hallucinated)",
_kept, _pre_report.total, _removed,
)
except Exception as exc:
logger.warning("P3: Pre-verification failed, using original bib: %s", exc)
candidates_text = _read_prior_artifact(run_dir, "candidates.jsonl")
if candidates_text:
cite_lines: list[str] = []
for row_text in candidates_text.strip().splitlines():
row = _safe_json_loads(row_text, {})
if isinstance(row, dict) and row.get("cite_key"):
authors_info = ""
if isinstance(row.get("authors"), list) and row["authors"]:
first_author = row["authors"][0]
if isinstance(first_author, dict):
# BUG-38: name may be non-str (tuple/list) — force str
_name = first_author.get("name", "")
authors_info = _name if isinstance(_name, str) else str(_name)
elif isinstance(first_author, str):
authors_info = first_author
if len(row["authors"]) > 1:
authors_info += " et al."
title = row.get("title", "")
cite_lines.append(
f"- [{row['cite_key']}] \u2192 TITLE: \"{title}\" "
f"| {authors_info} "
f"({row.get('venue', '')}, {row.get('year', '')}, "
f"cited {row.get('citation_count', 0)} times) "
f"| ONLY cite this key when discussing: {title}"
)
if cite_lines:
citation_instruction = (
"\n\nAVAILABLE REFERENCES (use [cite_key] to cite in the text):\n"
+ "\n".join(cite_lines)
+ "\n\nCRITICAL CITATION RULES:\n"
"- In the body text, cite using [cite_key] format, e.g. [smith2024transformer].\n"
"- Do NOT write a References section \u2014 it will be auto-generated from the bibliography file.\n"
"- Do NOT invent any references or arXiv IDs not in the above list.\n"
"- You may cite a subset, but NEVER fabricate citations or change arXiv IDs.\n"
"- SEMANTIC MATCHING: Before citing a reference, verify that its TITLE matches\n"
" the concept you are discussing. Do NOT use an unrelated cite_key just\n"
" because it sounds similar.\n"
"- If no reference in the list matches the concept you want to cite,\n"
" write 'prior work has shown...' WITHOUT a citation, rather than using\n"
" a mismatched reference.\n"
"- Each [cite_key] MUST correspond to the paper whose title is shown\n"
" next to that key in the list above. Cross-check before citing.\n"
"\nCITATION QUANTITY & QUALITY CONSTRAINTS:\n"
"- Cite 25-40 unique references in the paper body. The Related Work\n"
" section alone should cite at least 15 references.\n"
"- Every citation MUST be directly relevant to the paper's topic.\n"
"- DO NOT cite papers from unrelated domains (wireless communication, "
"manufacturing, UAV, etc.).\n"
"- Prefer well-known, highly-cited papers over obscure ones.\n"
"- If unsure whether a paper exists or is relevant, DO NOT cite it.\n"
)
if llm is not None:
_pm = prompts or PromptManager()
topic_constraint = _pm.block("topic_constraint", topic=config.research.topic)
# --- Section-by-section writing (3 calls) for conference-grade depth ---
draft = _write_paper_sections(
llm=llm,
pm=_pm,
run_dir=run_dir,
preamble=preamble,
topic_constraint=topic_constraint,
exp_metrics_instruction=exp_metrics_instruction,
citation_instruction=citation_instruction,
outline=outline,
model_name=config.llm.primary_model,
)
# R7: Strip LLM-generated References section — it often fabricates arXiv IDs.
import re as _re_r7
ref_pattern = _re_r7.compile(
r'^(#{1,2}\s*References.*)', _re_r7.MULTILINE | _re_r7.DOTALL
)
ref_match = ref_pattern.search(draft)
if ref_match:
draft = draft[:ref_match.start()].rstrip()
logger.info("Stage 17: Stripped LLM-generated References section (R7 fix)")
else:
# Build template with real data if available
results_section = "Template results summary."
if exp_summary_text:
exp_summary = _safe_json_loads(exp_summary_text, {})
if isinstance(exp_summary, dict) and exp_summary.get("metrics_summary"):
lines = ["Experiment results:"]
for mk, mv in exp_summary["metrics_summary"].items():
if isinstance(mv, dict):
lines.append(
f"- {mk}: mean={mv.get('mean')}, min={mv.get('min')}, "
f"max={mv.get('max')}, n={mv.get('count')}"
)
results_section = "\n".join(lines)
draft = f"""# Draft Title
## Abstract
Template draft abstract.
## Introduction
Template introduction for {config.research.topic}.
## Related Work
Template related work.
## Method
Template method description.
## Experiments
Template experimental setup.
## Results
{results_section}
## Limitations
Template limitations.
## Conclusion
Template conclusion.
## References
Template references.
Generated: {_utcnow_iso()}
"""
(stage_dir / "paper_draft.md").write_text(draft, encoding="utf-8")
# Validate draft quality (section balance + bullet density)
_validate_draft_quality(draft, stage_dir=stage_dir)
return StageResult(
stage=Stage.PAPER_DRAFT,
status=StageStatus.DONE,
artifacts=("paper_draft.md",),
evidence_refs=("stage-17/paper_draft.md",),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_review_publish.py
================================================
"""Stages 18-23: Peer review, paper revision, quality gate, knowledge archive, export/publish, and citation verify."""
from __future__ import annotations
import json
import logging
import math
import re
from collections import Counter
from pathlib import Path
from typing import Any
import yaml # noqa: F401 — available for downstream use
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain # noqa: F401
from researchclaw.pipeline._helpers import (
StageResult,
_build_context_preamble,
_chat_with_prompt,
_collect_experiment_results, # noqa: F401
_default_quality_report,
_extract_paper_title,
_find_prior_file,
_generate_framework_diagram_prompt,
_generate_neurips_checklist,
_get_evolution_overlay,
_read_best_analysis,
_read_prior_artifact,
_safe_json_loads,
_topic_constraint_block, # noqa: F401
_utcnow_iso,
reconcile_figure_refs,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers imported from executor.py (not yet moved to _helpers.py).
# Lazy-imported inside functions to avoid circular import when executor.py
# imports this module.
# ---------------------------------------------------------------------------
def _get_collect_raw_experiment_metrics():
from researchclaw.pipeline.stage_impls._paper_writing import _collect_raw_experiment_metrics
return _collect_raw_experiment_metrics
def _get_review_compiled_pdf():
from researchclaw.pipeline.stage_impls._paper_writing import _review_compiled_pdf
return _review_compiled_pdf
# ---------------------------------------------------------------------------
# _collect_experiment_evidence
# ---------------------------------------------------------------------------
def _collect_experiment_evidence(run_dir: Path) -> str:
"""Collect actual experiment parameters and results for peer review."""
evidence_parts: list[str] = []
# 1. Read experiment code to find actual trial count, methods used
exp_dir = _read_prior_artifact(run_dir, "experiment/")
if exp_dir and Path(exp_dir).is_dir():
main_py = Path(exp_dir) / "main.py"
if main_py.exists():
code = main_py.read_text(encoding="utf-8")
evidence_parts.append(f"### Actual Experiment Code (main.py)\n```python\n{code[:3000]}\n```")
# 2. Read sandbox run results (actual metrics, runtime, stderr)
runs_text = _read_prior_artifact(run_dir, "runs/")
if runs_text and Path(runs_text).is_dir():
for run_file in sorted(Path(runs_text).glob("*.json"))[:5]:
payload = _safe_json_loads(run_file.read_text(encoding="utf-8"), {})
if isinstance(payload, dict):
summary = {
"metrics": payload.get("metrics"),
"elapsed_sec": payload.get("elapsed_sec"),
"timed_out": payload.get("timed_out"),
}
stderr = payload.get("stderr", "")
if stderr:
summary["stderr_excerpt"] = stderr[:500]
evidence_parts.append(
f"### Run Result: {run_file.name}\n```json\n{json.dumps(summary, indent=2)}\n```"
)
# 3. Read refinement log for actual iteration count
refine_log_text = _read_prior_artifact(run_dir, "refinement_log.json")
if refine_log_text:
try:
rlog = json.loads(refine_log_text)
summary = {
"iterations_executed": len(rlog.get("iterations", [])),
"converged": rlog.get("converged"),
"stop_reason": rlog.get("stop_reason"),
"best_metric": rlog.get("best_metric"),
}
evidence_parts.append(
f"### Refinement Summary\n```json\n{json.dumps(summary, indent=2)}\n```"
)
except (json.JSONDecodeError, TypeError):
pass
# 4. Count actual number of experiment runs
actual_run_count = 0
for stage_subdir in sorted(run_dir.glob("stage-*/runs")):
for rf in stage_subdir.glob("*.json"):
if rf.name != "results.json":
actual_run_count += 1
if actual_run_count > 0:
evidence_parts.append(
f"### Actual Trial Count\n"
f"**The experiment was executed {actual_run_count} time(s).** "
f"If the paper claims a different number of trials, this is a CRITICAL discrepancy."
)
if not evidence_parts:
return ""
return (
"\n\n## Actual Experiment Evidence\n"
"Use the evidence below to verify the paper's methodology claims.\n\n"
+ "\n\n".join(evidence_parts)
)
# ---------------------------------------------------------------------------
# Stage 18: Peer Review
# ---------------------------------------------------------------------------
def _execute_peer_review(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
draft = _read_prior_artifact(run_dir, "paper_draft.md") or ""
experiment_evidence = _collect_experiment_evidence(run_dir)
# Load draft quality warnings from Stage 17 (if available)
_quality_suffix = ""
_quality_json_path = _find_prior_file(run_dir, "draft_quality.json")
if _quality_json_path and _quality_json_path.exists():
try:
_dq = json.loads(_quality_json_path.read_text(encoding="utf-8"))
_dq_warnings = _dq.get("overall_warnings", [])
if _dq_warnings:
_quality_suffix = (
"\n\nAUTOMATED QUALITY ISSUES (flag these in your review):\n"
+ "\n".join(f"- {w}" for w in _dq_warnings)
+ "\n"
)
except Exception: # noqa: BLE001
pass
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "peer_review")
sp = _pm.for_stage(
"peer_review",
evolution_overlay=_overlay,
topic=config.research.topic,
draft=draft,
experiment_evidence=experiment_evidence,
)
_review_user = sp.user + _quality_suffix
resp = _chat_with_prompt(
llm,
sp.system,
_review_user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
reviews = resp.content
else:
reviews = """# Reviews
## Reviewer A
- Strengths: Clear problem statement.
- Weaknesses: Limited ablation details.
- Actionable revisions: Add uncertainty analysis and stronger baselines.
## Reviewer B
- Strengths: Reproducibility focus.
- Weaknesses: Discussion underdeveloped.
- Actionable revisions: Expand limitations and broader impact.
"""
(stage_dir / "reviews.md").write_text(reviews, encoding="utf-8")
return StageResult(
stage=Stage.PEER_REVIEW,
status=StageStatus.DONE,
artifacts=("reviews.md",),
evidence_refs=("stage-18/reviews.md",),
)
# ---------------------------------------------------------------------------
# Stage 19: Paper Revision
# ---------------------------------------------------------------------------
def _execute_paper_revision(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
draft = _read_prior_artifact(run_dir, "paper_draft.md") or ""
reviews = _read_prior_artifact(run_dir, "reviews.md") or ""
draft_word_count = len(draft.split())
# R4-2: Collect real metrics for anti-fabrication guard in revision
# BUG-47: _collect_raw_experiment_metrics returns tuple[str, bool], must unpack
_raw_metrics_tuple = _get_collect_raw_experiment_metrics()(run_dir)
raw_metrics_revision = _raw_metrics_tuple[0] if isinstance(_raw_metrics_tuple, tuple) else (_raw_metrics_tuple or "")
data_integrity_revision = ""
if raw_metrics_revision:
data_integrity_revision = (
raw_metrics_revision
+ "\nDATA INTEGRITY: Do NOT add new numbers that are not in the "
"experiment data above. If a reviewer asks for additional results "
"you do not have, state 'Due to computational constraints, "
"this analysis was not conducted' instead of fabricating data.\n"
)
if llm is not None:
_pm = prompts or PromptManager()
try:
_ws_revision = _pm.block("writing_structure")
except (KeyError, Exception): # noqa: BLE001
_ws_revision = ""
# IMP-20/25/31/24: Load style blocks for revision prompt
_rev_blocks: dict[str, str] = {}
for _bname in ("academic_style_guide", "narrative_writing_rules",
"anti_hedging_rules", "anti_repetition_rules"):
try:
_rev_blocks[_bname] = _pm.block(_bname)
except (KeyError, Exception): # noqa: BLE001
_rev_blocks[_bname] = ""
# Load draft quality directives from Stage 17
_quality_prefix = ""
_quality_json_path = _find_prior_file(run_dir, "draft_quality.json")
if _quality_json_path and _quality_json_path.exists():
try:
_dq = json.loads(_quality_json_path.read_text(encoding="utf-8"))
_dq_directives = _dq.get("revision_directives", [])
if _dq_directives:
_quality_prefix = (
"MANDATORY QUALITY FIXES (address ALL of these):\n"
+ "\n".join(f"- {d}" for d in _dq_directives)
+ "\n\n"
)
except Exception: # noqa: BLE001
pass
_overlay = _get_evolution_overlay(run_dir, "paper_revision")
sp = _pm.for_stage(
"paper_revision",
evolution_overlay=_overlay,
topic_constraint=_pm.block("topic_constraint", topic=config.research.topic),
writing_structure=_ws_revision,
draft=draft,
reviews=_quality_prefix + reviews + data_integrity_revision,
**_rev_blocks,
)
# R10-Fix2: Ensure max_tokens is sufficient for full paper revision
revision_max_tokens = sp.max_tokens
if revision_max_tokens and draft_word_count > 0:
# ~1.5 tokens per word, 20% headroom
min_tokens_needed = int(draft_word_count * 1.5 * 1.2)
if revision_max_tokens < min_tokens_needed:
revision_max_tokens = min_tokens_needed
logger.info(
"Stage 19: Increased max_tokens from %d to %d to fit full paper revision",
sp.max_tokens,
revision_max_tokens,
)
# R10-Fix4: Retry on timeout for paper revision (critical stage)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=revision_max_tokens,
retries=2,
)
revised = resp.content
revised_word_count = len(revised.split())
# Length guard: if revision is shorter than 80% of draft, retry once
if draft_word_count > 500 and revised_word_count < int(draft_word_count * 0.8):
logger.warning(
"Paper revision (%d words) is shorter than draft (%d words). "
"Retrying with stronger length enforcement.",
revised_word_count,
draft_word_count,
)
retry_user = (
f"CRITICAL LENGTH REQUIREMENT: The draft is {draft_word_count} words. "
f"Your revision MUST be at least {draft_word_count} words — ideally longer. "
f"Do NOT summarize or condense ANY section. Copy each section verbatim "
f"and ONLY make targeted improvements to address reviewer comments. "
f"If a section has no reviewer comments, include it UNCHANGED.\n\n"
+ sp.user
)
resp2 = _chat_with_prompt(
llm, sp.system, retry_user,
json_mode=sp.json_mode, max_tokens=revision_max_tokens,
)
revised2 = resp2.content
revised2_word_count = len(revised2.split())
if revised2_word_count >= int(draft_word_count * 0.8):
revised = revised2
elif revised2_word_count > revised_word_count:
# Retry improved but still not enough — use the longer version
revised = revised2
logger.warning(
"Retry improved (%d → %d words) but still shorter than draft (%d).",
revised_word_count,
revised2_word_count,
draft_word_count,
)
else:
# Both attempts produced short output — preserve full original draft
logger.warning(
"Retry also produced short output (%d words). "
"Falling back to FULL ORIGINAL DRAFT to prevent content loss.",
revised2_word_count,
)
# Extract useful revision points as appendix
revision_words = revised.split()
revision_summary = (
" ".join(revision_words[:500]) + "\n\n*(Revision summary truncated)*"
if len(revision_words) > 500
else revised
)
if revision_summary.strip():
# Save revision notes to internal file, not paper body
(stage_dir / "revision_notes_internal.md").write_text(
revision_summary, encoding="utf-8"
)
revised = draft
else:
revised = draft
(stage_dir / "paper_revised.md").write_text(revised, encoding="utf-8")
return StageResult(
stage=Stage.PAPER_REVISION,
status=StageStatus.DONE,
artifacts=("paper_revised.md",),
evidence_refs=("stage-19/paper_revised.md",),
)
# ---------------------------------------------------------------------------
# Stage 20: Quality Gate
# ---------------------------------------------------------------------------
def _execute_quality_gate(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
revised = _read_prior_artifact(run_dir, "paper_revised.md") or ""
report: dict[str, Any] | None = None
# BUG-25 + BUG-180: Load the RICHEST experiment summary for cross-checking.
# _read_prior_artifact returns the first match in reverse-sorted order,
# which may be a repair stage with 0 conditions. Instead, scan all
# stage-14* experiment summaries and pick the one with the most data.
_exp_summary: dict[str, Any] = {}
_exp_summary_text = ""
_best_richness = -1
for _es_path in sorted(run_dir.glob("stage-14*/experiment_summary.json")):
try:
_es_text = _es_path.read_text(encoding="utf-8")
_es_data = _safe_json_loads(_es_text, {})
if not isinstance(_es_data, dict):
continue
_richness = len(_es_data.get("condition_summaries", {}))
if _richness > _best_richness:
_best_richness = _richness
_exp_summary = _es_data
_exp_summary_text = _es_text
except OSError:
continue
# Also check experiment_summary_best.json at run root
_root_best = run_dir / "experiment_summary_best.json"
if _root_best.is_file():
try:
_rb_text = _root_best.read_text(encoding="utf-8")
_rb_data = _safe_json_loads(_rb_text, {})
if isinstance(_rb_data, dict):
_rb_rich = len(_rb_data.get("condition_summaries", {}))
if _rb_rich > _best_richness:
_exp_summary = _rb_data
_exp_summary_text = _rb_text
except OSError:
pass
# Fallback to _read_prior_artifact if nothing found above
if not _exp_summary:
_exp_summary_text = _read_prior_artifact(run_dir, "experiment_summary.json") or ""
_exp_summary = _safe_json_loads(_exp_summary_text, {}) if _exp_summary_text else {}
_exp_failed = False
if isinstance(_exp_summary, dict):
_best_run = _exp_summary.get("best_run", {})
if isinstance(_best_run, dict):
_exp_failed = (
_best_run.get("status") == "failed"
and not _best_run.get("metrics")
)
# Also check if metrics_summary is empty
if not _exp_summary.get("metrics_summary"):
_exp_failed = True
# BUG-180: If we found real condition data, don't mark as failed
if _best_richness > 0:
_exp_failed = False
if llm is not None:
_pm = prompts or PromptManager()
# IMP-33: Evaluate the full paper instead of truncating to 12K chars.
# Split into chunks if very long, but prefer sending the full text.
paper_for_eval = revised[:40000] if len(revised) > 40000 else revised
# BUG-25: Inject experiment status into quality gate prompt
_exp_context = ""
if _exp_summary and isinstance(_exp_summary, dict):
_exp_status_keys = {
k: _exp_summary.get(k) for k in (
"total_conditions", "total_metric_keys",
"metrics_summary",
) if _exp_summary.get(k) is not None
}
# BUG-180: Include condition count from condition_summaries
_cond_summ = _exp_summary.get("condition_summaries", {})
if isinstance(_cond_summ, dict) and _cond_summ:
_exp_status_keys["completed_conditions"] = len(_cond_summ)
_exp_status_keys["condition_names"] = list(_cond_summ.keys())[:20]
if _best_run := _exp_summary.get("best_run"):
_exp_status_keys["best_run_status"] = (
_best_run.get("status") if isinstance(_best_run, dict) else str(_best_run)
)
_exp_context = (
"\n\nExperiment summary (for cross-checking reported numbers):\n"
+ json.dumps(_exp_status_keys, indent=2, default=str)[:4000]
+ "\n\nCross-check: If the experiment status is 'failed' with "
"empty metrics, any numerical results in tables constitute "
"fabrication. Penalize severely.\n"
)
_overlay = _get_evolution_overlay(run_dir, "quality_gate")
sp = _pm.for_stage(
"quality_gate",
evolution_overlay=_overlay,
quality_threshold=str(config.research.quality_threshold),
revised=paper_for_eval + _exp_context,
)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
parsed = _safe_json_loads(resp.content, {})
if isinstance(parsed, dict):
report = parsed
# BUG-25: If experiment failed with no metrics, cap the quality score
if report is not None and _exp_failed:
_orig_score = report.get("score_1_to_10", 5)
if isinstance(_orig_score, (int, float)) and _orig_score > 3:
report["score_1_to_10"] = min(_orig_score, 3.0)
report.setdefault("weaknesses", []).append(
"Experiment failed with no metrics — any reported numerical "
"results are unsupported and likely fabricated."
)
logger.warning(
"BUG-25: Experiment failed — capping quality score from %.1f to 3.0",
_orig_score,
)
if report is None:
report = _default_quality_report(config.research.quality_threshold)
report.setdefault("generated", _utcnow_iso())
(stage_dir / "quality_report.json").write_text(
json.dumps(report, indent=2), encoding="utf-8"
)
# T2.1: Enforce quality gate — fail if score below threshold
score = report.get("score_1_to_10", 0)
# BUG-R5-01: score can be string from LLM JSON — coerce to float
if not isinstance(score, (int, float)):
try:
score = float(score)
except (TypeError, ValueError):
score = 0
verdict = report.get("verdict", "proceed")
threshold = config.research.quality_threshold or 5.0
# --- Fabrication flag: collect real metrics for Stage 22 sanitization ---
_fabrication_info: dict[str, Any] = {
"experiment_failed": _exp_failed,
"quality_score": score,
"real_metric_values": [],
}
if isinstance(_exp_summary, dict):
# Collect ALL real numeric values from experiment_summary.json
_cond_summaries = _exp_summary.get("condition_summaries", {})
if isinstance(_cond_summaries, dict):
for cond_name, cond_data in _cond_summaries.items():
if not isinstance(cond_data, dict):
continue
cond_status = cond_data.get("status", "")
if cond_status == "failed":
continue # skip failed conditions
for k, v in cond_data.items():
if isinstance(v, (int, float)) and k not in (
"seed_count", "total_steps", "training_steps",
):
_fabrication_info["real_metric_values"].append(
round(float(v), 4)
)
_ms = _exp_summary.get("metrics_summary", {})
if isinstance(_ms, dict):
for _mk, _mv in _ms.items():
if isinstance(_mv, dict):
for _stat in ("mean", "min", "max"):
_sv = _mv.get(_stat)
if isinstance(_sv, (int, float)):
_fabrication_info["real_metric_values"].append(
round(float(_sv), 4)
)
_fabrication_info["has_real_data"] = bool(
_fabrication_info["real_metric_values"]
)
_fabrication_info["fabrication_suspected"] = (
_exp_failed and not _fabrication_info["has_real_data"]
)
# Phase 1: Enhanced fabrication detection via VerifiedRegistry
# BUG-108: Also pass refinement_log so NaN best_metric is properly handled
_rl20_candidates = sorted(run_dir.glob("stage-13*/refinement_log.json"), reverse=True)
_rl20_path = _rl20_candidates[0] if _rl20_candidates else None
_rl20: dict | None = None
if _rl20_path and _rl20_path.is_file():
try:
_rl20 = json.loads(_rl20_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
pass
try:
from researchclaw.pipeline.verified_registry import VerifiedRegistry as _VR20
_vr20 = _VR20.from_run_dir(run_dir, metric_direction=config.experiment.metric_direction, best_only=True) if isinstance(_exp_summary, dict) else None
if _vr20:
_fabrication_info["verified_values_count"] = len(_vr20.values)
_fabrication_info["verified_conditions"] = sorted(_vr20.condition_names)
except Exception:
pass
(stage_dir / "fabrication_flags.json").write_text(
json.dumps(_fabrication_info, indent=2), encoding="utf-8"
)
if isinstance(score, (int, float)) and score < threshold:
if config.research.graceful_degradation:
logger.warning(
"Quality gate DEGRADED: score %.1f < threshold %.1f — "
"continuing with sanitization (graceful_degradation=True)",
score, threshold,
)
# Write degradation signal for downstream stages
signal = {
"score": score,
"threshold": threshold,
"verdict": verdict,
"weaknesses": report.get("weaknesses", []),
"generated": _utcnow_iso(),
}
(run_dir / "degradation_signal.json").write_text(
json.dumps(signal, indent=2), encoding="utf-8"
)
return StageResult(
stage=Stage.QUALITY_GATE,
status=StageStatus.DONE,
artifacts=("quality_report.json",),
evidence_refs=("stage-20/quality_report.json",),
decision="degraded",
)
logger.warning(
"Quality gate FAILED: score %.1f < threshold %.1f (verdict=%s)",
score, threshold, verdict,
)
return StageResult(
stage=Stage.QUALITY_GATE,
status=StageStatus.FAILED,
artifacts=("quality_report.json", "fabrication_flags.json"),
evidence_refs=("stage-20/quality_report.json",),
error=f"Quality score {score:.1f}/10 below threshold {threshold:.1f}. "
f"Paper needs revision before export.",
)
logger.info(
"Quality gate PASSED: score %.1f >= threshold %.1f",
score, threshold,
)
return StageResult(
stage=Stage.QUALITY_GATE,
status=StageStatus.DONE,
artifacts=("quality_report.json", "fabrication_flags.json"),
evidence_refs=("stage-20/quality_report.json",),
)
# ---------------------------------------------------------------------------
# Stage 21: Knowledge Archive
# ---------------------------------------------------------------------------
def _execute_knowledge_archive(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
revised = _read_prior_artifact(run_dir, "paper_revised.md") or ""
analysis = _read_best_analysis(run_dir)
decision = _read_prior_artifact(run_dir, "decision.md") or ""
preamble = _build_context_preamble(config, run_dir, include_goal=True)
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "knowledge_archive")
sp = _pm.for_stage(
"knowledge_archive",
evolution_overlay=_overlay,
preamble=preamble,
decision=decision,
analysis=analysis,
revised=revised[:15000],
)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
archive = resp.content
else:
archive = f"""# Knowledge Archive
## Lessons Learned
- Preserve strict metric reporting protocol.
- Keep refinement logs aligned with code changes.
## Reproducibility
- Include exact experiment script and schedule.
- Capture run-level JSON metrics.
## Future Work
- Extend robustness and external validity checks.
Generated: {_utcnow_iso()}
"""
(stage_dir / "archive.md").write_text(archive, encoding="utf-8")
files: list[str] = []
for stage_subdir in sorted(run_dir.glob("stage-*")):
for artifact in sorted(stage_subdir.rglob("*")):
if artifact.is_file() and artifact != (stage_dir / "bundle_index.json"):
files.append(str(artifact.relative_to(run_dir)))
index = {
"run_id": run_dir.name,
"generated": _utcnow_iso(),
"artifact_count": len(files),
"artifacts": files,
}
(stage_dir / "bundle_index.json").write_text(
json.dumps(index, indent=2), encoding="utf-8"
)
return StageResult(
stage=Stage.KNOWLEDGE_ARCHIVE,
status=StageStatus.DONE,
artifacts=("archive.md", "bundle_index.json"),
evidence_refs=("stage-21/archive.md", "stage-21/bundle_index.json"),
)
# ---------------------------------------------------------------------------
# _sanitize_fabricated_data helper
# ---------------------------------------------------------------------------
def _sanitize_fabricated_data(
paper: str,
run_dir: Path,
) -> tuple[str, dict[str, Any]]:
"""Replace unverified numerical data in markdown tables with '---'.
Loads experiment_summary.json as ground truth, extracts all verified
metric values, then scans markdown tables in Results/Experiment sections.
Numbers not matching any verified value (within 1% relative tolerance)
are replaced with ``---``.
Returns (sanitized_paper, sanitization_report).
"""
import re as _re_san
# --- 1. Build verified values set from experiment_summary.json ---
# BUG-222: After REFINE cycles, merging ALL stage-14* data creates a
# permissive registry that validates fabricated numbers from regressed
# iterations. Use ONLY the promoted best data as ground truth.
# experiment_summary_best.json is written by _promote_best_stage14() and
# contains the single best iteration's data.
verified_values: set[float] = set()
def _richness(path: Path) -> int:
"""Score an experiment_summary.json by how many conditions it has."""
try:
d = json.loads(path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
return -1
if not isinstance(d, dict):
return -1
conds = d.get("condition_summaries", {})
metrics = d.get("metrics_summary", {})
return len(conds) + len(metrics)
# BUG-222: Prefer experiment_summary_best.json (promoted best iteration).
# Only fall back to "richest stage-14*" scanning if best.json is missing
# (single-iteration runs without REFINE).
_root_best = run_dir / "experiment_summary_best.json"
if _root_best.exists() and _richness(_root_best) > 0:
exp_path = _root_best
else:
_candidates = list(run_dir.glob("stage-14*/experiment_summary.json"))
exp_path = max(_candidates, key=_richness) if _candidates else run_dir / "stage-14" / "experiment_summary.json"
if exp_path.exists():
try:
exp_data = json.loads(exp_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
exp_data = {}
def _collect_numbers(obj: Any, depth: int = 0) -> None:
if depth > 10:
return
if isinstance(obj, (int, float)) and not isinstance(obj, bool):
import math as _math_vv
if _math_vv.isfinite(float(obj)):
verified_values.add(float(obj))
elif isinstance(obj, dict):
for v in obj.values():
_collect_numbers(v, depth + 1)
elif isinstance(obj, list):
for v in obj:
_collect_numbers(v, depth + 1)
# Extract from well-known keys
for key in (
"metrics_summary", "condition_summaries", "best_run",
"condition_metrics", "conditions", "ablation_results",
):
if key in exp_data:
_collect_numbers(exp_data[key])
# BUG-222: Removed BUG-206 refinement_log scanning. The original BUG-206
# rationale was "Stage 17 injects sandbox metrics, so the sanitizer must
# recognise them". But that created a loophole: after REFINE regression,
# the LLM would cite regressed iteration numbers and the sanitizer would
# pass them because they were in the refinement log. Now that Stage 17
# also uses only the promoted best data (BUG-222), there is no need to
# whitelist all sandbox metrics here.
if not verified_values:
report: dict[str, Any] = {
"sanitized": False,
"reason": "no verified values found in experiment_summary.json",
"tables_processed": 0,
"numbers_replaced": 0,
}
return paper, report
def _is_verified(num: float) -> bool:
"""Check if num matches any verified value within 1% relative tolerance.
BUG-R5-20: Also checks percentage/decimal cross-matching
(e.g., 73.42 in paper vs 0.7342 in experiment, or vice versa).
"""
for v in verified_values:
if v == 0.0:
if abs(num) < 1e-9:
return True
elif abs(num - v) / abs(v) <= 0.01:
return True
# Cross-match: num might be percentage form of v (or vice versa)
elif v != 0.0 and abs(num / 100.0 - v) / abs(v) <= 0.01:
return True
elif v != 0.0 and abs(num - v * 100.0) / abs(v * 100.0) <= 0.01:
return True
return False
# --- 2. Find and sanitize markdown tables ---
# BUG-175: Always-allowed set — common constants, hyperparameters, and
# structural values that should never be sanitized (matches paper_verifier.py).
_SANITIZER_ALWAYS_ALLOWED: set[float] = {
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 50.0, 100.0, 200.0,
0.5, 0.01, 0.001, 0.0001, 0.1, 0.05, 0.95, 0.99,
2024.0, 2025.0, 2026.0, 2027.0,
8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0, 1024.0, 2048.0,
224.0, 299.0, 384.0, # Common image sizes
# BUG-192: Common hyperparameter values
0.0003, 3e-4, 0.0005, 5e-4, 0.002, 2e-3, # learning rates
0.2, 0.3, 0.25, 0.7, 0.6, 0.8, # clip epsilon, dropout, gradient clip, GCE q, common HP
0.9, 0.999, 0.9999, # Adam betas, momentum
0.02, 0.03, # weight init std
1e-5, 1e-6, 1e-8, # epsilon, weight decay
300.0, 400.0, 500.0, # epochs
4096.0, 8192.0, # larger batch sizes / hidden dims
}
# Match markdown table blocks (header + separator + data rows)
table_pat = _re_san.compile(
r"((?:^[ \t]*\|.+\|[ \t]*\n)+" # one or more pipe-delimited lines
r")",
_re_san.MULTILINE,
)
# Match numbers in table cells (integers, decimals, percentages, scientific)
# BUG-175: Also exclude hyphen in lookaround to protect method names like
# "Cos-200", "StepLR-100" from partial number extraction.
# BUG-206: Include Unicode hyphens (U+2010 hyphen, U+2011 non-breaking
# hyphen, U+2013 en-dash) — LLMs frequently emit these instead of ASCII
# hyphens in model names like "ResNet‑34".
# BUG-206: Unicode hyphens placed before escaped ASCII hyphen (\\-)
# to avoid creating unintended character ranges in the class.
_HYPH = "\u2010\u2011\u2013\\-" # U+2010 + U+2011 + U+2013 + ASCII hyphen
num_pat = _re_san.compile(
f"(? str:
nonlocal numbers_replaced, numbers_kept
num_str = m.group(1)
pct = m.group(2)
try:
val = float(num_str)
except ValueError:
return m.group(0)
# BUG-175: Always allow common constants / hyperparameters
if val in _SANITIZER_ALWAYS_ALLOWED:
numbers_kept += 1
return m.group(0)
# BUG-175: Small integer exemption — counts, indices,
# epoch numbers, etc. (≤ 20 auto-pass)
if val == int(val) and abs(val) <= 20:
numbers_kept += 1
return m.group(0)
if _is_verified(val):
numbers_kept += 1
return m.group(0)
numbers_replaced += 1
replaced_values.append(num_str + pct)
return "---"
def _sanitize_table(match: _re_san.Match[str]) -> str:
nonlocal numbers_replaced, numbers_kept, tables_processed
table_text = match.group(0)
lines = table_text.split("\n")
# Check if this looks like a results/experiment table
# (heuristic: has a separator row with dashes)
has_separator = any(
_re_san.match(r"^[ \t]*\|[\s:|-]+\|[ \t]*$", line)
for line in lines
)
if not has_separator:
return table_text
# BUG-192: Detect hyperparameter/config tables and SKIP sanitization.
# These tables contain design choices, not experimental results.
_HP_TABLE_KW = {
"hyperparameter", "hyper-parameter", "configuration", "config",
"setting", "parameter", "learning rate", "lr", "batch size",
"optimizer", "architecture", "schedule", "warmup", "decay",
"dropout", "weight decay", "momentum", "epsilon", "clip",
}
# BUG-224: Statistical analysis tables contain derived values
# (t-statistics, p-values, effect sizes) that are computed from
# the experiment data but never appear in experiment_summary.json.
# These tables should NOT be sanitized.
_STAT_TABLE_KW = {
"t-statistic", "t-stat", "t statistic", "p-value", "p value",
"paired", "cohen", "effect size", "wilcoxon", "mann-whitney",
"statistical", "significance", "confidence interval",
}
_RESULT_TABLE_KW = {
"accuracy", "acc", "loss", "f1", "auroc", "auc", "precision",
"recall", "bleu", "rouge", "reward", "return", "rmse", "mae",
"mse", "error", "score", "metric", "performance", "improvement",
"top-1", "top1", "top-5", "top5",
}
_header_lower = lines[0].lower() if lines else ""
_is_hp_table = any(kw in _header_lower for kw in _HP_TABLE_KW)
_is_result_table = any(kw in _header_lower for kw in _RESULT_TABLE_KW)
# BUG-224: Statistical analysis tables (t-tests, p-values) contain
# derived values that are never in experiment_summary.json.
_is_stat_table = any(kw in _header_lower for kw in _STAT_TABLE_KW)
if _is_hp_table and not _is_result_table:
return table_text # Skip sanitization for HP/config tables
if _is_stat_table:
return table_text # Skip sanitization for statistical test tables
# BUG-184: Per-column HP detection — classify each column header
# as HP-type (skip sanitization) or result-type (sanitize).
# This handles mixed tables like "| Method | LR | Acc | F1 |"
# where LR should be preserved but Acc/F1 are verified.
_HP_COL_KW = {
"lr", "learning rate", "batch", "epoch", "optimizer",
"schedule", "warmup", "decay", "dropout", "momentum",
"clip", "epsilon", "eps", "beta", "alpha", "gamma",
"lambda", "weight decay", "wd", "temperature", "temp",
"hidden", "dim", "layers", "heads", "steps", "iterations",
"seed", "patience", "#param", "params", "size", "depth",
"width", "channels", "kernel", "stride", "padding",
# BUG-224: Statistical test columns (derived, not in experiment data)
"t-stat", "t stat", "p-value", "p value", "p-val",
"cohen", "effect", "ci lower", "ci upper", "difference",
}
_hp_cols: set[int] = set() # column indices that are HP columns
if lines:
_hdr_cells = lines[0].split("|")
for _ci, _hc in enumerate(_hdr_cells):
_hc_low = _hc.strip().lower()
if any(kw in _hc_low for kw in _HP_COL_KW):
_hp_cols.add(_ci)
tables_processed += 1
sanitized_lines: list[str] = []
for i, line in enumerate(lines):
# Skip header row and separator row
is_separator = bool(
_re_san.match(r"^[ \t]*\|[\s:|-]+\|[ \t]*$", line)
)
is_header = i == 0 # first line is typically the header
if is_separator or is_header:
sanitized_lines.append(line)
continue
# BUG-175: Split by pipe and only sanitize cells after
# the first data column (which typically contains method
# names, condition labels, etc.)
cells = line.split("|")
sanitized_cells: list[str] = []
for ci, cell in enumerate(cells):
# Skip first non-empty cell (method/label column),
# empty edge cells, and BUG-184 HP-classified columns
if ci <= 1 or not cell.strip() or ci in _hp_cols:
sanitized_cells.append(cell)
else:
sanitized_cells.append(
num_pat.sub(_replace_num, cell)
)
sanitized_lines.append("|".join(sanitized_cells))
return "\n".join(sanitized_lines)
sanitized = table_pat.sub(_sanitize_table, paper)
# --- BUG-211: LaTeX tabular sanitization ---
# LLMs sometimes write results in LaTeX \begin{tabular} format inside
# the markdown paper (often within ```latex fences). The markdown
# table regex above misses these entirely, allowing fabricated numbers
# to pass through unchecked.
latex_tab_pat = _re_san.compile(
r"(\\begin\{tabular\}.*?\\end\{tabular\})",
_re_san.DOTALL,
)
# Keywords for HP-table vs result-table classification (reuse from above)
_LTX_HP_KW = {
"hyperparameter", "hyper-parameter", "configuration", "config",
"setting", "learning rate", "lr", "batch size", "optimizer",
}
_LTX_RESULT_KW = {
"accuracy", "acc", "loss", "f1", "auroc", "auc", "precision",
"recall", "reward", "score", "metric", "performance", "result",
}
# BUG-224: Statistical analysis LaTeX tables — derived values
_LTX_STAT_KW = {
"t-statistic", "t-stat", "t statistic", "p-value", "p value",
"paired", "cohen", "effect size", "statistical", "significance",
}
def _sanitize_latex_table(match: _re_san.Match[str]) -> str:
nonlocal tables_processed
block = match.group(0)
# Heuristic: look at the first ~300 chars (column spec + header row)
# to decide HP vs result table. Also check preceding \caption if
# the match is part of a \begin{table} environment — we can look
# backwards a bit in the full text for the caption.
_start = match.start()
_context = sanitized[max(0, _start - 300):_start + 300].lower()
_is_hp = any(kw in _context for kw in _LTX_HP_KW)
_is_res = any(kw in _context for kw in _LTX_RESULT_KW)
# BUG-224: Statistical test tables — derived values not in experiment data
_is_stat = any(kw in _context for kw in _LTX_STAT_KW)
if _is_hp and not _is_res:
return block # HP/config table — skip
if _is_stat:
return block # Statistical analysis table — skip
tables_processed += 1
# Split into rows by \\ (LaTeX row separator).
# We split on \\ but keep the delimiter so we can reconstruct.
parts = _re_san.split(r"(\\\\)", block)
result_parts: list[str] = []
_seen_midrule = False
for part in parts:
# Preserve row separators as-is
if part == "\\\\":
result_parts.append(part)
continue
_stripped = part.strip()
# Rule lines — no numbers to sanitize
if _re_san.search(
r"\\(hline|toprule|midrule|bottomrule|cline|cmidrule)",
_stripped,
):
if "midrule" in _stripped or "hline" in _stripped:
_seen_midrule = True
result_parts.append(part)
continue
# Column spec line (contains \begin{tabular}{...})
if r"\begin{tabular}" in part:
result_parts.append(part)
continue
# End line
if r"\end{tabular}" in part:
result_parts.append(part)
continue
# Header row: rows before the first \midrule/\hline
if not _seen_midrule:
result_parts.append(part)
continue
# Data row — split by & and sanitize cells after the first
cells = part.split("&")
sanitized_cells: list[str] = []
for ci, cell in enumerate(cells):
if ci == 0:
# First cell is method/condition name — preserve
sanitized_cells.append(cell)
else:
sanitized_cells.append(num_pat.sub(_replace_num, cell))
result_parts.append("&".join(sanitized_cells))
return "".join(result_parts)
sanitized = latex_tab_pat.sub(_sanitize_latex_table, sanitized)
# --- Improvement F: Prose-level anti-fabrication ---
# Scan Results/Experiments sections for inline numeric claims like
# "achieved 94.2% accuracy" or "obtained an AUROC of 0.87".
# Replace unverified numbers with "[value removed]".
prose_numbers_replaced = 0
_prose_pattern = _re_san.compile(
r"(?:achiev|obtain|reach|attain|yield|report|record|produc|demonstrat|show|observ)"
r"(?:ed|es|ing|s)?\s+"
r"(?:an?\s+)?(?:\w+\s+)?(?:of\s+)?"
r"(\d+\.?\d*)\s*"
r"(%|\\%)?",
_re_san.IGNORECASE,
)
# Only process lines in Results/Experiments sections
_in_results_section = False
_results_headers = _re_san.compile(
r"^#{1,3}\s*(Results|Experiments|Experimental|Evaluation|Ablation)",
_re_san.IGNORECASE,
)
_any_header = _re_san.compile(r"^#{1,3}\s+")
_sanitized_lines = []
for _line in sanitized.split("\n"):
if _results_headers.match(_line):
_in_results_section = True
elif _any_header.match(_line) and _in_results_section:
# Check if we're leaving Results for a different top-level section
_header_text = _line.lstrip("#").strip().lower()
if _header_text and not any(kw in _header_text for kw in
("result", "experiment", "ablation", "evaluation", "comparison")):
_in_results_section = False
if _in_results_section and "|" not in _line: # skip table rows
def _replace_prose_num(m: _re_san.Match[str]) -> str:
nonlocal prose_numbers_replaced
num_str = m.group(1)
try:
val = float(num_str)
except ValueError:
return m.group(0)
# Skip common constants / small integers
if val in _SANITIZER_ALWAYS_ALLOWED:
return m.group(0)
if val == int(val) and abs(val) <= 20:
return m.group(0)
if _is_verified(val):
return m.group(0)
prose_numbers_replaced += 1
return m.group(0).replace(num_str + (m.group(2) or ""), "[value removed]")
_line = _prose_pattern.sub(_replace_prose_num, _line)
_sanitized_lines.append(_line)
sanitized = "\n".join(_sanitized_lines)
report = {
"sanitized": numbers_replaced > 0 or prose_numbers_replaced > 0,
"tables_processed": tables_processed,
"numbers_replaced": numbers_replaced,
"numbers_kept": numbers_kept,
"prose_numbers_replaced": prose_numbers_replaced,
"verified_values_count": len(verified_values),
"replaced_samples": replaced_values[:20],
"generated": _utcnow_iso(),
}
return sanitized, report
# ---------------------------------------------------------------------------
# BUG-176: Missing citation resolution
# BUG-194: Validate search results to avoid replacing correct entries with
# garbage. Previous code searched by cite-key fragments (e.g.
# "he 2016 deep") which returned completely unrelated papers.
# Fix: (1) consult seminal_papers.yaml first, (2) require title-
# similarity validation for API results, (3) build better queries.
# ---------------------------------------------------------------------------
# Minimum title-similarity between search result and expected title/query
# for a result to be accepted. Prevents "Jokowi and the New Developmentalism"
# from replacing "Deep Residual Learning for Image Recognition".
_CITATION_RESOLVE_MIN_SIMILARITY = 0.30
def _load_seminal_papers_by_key() -> dict[str, dict]:
"""Load seminal_papers.yaml and index by cite_key.
Returns dict like::
{"he2016deep": {"title": "Deep Residual Learning...", "authors": "He et al.", ...}, ...}
Returns empty dict on any failure (missing file, bad YAML, etc.).
"""
try:
from researchclaw.data import _load_all as _load_seminal_all
all_papers = _load_seminal_all()
return {p["cite_key"]: p for p in all_papers if "cite_key" in p}
except Exception: # noqa: BLE001
return {}
def _seminal_to_bibtex(paper: dict, cite_key: str) -> str:
"""Convert a seminal_papers.yaml entry dict to a BibTeX string."""
title = paper.get("title", "Unknown")
authors = paper.get("authors", "Unknown")
year = paper.get("year", "")
venue = paper.get("venue", "")
# Decide entry type
venue_lower = (venue or "").lower()
is_conf = any(kw in venue_lower for kw in (
"neurips", "nips", "icml", "iclr", "cvpr", "eccv", "iccv",
"aaai", "acl", "emnlp", "naacl", "sigir", "kdd", "www",
"ijcai", "conference", "proc", "workshop",
))
if is_conf:
return (
f"@inproceedings{{{cite_key},\n"
f" title = {{{title}}},\n"
f" author = {{{authors}}},\n"
f" year = {{{year}}},\n"
f" booktitle = {{{venue}}},\n"
f"}}"
)
return (
f"@article{{{cite_key},\n"
f" title = {{{title}}},\n"
f" author = {{{authors}}},\n"
f" year = {{{year}}},\n"
f" journal = {{{venue}}},\n"
f"}}"
)
def _resolve_missing_citations(
missing_keys: set[str],
existing_bib: str,
) -> tuple[set[str], list[str]]:
"""Try to find BibTeX entries for citation keys not in references.bib.
Parses each cite_key (e.g. ``hendrycks2017baseline``) into an author name
and year, then searches academic APIs. Returns ``(resolved_keys,
new_bib_entries)`` where each entry is a complete BibTeX string.
BUG-194 fix: Three-layer resolution strategy:
1. **Seminal lookup** — check seminal_papers.yaml (zero API calls, exact match)
2. **API search with validation** — search Semantic Scholar / arXiv, but ONLY
accept results whose title has ≥ 30% word overlap with query terms.
Previously any year-matching result was blindly accepted, causing
foundational papers to be replaced with garbage.
3. **Skip** — if no confident match, leave the citation unresolved rather
than inject a wrong paper.
Gracefully returns empty results on any network failure.
"""
import re as _re176
import time as _time176
resolved: set[str] = set()
new_entries: list[str] = []
def _parse_cite_key(key: str) -> tuple[str, str, str]:
"""Extract (author, year, keyword_hint) from a citation key.
Common patterns:
``he2016deep`` → ("he", "2016", "deep")
``vaswani2017attention`` → ("vaswani", "2017", "attention")
``goodfellow2014generative`` → ("goodfellow", "2014", "generative")
"""
m = _re176.match(r"([a-zA-Z]+?)(\d{4})(.*)", key)
if m:
return m.group(1), m.group(2), m.group(3)
return key, "", ""
def _title_word_overlap(title: str, query_words: list[str]) -> float:
"""Word-overlap score between a paper title and query keywords.
Returns fraction of query words found in the title (0.0–1.0).
Used to validate that a search result is actually relevant.
"""
if not query_words:
return 0.0
title_lower = set(
_re176.sub(r"[^a-z0-9\s]", "", title.lower()).split()
) - {""}
if not title_lower:
return 0.0
matched = sum(1 for w in query_words if w.lower() in title_lower)
return matched / len(query_words)
# --- Layer 1: Seminal papers lookup (no API calls) ---
seminal_by_key = _load_seminal_papers_by_key()
for key in sorted(missing_keys):
if key in seminal_by_key and key not in existing_bib:
sp = seminal_by_key[key]
bib_entry = _seminal_to_bibtex(sp, key)
new_entries.append(bib_entry)
resolved.add(key)
logger.info(
"BUG-194: Resolved %r via seminal_papers.yaml → %r (%s)",
key, sp.get("title", "")[:60], sp.get("year", ""),
)
# Remaining keys that weren't in the seminal database AND aren't already
# present in the existing bib (no point re-resolving keys we already have).
remaining = sorted(
k for k in (missing_keys - resolved) if k not in existing_bib
)
if not remaining:
return resolved, new_entries
# --- Layer 2: API search with title-similarity validation ---
try:
from researchclaw.literature.search import search_papers
except ImportError:
logger.debug("BUG-176: literature.search not available, skipping resolution")
return resolved, new_entries
for key in remaining:
author, year, hint = _parse_cite_key(key)
if not author or not year:
continue
# BUG-194: Build a better search query.
# Instead of "he 2016 deep", use "he deep residual learning 2016" or
# at minimum, split camelCase hints into separate words.
# Split hint on word boundaries (camelCase or underscore).
hint_words = _re176.findall(r"[a-zA-Z]+", hint) if hint else []
# The query words used for validation
query_words = [author] + hint_words
# Build search query: author + hint words + year (year helps but isn't
# the primary discriminator anymore)
query_parts = [author] + hint_words + [year]
query = " ".join(query_parts)
try:
results = search_papers(query, limit=5, deduplicate=True)
except Exception as exc:
logger.debug("BUG-176: Search failed for %r: %s", key, exc)
continue
if not results:
logger.debug(
"BUG-194: No search results for %r (query=%r), skipping",
key, query,
)
continue
# BUG-194: Find best match by title-word-overlap AND year match.
# Previously the code just took the first year-matching result.
best = None
best_score = -1.0
for paper in results:
overlap = _title_word_overlap(paper.title, query_words)
year_bonus = 0.2 if str(paper.year) == year else 0.0
# Also give bonus for author name appearing in paper.authors
author_bonus = 0.0
if any(author.lower() in a.name.lower() for a in paper.authors):
author_bonus = 0.2
score = overlap + year_bonus + author_bonus
if score > best_score:
best_score = score
best = paper
if best is None:
continue
# BUG-194: Validate the result — require minimum similarity.
# This is the KEY fix: previously ANY result was accepted blindly.
overlap = _title_word_overlap(best.title, query_words)
if overlap < _CITATION_RESOLVE_MIN_SIMILARITY:
logger.info(
"BUG-194: Rejecting search result for %r — title %r has "
"too-low overlap (%.2f < %.2f) with query words %r",
key, best.title[:60], overlap,
_CITATION_RESOLVE_MIN_SIMILARITY, query_words,
)
continue
# Year must also match (or be within 1 year — sometimes conferences
# vs arXiv preprint have different years)
if year and best.year:
year_diff = abs(int(year) - int(best.year))
if year_diff > 1:
logger.info(
"BUG-194: Rejecting search result for %r — year mismatch "
"(%s vs %s, diff=%d)",
key, year, best.year, year_diff,
)
continue
# Generate BibTeX with the ORIGINAL cite_key (so \cite{key} works)
bib_entry = best.to_bibtex()
# Replace the auto-generated cite_key with the one used in the paper
orig_key_match = _re176.match(r"@(\w+)\{([^,]+),", bib_entry)
if orig_key_match:
bib_entry = bib_entry.replace(
f"@{orig_key_match.group(1)}{{{orig_key_match.group(2)},",
f"@{orig_key_match.group(1)}{{{key},",
1,
)
# Verify entry doesn't duplicate an existing key
if key not in existing_bib:
new_entries.append(bib_entry)
resolved.add(key)
logger.info(
"BUG-194: Resolved %r via API → %r (%s, overlap=%.2f)",
key, best.title[:60], best.year, overlap,
)
else:
logger.debug(
"BUG-194: Key %r already in bib, skipping API result", key,
)
# Rate limit: 0.5s between API calls
_time176.sleep(0.5)
return resolved, new_entries
# ---------------------------------------------------------------------------
# Stage 22: Export & Publish
# ---------------------------------------------------------------------------
def _execute_export_publish(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
revised = _read_prior_artifact(run_dir, "paper_revised.md") or ""
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "export_publish")
sp = _pm.for_stage("export_publish", evolution_overlay=_overlay, revised=revised)
resp = _chat_with_prompt(
llm,
sp.system,
sp.user,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
final_paper = resp.content
# Content guard: reject LLM output that truncates the paper
if revised and len(final_paper) < 0.6 * len(revised):
logger.warning(
"Stage 22: LLM output is %.0f%% of input length — using original",
100 * len(final_paper) / max(len(revised), 1),
)
final_paper = revised
else:
final_paper = revised
if not final_paper.strip():
final_paper = "# Final Paper\n\nNo content generated."
# --- Always-on fabrication sanitization (Phase 1 anti-fabrication) ---
# Back up pre-sanitized version
(stage_dir / "paper_presanitized.md").write_text(
final_paper, encoding="utf-8"
)
# Sanitize unverified data in tables — always-on, not just degraded mode
final_paper, _san_report = _sanitize_fabricated_data(
final_paper, run_dir
)
(stage_dir / "sanitization_report.json").write_text(
json.dumps(_san_report, indent=2), encoding="utf-8"
)
if _san_report.get("numbers_replaced", 0) > 0:
logger.info(
"Stage 22: Fabrication sanitization — %d numbers replaced, %d kept",
_san_report.get("numbers_replaced", 0),
_san_report.get("numbers_kept", 0),
)
# Graceful degradation: insert notice only when quality gate was degraded
_degradation_signal_path = run_dir / "degradation_signal.json"
if _degradation_signal_path.exists():
try:
_deg_signal = json.loads(
_degradation_signal_path.read_text(encoding="utf-8")
)
except (json.JSONDecodeError, OSError):
_deg_signal = {}
# Insert degradation notice after abstract
_deg_score = _deg_signal.get("score", "N/A")
_deg_threshold = _deg_signal.get("threshold", "N/A")
_deg_notice = (
"\n\n> **Note:** This paper was produced in degraded mode. "
f"Quality gate score ({_deg_score}/{_deg_threshold}) was below "
"threshold. Unverified numerical results in tables have been "
"replaced with `---` and require independent verification.\n\n"
)
# Try to insert after ## Abstract section
_abstract_markers = ["## Abstract\n", "# Abstract\n"]
_notice_inserted = False
for _marker in _abstract_markers:
if _marker in final_paper:
_marker_end = final_paper.index(_marker) + len(_marker)
# Find the end of the abstract paragraph
_next_section = final_paper.find("\n## ", _marker_end)
_next_heading = final_paper.find("\n# ", _marker_end)
_insert_pos = min(
p for p in (_next_section, _next_heading)
if p > 0
) if any(p > 0 for p in (_next_section, _next_heading)) else len(final_paper)
final_paper = (
final_paper[:_insert_pos]
+ _deg_notice
+ final_paper[_insert_pos:]
)
_notice_inserted = True
break
if not _notice_inserted:
# Fallback: prepend to paper
final_paper = _deg_notice + final_paper
logger.info(
"Stage 22: Applied degraded-mode notice (score=%s, threshold=%s)",
_deg_score, _deg_threshold,
)
# IMP-3: Deduplicate "due to computational constraints" — keep at most 1
import re as _re_imp3
_CONSTRAINT_PAT = _re_imp3.compile(
r"[Dd]ue to computational constraints", _re_imp3.IGNORECASE
)
_matches = list(_CONSTRAINT_PAT.finditer(final_paper))
if len(_matches) > 1:
# Keep only the first occurrence; remove subsequent ones by
# deleting the enclosing sentence.
for m in reversed(_matches[1:]):
# Find sentence boundaries around the match
start = final_paper.rfind(".", 0, m.start())
start = start + 1 if start >= 0 else m.start()
end = final_paper.find(".", m.end())
end = end + 1 if end >= 0 else m.end()
sentence = final_paper[start:end].strip()
if sentence:
final_paper = final_paper[:start] + final_paper[end:]
final_paper = re.sub(r"[^\S\n]{2,}", " ", final_paper)
logger.info(
"Stage 22: Removed %d duplicate 'computational constraints' "
"disclaimers",
len(_matches) - 1,
)
# IMP-19 Layer 2: Ensure at least figures are referenced in the paper
import re as _re_fig
chart_files = []
# BUG-215: Also search stage-14* versioned dirs (stage-14_v1, etc.)
# in case stage-14/ was renamed and never recreated.
_chart_search_dirs = [stage_dir / "charts", run_dir / "stage-14" / "charts"]
for _s14_charts in sorted(run_dir.glob("stage-14*/charts"), reverse=True):
if _s14_charts not in _chart_search_dirs:
_chart_search_dirs.append(_s14_charts)
for _chart_src_dir in _chart_search_dirs:
if _chart_src_dir.is_dir():
chart_files.extend(sorted(_chart_src_dir.glob("*.png")))
# BUG-190: Also inject charts not already referenced in the paper.
# The old condition only fired when NO figures were present. Now we
# filter to only unreferenced charts, so partially-illustrated papers
# also get the remaining charts injected.
_already_referenced = set()
for _cf in chart_files:
if _cf.name in final_paper:
_already_referenced.add(_cf.name)
chart_files = [cf for cf in chart_files if cf.name not in _already_referenced]
if chart_files:
# Distribute figures to relevant sections based on filename keywords
_fig_placement: dict[str, list[str]] = {
"method": [], # architecture, method, model, pipeline diagrams
"result": [], # experiment, comparison, ablation charts
"intro": [], # concept, overview, illustration
}
_fig_counter = len(_already_referenced) # start numbering after existing figs
for cf in chart_files[:6]:
_fig_counter += 1
stem_lower = cf.stem.lower()
label = cf.stem.replace("_", " ").title()
fig_md = f""
if any(k in stem_lower for k in ("architecture", "model", "pipeline", "method", "flowchart")):
_fig_placement["method"].append(fig_md)
elif any(k in stem_lower for k in ("experiment", "comparison", "ablation", "result", "metric")):
_fig_placement["result"].append(fig_md)
elif any(k in stem_lower for k in ("concept", "overview", "illustration", "threat", "attack")):
_fig_placement["intro"].append(fig_md)
else:
_fig_placement["result"].append(fig_md) # default to results
# Insert figures at relevant section boundaries.
# BUG-200: Match both H1 (#) and H2 (##) headings — LLMs generate
# either level depending on the writing_structure prompt.
_section_markers = {
"method": ["# Method", "## Method", "# Methodology", "## Methodology",
"# Approach", "## Approach", "# Framework", "## Framework",
"## 3. Method", "## 3 Method"],
"result": ["# Results", "## Results", "# Experiments", "## Experiments",
"# Evaluation", "## Evaluation",
"## 5. Results", "## 4. Experiments", "## 5 Results"],
"intro": ["# Related Work", "## Related Work", "# Background",
"## Background", "## 2. Related", "## 2 Related Work"],
}
_total_inserted = 0
for category, figs in _fig_placement.items():
if not figs:
continue
fig_block = "\n\n" + "\n\n".join(figs) + "\n\n"
inserted = False
for marker in _section_markers.get(category, []):
if marker in final_paper:
# Insert BEFORE the marker section (so figure appears at end of previous section)
final_paper = final_paper.replace(marker, fig_block + marker, 1)
inserted = True
_total_inserted += len(figs)
break
if not inserted:
# Fallback: insert before Conclusion/Limitations/Discussion
for fallback in ["# Conclusion", "## Conclusion",
"# Limitations", "## Limitations",
"# Discussion", "## Discussion"]:
if fallback in final_paper:
final_paper = final_paper.replace(fallback, fig_block + fallback, 1)
inserted = True
_total_inserted += len(figs)
break
if not inserted:
# BUG-200: Last resort — insert before closing fence marker
# rather than appending after it (which puts content outside
# the markdown fence and gets dropped by converter).
_fence_end = final_paper.rfind("\n```")
if _fence_end > 0:
final_paper = (
final_paper[:_fence_end] + fig_block + final_paper[_fence_end:]
)
else:
final_paper += fig_block
_total_inserted += len(figs)
logger.info(
"IMP-19: Injected %d figure references into paper_final.md (distributed across sections)",
_total_inserted,
)
# IMP-24: Detect excessive number repetition
_numbers_found = _re_fig.findall(r"\b\d+\.\d{2,}\b", final_paper)
_num_counts = Counter(_numbers_found)
_repeated = {n: c for n, c in _num_counts.items() if c > 3}
if _repeated:
logger.warning(
"IMP-24: Numbers repeated >3 times: %s",
_repeated,
)
(stage_dir / "paper_final.md").write_text(final_paper, encoding="utf-8")
# --- Legacy fabrication sanitization (disabled — superseded by Phase 1 _sanitize_fabricated_data above) ---
# Kept but guarded: Phase 1 always-on sanitization handles this now.
# Only run if Phase 1 was somehow skipped (should never happen).
_fab_flags_text = _read_prior_artifact(run_dir, "fabrication_flags.json") or ""
_fab_flags = _safe_json_loads(_fab_flags_text, {}) if _fab_flags_text else {}
if (
isinstance(_fab_flags, dict)
and _fab_flags.get("fabrication_suspected")
and _san_report.get("numbers_replaced", 0) == 0 # Phase 1 didn't run/replace
):
import re as _re_fab
_real_vals = set()
for rv in _fab_flags.get("real_metric_values", []):
if isinstance(rv, (int, float)) and math.isfinite(rv):
_real_vals.add(str(round(rv, 4)))
_real_vals.add(str(round(rv, 2)))
_real_vals.add(str(round(rv, 1)))
if rv == int(rv):
_real_vals.add(str(int(rv)))
def _sanitize_number(m: _re_fab.Match) -> str: # type: ignore[name-defined]
"""Replace fabricated numbers with '--' but keep real ones."""
num_str = m.group(0)
# Keep the number if it matches any known real metric value
try:
num_val = float(num_str)
if not math.isfinite(num_val):
return "--"
rounded_strs = {
str(round(num_val, 4)),
str(round(num_val, 2)),
str(round(num_val, 1)),
*(
[str(int(num_val))] if num_val == int(num_val) else []
),
}
if rounded_strs & _real_vals:
return num_str # real value — keep it
except (ValueError, OverflowError):
return num_str
return "--"
# Only sanitize numbers in Results/Experiments/Evaluation/Ablation sections
_result_section_pat = _re_fab.compile(
r"(##\s*(?:\d+\.?\s*)?(?:Results|Experiments|Evaluation|Ablation"
r"|Experimental Results|Quantitative).*?)(?=\n##\s|\Z)",
_re_fab.DOTALL | _re_fab.IGNORECASE,
)
_sanitized_count = 0
def _sanitize_section(sec_match: _re_fab.Match) -> str: # type: ignore[name-defined]
nonlocal _sanitized_count
section_text = sec_match.group(0)
# Replace decimal numbers (e.g., 73.42, 0.891) but NOT integers
# that are likely structural (year, section number, figure number)
def _replace_in_section(m: _re_fab.Match) -> str: # type: ignore[name-defined]
nonlocal _sanitized_count
result = _sanitize_number(m)
if result == "--":
_sanitized_count += 1
return result
return _re_fab.sub(
r"\b\d+\.\d{1,6}\b", _replace_in_section, section_text
)
final_paper = _result_section_pat.sub(_sanitize_section, final_paper)
if _sanitized_count > 0:
logger.warning(
"Stage 22: Fabrication sanitization — blanked %d unsupported "
"numbers in Results sections (experiment had no real metrics)",
_sanitized_count,
)
# Rewrite the sanitized paper
(stage_dir / "paper_final.md").write_text(
final_paper, encoding="utf-8"
)
# Initialize artifacts list
artifacts = ["paper_final.md"]
# F2.7: Post-process citations — [cite_key] → \cite{cite_key}
# and copy final references.bib to export stage
_ay_map: dict[str, str] = {} # BUG-102: author-year → cite_key map
bib_text = _read_prior_artifact(run_dir, "references.bib")
if bib_text:
# Replace [cite_key] patterns in the final paper with \cite{cite_key}
# Collect all valid cite_keys from the bib file
import re as _re
valid_keys = set(_re.findall(r"@\w+\{([^,]+),", bib_text))
# BUG-102: Recover author-year citations → [cite_key] format.
# When Stage 19 (paper_revision) converts [cite_key] to [Author et al., 2024],
# the downstream regex can't match them. Build a reverse map from bib entries.
def _build_author_year_map(bib: str, keys: set[str]) -> dict[str, str]:
"""Build mapping from author-year patterns to cite_keys.
Returns dict like:
"Raissi et al., 2019" → "raissi2019physicsinformed"
"Tavella and Randall, 2000" → "tavella2000pricing"
"""
mapping: dict[str, str] = {}
# Parse each bib entry for author + year
# BUG-DA8-17: Allow newline OR whitespace before closing brace
# Use \n} or just } at start-of-line to avoid greedy cross-entry match
entry_pat = _re.compile(
r"@\w+\{([^,]+),\s*(.*?)(?:\n\}|^[ \t]*\})", _re.DOTALL | _re.MULTILINE
)
for m in entry_pat.finditer(bib):
key = m.group(1).strip()
if key not in keys:
continue
body = m.group(2)
# Extract author field
author_m = _re.search(
r"author\s*=\s*[\{\"](.*?)[\}\"]", body, _re.IGNORECASE
)
year_m = _re.search(
r"year\s*=\s*[\{\"]?(\d{4})[\}\"]?", body, _re.IGNORECASE
)
if not author_m or not year_m:
continue
author_raw = author_m.group(1).strip()
year = year_m.group(1)
# Parse author names (split on " and ")
authors = [a.strip() for a in _re.split(r"\s+and\s+", author_raw)]
# Extract last names
last_names = []
for a in authors:
if "," in a:
last_names.append(a.split(",")[0].strip())
else:
parts = a.split()
last_names.append(parts[-1] if parts else a)
if not last_names:
continue
# Generate author-year patterns:
# 1 author: "Smith, 2024"
# 2 authors: "Smith and Jones, 2024"
# 3+ authors: "Smith et al., 2024"
if len(last_names) == 1:
patterns = [f"{last_names[0]}, {year}"]
elif len(last_names) == 2:
patterns = [
f"{last_names[0]} and {last_names[1]}, {year}",
f"{last_names[0]} \\& {last_names[1]}, {year}",
]
else:
patterns = [
f"{last_names[0]} et al., {year}",
f"{last_names[0]} et al. {year}",
]
# Also add "Smith and Jones, 2024" for first two authors
patterns.append(
f"{last_names[0]} and {last_names[1]}, {year}"
)
for pat in patterns:
mapping[pat] = key
return mapping
_ay_map = _build_author_year_map(bib_text, valid_keys)
if _ay_map:
# Count how many author-year citations exist in the paper
_ay_found = 0
for _ay_pat in _ay_map:
if _ay_pat in final_paper:
_ay_found += 1
if _ay_found > 0:
logger.info(
"Stage 22: Found %d author-year citation patterns — "
"converting back to [cite_key] format.",
_ay_found,
)
# Sort by longest pattern first to avoid partial matches
for _ay_pat in sorted(_ay_map, key=len, reverse=True):
_ay_key = _ay_map[_ay_pat]
# Match [Author et al., 2024] or [Author and Jones, 2024; ...]
# Handle single-citation brackets
final_paper = final_paper.replace(
f"[{_ay_pat}]", f"[{_ay_key}]"
)
# Handle within multi-citation brackets [A et al., 2020; B et al., 2021]
# Replace the author-year segment only inside [...] brackets
final_paper = _re.sub(
r'\[([^\]]*?)' + _re.escape(_ay_pat) + r'([^\]]*?)\]',
lambda _m: '[' + _m.group(1) + _ay_key + _m.group(2) + ']',
final_paper,
)
# Fix multi-key brackets: [key1; key2] → [key1, key2]
# (author-year uses semicolons, cite-keys use commas)
def _fix_semicolon_cites(m_sc: _re.Match[str]) -> str:
inner = m_sc.group(1)
# Only convert if ALL segments look like cite keys
parts = [p.strip() for p in inner.split(";")]
_ck = r"[a-zA-Z][a-zA-Z0-9_-]*\d{4}[a-zA-Z0-9_]*"
if all(_re.fullmatch(_ck, p) for p in parts):
return "[" + ", ".join(parts) + "]"
return m_sc.group(0)
final_paper = _re.sub(
r"\[([^\]]+;[^\]]+)\]", _fix_semicolon_cites, final_paper
)
(stage_dir / "paper_final.md").write_text(
final_paper, encoding="utf-8"
)
# R10-Fix4: Citation cross-validation
# BUG-187: Also parse multi-key brackets like [key1, key2, key3].
# The old regex only matched single-key brackets [key2020word].
_cite_key_pat = r"[a-zA-Z]+\d{4}[a-zA-Z0-9_-]*"
cited_keys_in_paper: set[str] = set()
# Single-key brackets
for m in _re.finditer(rf"\[({_cite_key_pat})\]", final_paper):
cited_keys_in_paper.add(m.group(1))
# Multi-key brackets [key1, key2] or [key1; key2]
for m in _re.finditer(r"\[([^\]]{10,300})\]", final_paper):
inner = m.group(1)
# Only parse if it looks like citation keys (has year-like digits)
parts = _re.split(r"[,;]\s*", inner)
if all(_re.fullmatch(_cite_key_pat, p.strip()) for p in parts if p.strip()):
for p in parts:
if p.strip():
cited_keys_in_paper.add(p.strip())
if valid_keys and cited_keys_in_paper:
invalid_keys = cited_keys_in_paper - valid_keys
if invalid_keys:
logger.warning(
"Stage 22: Found %d citation keys in paper not in references.bib: %s",
len(invalid_keys),
", ".join(sorted(invalid_keys)[:20]),
)
# BUG-176: Try to resolve missing citations before removing them.
# Parse cite_key → search query, look up via academic APIs,
# and add found entries to references.bib.
resolved_keys: set[str] = set()
new_bib_entries: list[str] = []
if len(invalid_keys) <= 30: # Sanity: don't flood APIs
resolved_keys, new_bib_entries = _resolve_missing_citations(
invalid_keys, bib_text
)
if resolved_keys:
valid_keys.update(resolved_keys)
bib_text += "\n" + "\n\n".join(new_bib_entries) + "\n"
logger.info(
"Stage 22: Resolved %d/%d missing citations via API lookup",
len(resolved_keys), len(invalid_keys),
)
still_invalid = invalid_keys - resolved_keys
if still_invalid:
# IMP-29: Remove remaining unresolvable citations from
# BOTH single-key and multi-key brackets.
import re as _re_imp29
for bad_key in still_invalid:
# Remove single-key brackets
final_paper = final_paper.replace(f"[{bad_key}]", "")
# Remove from multi-key brackets: [good, BAD, good] → [good, good]
def _remove_from_multi(m: _re.Match) -> str:
inner = m.group(1)
parts = [p.strip() for p in _re.split(r"[,;]\s*", inner)]
filtered = [p for p in parts if p != bad_key]
if not filtered:
return ""
return "[" + ", ".join(filtered) + "]"
final_paper = _re_imp29.sub(
r"\[([^\]]*\b" + _re.escape(bad_key) + r"\b[^\]]*)\]",
_remove_from_multi,
final_paper,
)
# Clean up whitespace artifacts from removed citations
final_paper = _re_imp29.sub(r" +", " ", final_paper)
final_paper = _re_imp29.sub(r" ([.,;:)])", r"\1", final_paper)
(stage_dir / "paper_final.md").write_text(final_paper, encoding="utf-8")
if still_invalid:
(stage_dir / "invalid_citations.json").write_text(
json.dumps(sorted(still_invalid), indent=2), encoding="utf-8"
)
artifacts.append("invalid_citations.json")
if resolved_keys:
(stage_dir / "resolved_citations.json").write_text(
json.dumps(sorted(resolved_keys), indent=2), encoding="utf-8"
)
artifacts.append("resolved_citations.json")
final_paper_latex = final_paper # default: no citation conversion
if valid_keys:
_CITE_KEY_PAT = r"[a-zA-Z][a-zA-Z0-9_-]*\d{4}[a-zA-Z0-9]*"
# Step 1: Convert multi-key brackets [key1, key2] → \cite{key1, key2}
def _replace_multi_cite(m: _re.Match[str]) -> str:
keys = [k.strip() for k in m.group(1).split(",")]
matched = [k for k in keys if k in valid_keys]
if matched:
return "\\cite{" + ", ".join(matched) + "}"
return m.group(0)
final_paper_latex = _re.sub(
rf"\[({_CITE_KEY_PAT}(?:\s*,\s*{_CITE_KEY_PAT})+)\]",
_replace_multi_cite,
final_paper,
)
# Step 2: Convert single-key brackets [key] → \cite{key}
def _replace_cite(m: _re.Match[str]) -> str:
key = m.group(1)
if key in valid_keys:
return f"\\cite{{{key}}}"
return m.group(0)
final_paper_latex = _re.sub(
rf"\[({_CITE_KEY_PAT})\]", _replace_cite, final_paper_latex
)
# Step 3: Merge adjacent \cite{a} \cite{b} → \cite{a, b}
def _merge_adjacent_cites(m: _re.Match[str]) -> str:
keys = _re.findall(r"\\cite\{([^}]+)\}", m.group(0))
return "\\cite{" + ", ".join(keys) + "}"
final_paper_latex = _re.sub(
r"\\cite\{[^}]+\}(?:\s*\\cite\{[^}]+\})+",
_merge_adjacent_cites,
final_paper_latex,
)
(stage_dir / "paper_final_latex.md").write_text(
final_paper_latex, encoding="utf-8"
)
artifacts.append("paper_final_latex.md")
# IMP-1: Prune uncited bibliography entries — keep only keys
# that actually appear in the paper text (bracket or \cite form).
if valid_keys:
_all_cited: set[str] = set()
# Bracket-format citations [key]
_all_cited.update(
_re.findall(r"\[([a-zA-Z]+\d{4}[a-zA-Z0-9_-]*)\]", final_paper)
)
# \cite{key, key2} format (original + latex-converted)
for _src in (
final_paper,
final_paper_latex,
):
for _cm in _re.finditer(r"\\cite\{([^}]+)\}", _src):
_all_cited.update(
k.strip() for k in _cm.group(1).split(",")
)
uncited_keys = valid_keys - _all_cited
if uncited_keys:
bib_text = _remove_bibtex_entries(bib_text, uncited_keys)
logger.info(
"Stage 22: Pruned %d uncited bibliography entries "
"(kept %d)",
len(uncited_keys),
len(valid_keys) - len(uncited_keys),
)
# Write final references.bib
(stage_dir / "references.bib").write_text(bib_text, encoding="utf-8")
artifacts.append("references.bib")
logger.info(
"Stage 22: Exported references.bib with %d entries",
len(valid_keys) if valid_keys else 0,
)
# Conference template: generate .tex file
try:
from researchclaw.templates import get_template, markdown_to_latex
tpl = get_template(config.export.target_conference)
# Use the latex-citation-processed version if available
tex_source = final_paper_latex
# Append NeurIPS-style checklist if target is a ML conference
if tpl.name in ("neurips_2024", "neurips_2025", "icml_2025", "icml_2026",
"iclr_2025", "iclr_2026"):
_has_exp = bool(_read_prior_artifact(run_dir, "experiment_summary.json"))
_checklist = _generate_neurips_checklist(
has_experiments=_has_exp,
has_code=True,
)
if "NeurIPS Paper Checklist" not in tex_source:
tex_source = tex_source.rstrip() + "\n\n" + _checklist
_t = _extract_paper_title(tex_source)
tex_content = markdown_to_latex(
tex_source,
tpl,
title=_t if _t != "Untitled Paper" else "",
authors=config.export.authors,
bib_file=config.export.bib_file,
bib_entries=_ay_map or None,
)
(stage_dir / "paper.tex").write_text(tex_content, encoding="utf-8")
artifacts.append("paper.tex")
logger.info(
"Stage 22: Generated paper.tex for %s (%d chars)",
tpl.display_name,
len(tex_content),
)
# --- Phase 1 anti-fabrication: verify paper against VerifiedRegistry ---
_vresult = None # BUG-DA8-04: Initialize before try to avoid fragile dir() check
try:
from researchclaw.pipeline.paper_verifier import verify_paper as _verify_paper
# BUG-222: Use best_only=True to validate against promoted best data only
from researchclaw.pipeline.verified_registry import (
VerifiedRegistry as _VR22,
)
_vr22 = _VR22.from_run_dir(
run_dir,
metric_direction=config.experiment.metric_direction,
best_only=True,
)
if _vr22.values:
_vresult = _verify_paper(tex_content, _vr22)
(stage_dir / "paper_verification.json").write_text(
json.dumps({
"passed": _vresult.passed,
"severity": _vresult.severity,
"total_checked": _vresult.total_numbers_checked,
"total_verified": _vresult.total_numbers_verified,
"strict_violations": _vresult.strict_violations,
"lenient_violations": _vresult.lenient_violations,
"fabrication_rate": round(_vresult.fabrication_rate, 4),
"unverified_numbers": [
{"value": u.value, "line": u.line_number,
"section": u.section, "in_table": u.in_table}
for u in _vresult.unverified_numbers[:20]
],
"fabricated_conditions": [
{"name": fc.name, "line": fc.line_number}
for fc in _vresult.fabricated_conditions
],
"config_warnings": getattr(_vresult, "config_warnings", []),
"summary": _vresult.summary,
}, indent=2),
encoding="utf-8",
)
logger.info(
"Stage 22: Paper verification — %s (%d checked, %d verified, "
"%d strict violations, fabrication_rate=%.1f%%)",
_vresult.severity,
_vresult.total_numbers_checked,
_vresult.total_numbers_verified,
_vresult.strict_violations,
_vresult.fabrication_rate * 100,
)
except Exception as _pv_exc:
logger.debug("Stage 22: Paper verification skipped: %s", _pv_exc)
# BUG-23 P1: Enforce REJECT verdict — sanitize unverified numbers
if _vresult is not None and getattr(_vresult, "severity", None) == "REJECT":
logger.warning(
"Stage 22: Paper REJECTED by verifier (fabrication_rate=%.1f%%, "
"%d strict violations). Sanitizing unverified numbers.",
_vresult.fabrication_rate * 100,
_vresult.strict_violations,
)
# Replace unverified numbers in strict sections/tables with "---"
import re as _re_san2
# BUG-R49-02: Section names that sound like results but are
# actually protocol/setup sections should NOT trigger strict
# sanitization. Exempt sections containing "dataset", "setup",
# "protocol", "hyperparameter", or "implementation".
_STRICT_EXEMPT_KW = {"dataset", "setup", "protocol",
"hyperparameter", "implementation",
"hardware", "infrastructure"}
_sanitized_tex = tex_content
_san2_count = 0
for _uv in sorted(_vresult.unverified_numbers, key=lambda u: -u.line_number):
# Only sanitize strict-section / in-table numbers
_uv_section_lower = (_uv.section or "").lower()
_uv_is_strict = any(
s in _uv_section_lower
for s in ("results", "experiment", "evaluation",
"ablation", "comparison", "analysis")
)
# BUG-R49-02: Exempt protocol/setup sections from strict mode
if _uv_is_strict and any(
kw in _uv_section_lower for kw in _STRICT_EXEMPT_KW
):
_uv_is_strict = False
if _uv_is_strict or _uv.in_table:
_lines = _sanitized_tex.split("\n")
if 0 < _uv.line_number <= len(_lines):
_orig_line = _lines[_uv.line_number - 1]
# BUG-R49-01: Use word-boundary regex instead of
# naive substring matching to avoid replacing numbers
# inside identifiers (e.g. "18" in "ResNet18").
# BUG-206: Include ASCII hyphen and Unicode hyphens
# (U+2010 hyphen, U+2011 non-breaking hyphen,
# U+2013 en-dash) so that model variant numbers
# like "34" in "ResNet-34" or "ResNet‑34" are not
# mistaken for unverified experimental values.
# BUG-210: Include period (.) so that fractional
# parts of decimals in condition names like
# "ema_decay_0.9" are not treated as standalone
# numbers (prevents "0.9" → "0.---").
_BOUNDARY = "A-Za-z0-9_\u2010\u2011\u2013\\-."
for _rep in (
f"{_uv.value:.4f}".rstrip("0").rstrip("."),
f"{_uv.value:.3f}",
f"{_uv.value:.2f}",
f"{_uv.value:.1f}",
f"{_uv.value:g}",
str(_uv.value),
):
# Word boundary: number must NOT be adjacent to
# alphanumeric, underscore, or hyphen on either side.
_pat = (
rf"(? _page_limit:
logger.warning(
"BUG-27: Paper is %d pages (limit %d). "
"Consider tightening content in revision.",
_qc.page_count, _page_limit,
)
except Exception as _qc_exc: # noqa: BLE001
logger.debug("Stage 22: Quality checks skipped: %s", _qc_exc)
else:
logger.warning("Stage 22: LaTeX compilation verification FAILED: %s", _compile_result.errors[:3])
# Add compilation failure comment to .tex
_tex_path = stage_dir / "paper.tex"
if _tex_path.exists():
_tex_content = _tex_path.read_text(encoding="utf-8")
if "% WARNING: Compilation failed" not in _tex_content:
_tex_content = (
"% WARNING: Compilation failed. Errors:\n"
+ "".join(f"% {e}\n" for e in _compile_result.errors[:5])
+ _tex_content
)
_tex_path.write_text(_tex_content, encoding="utf-8")
except Exception as _compile_exc: # noqa: BLE001
logger.debug("Stage 22: Compile verification skipped: %s", _compile_exc)
except Exception as exc: # noqa: BLE001
logger.error("LaTeX generation failed: %s", exc, exc_info=True)
# (Charts, BUG-99 path fix, and remove_missing_figures are now handled
# BEFORE compile_latex() — see "Pre-compilation" block above.)
# --- Code packaging: multi-file directory or single file ---
exp_final_dir_path = _read_prior_artifact(run_dir, "experiment_final/")
if exp_final_dir_path and Path(exp_final_dir_path).is_dir():
import ast
code_dir = stage_dir / "code"
code_dir.mkdir(parents=True, exist_ok=True)
all_code_combined = ""
code_file_names: list[str] = []
for src in sorted(Path(exp_final_dir_path).glob("*.py")):
(code_dir / src.name).write_bytes(src.read_bytes())
all_code_combined += src.read_text(encoding="utf-8") + "\n"
code_file_names.append(src.name)
# Detect dependencies from all files
detected: set[str] = set()
known_packages = {
"numpy": "numpy",
"torch": "torch",
"tensorflow": "tensorflow",
"sklearn": "scikit-learn",
"scikit-learn": "scikit-learn",
"scipy": "scipy",
"pandas": "pandas",
"matplotlib": "matplotlib",
"seaborn": "seaborn",
"transformers": "transformers",
"datasets": "datasets",
"jax": "jax",
}
try:
tree = ast.parse(all_code_combined)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
top = alias.name.split(".")[0]
if top in known_packages:
detected.add(known_packages[top])
elif isinstance(node, ast.ImportFrom) and node.module:
top = node.module.split(".")[0]
if top in known_packages:
detected.add(known_packages[top])
except SyntaxError:
pass
requirements = sorted(detected)
(code_dir / "requirements.txt").write_text(
"\n".join(requirements) + ("\n" if requirements else ""),
encoding="utf-8",
)
paper_title = _extract_paper_title(final_paper)
file_list_md = "\n".join(f"- `{f}`" for f in code_file_names)
readme = (
f"# Code Package for {paper_title}\n\n"
"## Description\n"
"This directory contains the experiment project used for the paper.\n\n"
"## Project Files\n"
f"{file_list_md}\n\n"
"## How to Run\n"
"`python main.py`\n\n"
"## Dependencies\n"
"Install dependencies with `pip install -r requirements.txt` if needed.\n"
)
(code_dir / "README.md").write_text(readme, encoding="utf-8")
artifacts.append("code/")
logger.info(
"Stage 22: Packaged multi-file code release (%d files, %d deps)",
len(code_file_names),
len(requirements),
)
else:
# Backward compat: single-file packaging
code_payload = _read_prior_artifact(run_dir, "experiment_final.py")
if not code_payload:
code_payload = _read_prior_artifact(run_dir, "experiment.py")
if code_payload:
import ast
code_dir = stage_dir / "code"
code_dir.mkdir(parents=True, exist_ok=True)
(code_dir / "experiment.py").write_text(code_payload, encoding="utf-8")
detected_single: set[str] = set()
known_packages_single = {
"numpy": "numpy",
"torch": "torch",
"tensorflow": "tensorflow",
"sklearn": "scikit-learn",
"scikit-learn": "scikit-learn",
"scipy": "scipy",
"pandas": "pandas",
"matplotlib": "matplotlib",
"seaborn": "seaborn",
"transformers": "transformers",
"datasets": "datasets",
"jax": "jax",
}
try:
tree = ast.parse(code_payload)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
top = alias.name.split(".")[0]
if top in known_packages_single:
detected_single.add(known_packages_single[top])
elif isinstance(node, ast.ImportFrom) and node.module:
top = node.module.split(".")[0]
if top in known_packages_single:
detected_single.add(known_packages_single[top])
except SyntaxError:
pass
requirements = sorted(detected_single)
(code_dir / "requirements.txt").write_text(
"\n".join(requirements) + ("\n" if requirements else ""),
encoding="utf-8",
)
paper_title = _extract_paper_title(final_paper)
readme = (
f"# Code Package for {paper_title}\n\n"
"## Description\n"
"This directory contains the final experiment script used for the paper.\n\n"
"## How to Run\n"
"`python experiment.py`\n\n"
"## Dependencies\n"
"Install dependencies with `pip install -r requirements.txt` if needed.\n"
)
(code_dir / "README.md").write_text(readme, encoding="utf-8")
artifacts.append("code/")
logger.info(
"Stage 22: Packaged single-file code release with %d deps",
len(requirements),
)
# WS-5.5: Generate framework diagram prompt for methodology section
try:
_framework_prompt = _generate_framework_diagram_prompt(
final_paper, config, llm=llm
)
if _framework_prompt:
_chart_dir = stage_dir / "charts"
_chart_dir.mkdir(parents=True, exist_ok=True)
(_chart_dir / "framework_diagram_prompt.md").write_text(
_framework_prompt, encoding="utf-8"
)
logger.info("Stage 22: Generated framework diagram prompt → charts/framework_diagram_prompt.md")
except Exception as exc: # noqa: BLE001
logger.debug("Stage 22: Framework diagram prompt generation skipped: %s", exc)
return StageResult(
stage=Stage.EXPORT_PUBLISH,
status=StageStatus.DONE,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-22/{a}" for a in artifacts),
)
# ---------------------------------------------------------------------------
# Citation helpers
# ---------------------------------------------------------------------------
def _check_citation_relevance(
llm: Any,
topic: str,
results: list[Any],
) -> dict[str, float | None]:
"""Use LLM to assess relevance of each citation to the research topic.
Returns a dict mapping cite_key → relevance score (0.0–1.0).
Processes citations in batches of 30 to handle large bibliographies.
"""
citation_lines = []
for cr in results:
citation_lines.append(f"- [{cr.cite_key}] \"{cr.title}\"")
if not citation_lines:
return {}
all_scores: dict[str, float] = {}
_BATCH_SIZE = 30
for batch_start in range(0, len(citation_lines), _BATCH_SIZE):
batch = citation_lines[batch_start:batch_start + _BATCH_SIZE]
citations_text = "\n".join(batch)
prompt = (
f"Research topic: {topic}\n\n"
f"Rate the relevance of each citation to the research topic "
f"on a scale of 0.0 to 1.0.\n"
f"Return ONLY a JSON object mapping cite_key to relevance score.\n"
f"Example: {{\"smith2020\": 0.9, \"jones2019\": 0.2}}\n\n"
f"Citations:\n{citations_text}"
)
try:
resp = llm.chat(
[{"role": "user", "content": prompt}],
system="You assess citation relevance. Return only valid JSON.",
json_mode=True,
)
parsed = _safe_json_loads(resp.content, {})
if isinstance(parsed, dict):
for k, v in parsed.items():
if isinstance(v, (int, float)):
all_scores[k] = max(0.0, min(1.0, float(v)))
except Exception: # noqa: BLE001
logger.debug(
"Citation relevance check failed for batch %d–%d, skipping",
batch_start, batch_start + len(batch),
)
return all_scores
def _remove_bibtex_entries(bib_text: str, keys_to_remove: set[str]) -> str:
"""Remove BibTeX entries whose keys are in *keys_to_remove*."""
kept: list[str] = []
for m in re.finditer(r"@\w+\{([^,]+),", bib_text):
key = m.group(1).strip()
if key in keys_to_remove:
continue
# Find the full entry (from @ to the next @ or end)
start = m.start()
# Find balanced braces
depth = 0
end = start
for i in range(start, len(bib_text)):
if bib_text[i] == "{":
depth += 1
elif bib_text[i] == "}":
depth -= 1
if depth == 0:
end = i + 1
break
if end > start:
kept.append(bib_text[start:end])
return "\n\n".join(kept) + "\n" if kept else ""
def _remove_citations_from_text(text: str, keys_to_remove: set[str]) -> str:
"""Remove \\cite{key} and [key] references for specified citation keys."""
# Handle multi-key LaTeX cites: \cite{a,b,c} → filter keys inside braces
def _filter_cite(m: re.Match[str]) -> str:
keys = [k.strip() for k in m.group(1).split(",")]
kept = [k for k in keys if k not in keys_to_remove]
if not kept:
return ""
return f"\\cite{{{','.join(kept)}}}"
text = re.sub(r"\\cite\{([^}]+)\}", _filter_cite, text)
# Markdown: [key]
for key in keys_to_remove:
text = re.sub(rf"\[{re.escape(key)}\]", "", text)
return text
# ---------------------------------------------------------------------------
# Stage 23: Citation Verify
# ---------------------------------------------------------------------------
def _execute_citation_verify(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
from researchclaw.literature.verify import (
VerifyStatus,
annotate_paper_hallucinations,
filter_verified_bibtex,
verify_citations,
)
bib_text = _read_prior_artifact(run_dir, "references.bib") or ""
paper_text = _read_prior_artifact(run_dir, "paper_final.md") or ""
if not bib_text.strip():
report_data = {
"summary": {
"total": 0,
"verified": 0,
"suspicious": 0,
"hallucinated": 0,
"skipped": 0,
"integrity_score": 1.0,
},
"results": [],
"note": "No references.bib found — nothing to verify.",
}
(stage_dir / "verification_report.json").write_text(
json.dumps(report_data, indent=2), encoding="utf-8"
)
(stage_dir / "references_verified.bib").write_text(
"% No references to verify\n", encoding="utf-8"
)
return StageResult(
stage=Stage.CITATION_VERIFY,
status=StageStatus.DONE,
artifacts=("verification_report.json", "references_verified.bib"),
evidence_refs=(
"stage-23/verification_report.json",
"stage-23/references_verified.bib",
),
)
s2_api_key = getattr(config.llm, "s2_api_key", "") or ""
from researchclaw.literature.verify import parse_bibtex_entries
_n_entries = len(parse_bibtex_entries(bib_text))
logger.info(
"[citation-verify] Verifying %d references "
"(DOI→CrossRef > OpenAlex > arXiv > S2)…",
_n_entries,
)
report = verify_citations(bib_text, s2_api_key=s2_api_key)
logger.info(
"[citation-verify] Done: %d verified, %d suspicious, "
"%d hallucinated, %d skipped (integrity: %.0f%%)",
report.verified,
report.suspicious,
report.hallucinated,
report.skipped,
report.integrity_score * 100,
)
# --- Relevance check: assess topical relevance of verified citations ---
if llm is not None and report.results:
relevance_scores = _check_citation_relevance(
llm, config.research.topic, report.results
)
for cr in report.results:
score = relevance_scores.get(cr.cite_key)
if score is not None:
cr.relevance_score = score
# FIX-5: Filter low-relevance citations and enforce hard cap
RELEVANCE_THRESHOLD = 0.5
MAX_CITATIONS = 60
low_relevance_keys: set[str] = set()
for cr in report.results:
if cr.relevance_score is not None and cr.relevance_score < RELEVANCE_THRESHOLD:
low_relevance_keys.add(cr.cite_key)
# Hard cap: if still above MAX_CITATIONS after relevance filter, drop lowest
# BUG-07 fix: Unscored citations (relevance_score=None) default to 0.7
# because they passed API verification and are likely relevant.
# Previously they defaulted to 0.0 which caused mass-deletion.
_DEFAULT_RELEVANCE = 0.7
remaining = [
cr for cr in report.results
if cr.cite_key not in low_relevance_keys
and cr.status != VerifyStatus.HALLUCINATED
]
if len(remaining) > MAX_CITATIONS:
remaining.sort(
key=lambda c: c.relevance_score if c.relevance_score is not None else _DEFAULT_RELEVANCE,
)
overflow = remaining[:len(remaining) - MAX_CITATIONS]
for cr in overflow:
low_relevance_keys.add(cr.cite_key)
logger.info(
"Stage 23: Hard cap applied, dropping %d additional low-relevance citations",
len(overflow),
)
if low_relevance_keys:
logger.info(
"Stage 23: Filtering %d low-relevance citations (threshold=%.1f, cap=%d): %s",
len(low_relevance_keys),
RELEVANCE_THRESHOLD,
MAX_CITATIONS,
", ".join(sorted(list(low_relevance_keys)[:20])),
)
(stage_dir / "verification_report.json").write_text(
json.dumps(report.to_dict(), indent=2), encoding="utf-8"
)
verified_bib = filter_verified_bibtex(bib_text, report, include_suspicious=True)
# Remove low-relevance entries from BibTeX
if low_relevance_keys:
verified_bib = _remove_bibtex_entries(verified_bib, low_relevance_keys)
# BUG-26: If verification stripped >50% of entries (e.g. due to rate limiting),
# fall back to the original bib to avoid breaking the paper's references
original_count = len(re.findall(r"@\w+\{", bib_text))
verified_count = len(re.findall(r"@\w+\{", verified_bib))
if original_count > 0 and verified_count < original_count * 0.5:
logger.warning(
"Stage 23: Verification stripped %d→%d entries (>50%% loss). "
"Keeping original bib to avoid breaking references.",
original_count, verified_count,
)
verified_bib = bib_text
# IMP-1: Also prune uncited entries from verified bib
# BUG-182: Also scan LaTeX paper.tex (not just Markdown) for \cite{} keys.
# The Markdown version may use [key] notation while LaTeX uses \cite{key}.
if paper_text.strip():
_vbib_keys = set(re.findall(r"@\w+\{([^,]+),", verified_bib))
_cited_in_paper: set[str] = set()
_cited_in_paper.update(
re.findall(r"\[([a-zA-Z]+\d{4}[a-zA-Z0-9_-]*)\]", paper_text)
)
for _cm in re.finditer(r"\\cite\{([^}]+)\}", paper_text):
_cited_in_paper.update(
k.strip() for k in _cm.group(1).split(",")
)
# BUG-182: Also read stage-22/paper.tex for \cite{} keys
_latex_paper = stage_dir.parent / "stage-22" / "paper.tex"
if _latex_paper.exists():
try:
_latex_text = _latex_paper.read_text(encoding="utf-8")
for _cm in re.finditer(r"\\cite[pt]?\{([^}]+)\}", _latex_text):
_cited_in_paper.update(
k.strip() for k in _cm.group(1).split(",")
)
except OSError:
pass
_uncited_vbib = _vbib_keys - _cited_in_paper
if _uncited_vbib:
verified_bib = _remove_bibtex_entries(verified_bib, _uncited_vbib)
logger.info(
"Stage 23: Pruned %d uncited entries from verified bib "
"(kept %d)",
len(_uncited_vbib),
len(_vbib_keys) - len(_uncited_vbib),
)
# BUG-100: If all entries were filtered out (low-relevance + uncited pruning),
# write a comment instead of an empty file to avoid "Missing or empty output" error.
if not verified_bib.strip():
verified_bib = "% All citations were filtered out during verification\n"
logger.warning(
"Stage 23: All BibTeX entries filtered out — writing placeholder"
)
(stage_dir / "references_verified.bib").write_text(verified_bib, encoding="utf-8")
artifacts = ["verification_report.json", "references_verified.bib"]
if paper_text.strip():
annotated = annotate_paper_hallucinations(paper_text, report)
# Remove \cite{} and [cite_key] references for low-relevance entries
if low_relevance_keys:
annotated = _remove_citations_from_text(annotated, low_relevance_keys)
(stage_dir / "paper_final_verified.md").write_text(annotated, encoding="utf-8")
artifacts.append("paper_final_verified.md")
logger.info(
"Stage 23 citation verify: %d total, %d verified, %d suspicious, "
"%d hallucinated, %d skipped (integrity=%.1f%%)",
report.total,
report.verified,
report.suspicious,
report.hallucinated,
report.skipped,
report.integrity_score * 100,
)
return StageResult(
stage=Stage.CITATION_VERIFY,
status=StageStatus.DONE,
artifacts=tuple(artifacts),
evidence_refs=tuple(f"stage-23/{a}" for a in artifacts),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_synthesis.py
================================================
"""Stages 7-8: Synthesis and hypothesis generation."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._helpers import (
StageResult,
_default_hypotheses,
_get_evolution_overlay,
_multi_perspective_generate,
_parse_jsonl_rows,
_read_prior_artifact,
_synthesize_perspectives,
_utcnow_iso,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
def _execute_synthesis(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
cards_path = _read_prior_artifact(run_dir, "cards/") or ""
cards_context = ""
if cards_path:
snippets: list[str] = []
for path in sorted(Path(cards_path).glob("*.md"))[:24]:
snippets.append(path.read_text(encoding="utf-8"))
cards_context = "\n\n".join(snippets)
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "synthesis")
sp = _pm.for_stage(
"synthesis",
evolution_overlay=_overlay,
topic=config.research.topic,
cards_context=cards_context,
)
resp = llm.chat(
[{"role": "user", "content": sp.user}],
system=sp.system,
max_tokens=sp.max_tokens or 8192,
)
synthesis_md = resp.content
else:
synthesis_md = f"""# Synthesis
## Cluster Overview
- Cluster A: Representation methods
- Cluster B: Training strategies
- Cluster C: Evaluation robustness
## Gap 1
Limited consistency across benchmark protocols.
## Gap 2
Under-reported failure behavior under distribution shift.
## Prioritized Opportunities
1. Unified experimental protocol
2. Robustness-aware evaluation suite
## Generated
{_utcnow_iso()}
"""
(stage_dir / "synthesis.md").write_text(synthesis_md, encoding="utf-8")
return StageResult(
stage=Stage.SYNTHESIS,
status=StageStatus.DONE,
artifacts=("synthesis.md",),
evidence_refs=("stage-07/synthesis.md",),
)
def _execute_hypothesis_gen(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
synthesis = _read_prior_artifact(run_dir, "synthesis.md") or ""
if llm is not None:
_pm = prompts or PromptManager()
from researchclaw.prompts import DEBATE_ROLES_HYPOTHESIS # noqa: PLC0415
# --- Multi-perspective debate ---
perspectives_dir = stage_dir / "perspectives"
variables = {"topic": config.research.topic, "synthesis": synthesis}
perspectives = _multi_perspective_generate(
llm, DEBATE_ROLES_HYPOTHESIS, variables, perspectives_dir
)
# BUG-S2: If all debate perspectives failed, fall back to defaults
# instead of sending empty context to the LLM (pure hallucination).
if not perspectives:
logger.warning("All debate perspectives failed; using default hypotheses")
hypotheses_md = _default_hypotheses(config.research.topic)
else:
# --- Synthesize into final hypotheses ---
hypotheses_md = _synthesize_perspectives(
llm, perspectives, "hypothesis_synthesize", _pm
)
else:
hypotheses_md = _default_hypotheses(config.research.topic)
(stage_dir / "hypotheses.md").write_text(hypotheses_md, encoding="utf-8")
# --- Novelty check (non-blocking) ---
novelty_artifacts: tuple[str, ...] = ()
try:
from researchclaw.literature.novelty import check_novelty # noqa: PLC0415
candidates_text = _read_prior_artifact(run_dir, "candidates.jsonl") or ""
papers_seen = _parse_jsonl_rows(candidates_text) if candidates_text else []
novelty_report = check_novelty(
topic=config.research.topic,
hypotheses_text=hypotheses_md,
papers_already_seen=papers_seen,
s2_api_key=getattr(config.llm, "s2_api_key", ""),
)
(stage_dir / "novelty_report.json").write_text(
json.dumps(novelty_report, indent=2, ensure_ascii=False),
encoding="utf-8",
)
novelty_artifacts = ("novelty_report.json",)
logger.info(
"Novelty check: score=%.3f assessment=%s recommendation=%s",
novelty_report["novelty_score"],
novelty_report["assessment"],
novelty_report["recommendation"],
)
except Exception: # noqa: BLE001
logger.warning("Novelty check failed (non-blocking)", exc_info=True)
return StageResult(
stage=Stage.HYPOTHESIS_GEN,
status=StageStatus.DONE,
artifacts=("hypotheses.md",) + novelty_artifacts,
evidence_refs=("stage-08/hypotheses.md",),
)
================================================
FILE: researchclaw/pipeline/stage_impls/_topic.py
================================================
"""Stages 1-2: Topic initialization and problem decomposition."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.hardware import detect_hardware, ensure_torch_available
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline._domain import _detect_domain
from researchclaw.pipeline._helpers import (
StageResult,
_get_evolution_overlay,
_read_prior_artifact,
_safe_json_loads,
_utcnow_iso,
)
from researchclaw.pipeline.stages import Stage, StageStatus
from researchclaw.prompts import PromptManager
logger = logging.getLogger(__name__)
def _execute_topic_init(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
topic = config.research.topic
domains = (
", ".join(config.research.domains) if config.research.domains else "general"
)
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "topic_init")
sp = _pm.for_stage(
"topic_init",
evolution_overlay=_overlay,
topic=topic,
domains=domains,
project_name=config.project.name,
quality_threshold=config.research.quality_threshold,
)
resp = llm.chat(
[{"role": "user", "content": sp.user}],
system=sp.system,
)
goal_md = resp.content
else:
goal_md = f"""# Research Goal
## Topic
{topic}
## Scope
Investigate the topic with emphasis on reproducible methods and measurable outcomes.
## SMART Goal
- Specific: Build a focused research plan for {topic}
- Measurable: Produce literature shortlist, hypotheses, experiment plan, and final paper
- Achievable: Complete through staged pipeline with gate checks
- Relevant: Aligned with project {config.project.name}
- Time-bound: Constrained by pipeline execution budget
## Constraints
- Quality threshold: {config.research.quality_threshold}
- Daily paper target: {config.research.daily_paper_count}
## Success Criteria
- At least 2 falsifiable hypotheses
- Executable experiment code and results analysis
- Revised paper passing quality gate
## Generated
{_utcnow_iso()}
"""
(stage_dir / "goal.md").write_text(goal_md, encoding="utf-8")
# --- Hardware detection (GPU / MPS / CPU) ---
hw = detect_hardware()
(stage_dir / "hardware_profile.json").write_text(
json.dumps(hw.to_dict(), indent=2), encoding="utf-8"
)
if hw.warning:
logger.warning("Hardware advisory: %s", hw.warning)
else:
logger.info("Hardware detected: %s (%s, %s MB VRAM)", hw.gpu_name, hw.gpu_type, hw.vram_mb)
# --- Optionally ensure PyTorch is available ---
if hw.has_gpu and config.experiment.mode == "sandbox":
torch_ok = ensure_torch_available(config.experiment.sandbox.python_path, hw.gpu_type)
if torch_ok:
logger.info("PyTorch is available for sandbox experiments")
else:
logger.warning("PyTorch could not be installed; sandbox will use CPU-only packages")
elif hw.has_gpu and config.experiment.mode == "docker":
logger.info("Docker sandbox: PyTorch pre-installed in container image")
return StageResult(
stage=Stage.TOPIC_INIT,
status=StageStatus.DONE,
artifacts=("goal.md", "hardware_profile.json"),
evidence_refs=("stage-01/goal.md", "stage-01/hardware_profile.json"),
)
def _execute_problem_decompose(
stage_dir: Path,
run_dir: Path,
config: RCConfig,
adapters: AdapterBundle,
*,
llm: LLMClient | None = None,
prompts: PromptManager | None = None,
) -> StageResult:
goal_text = _read_prior_artifact(run_dir, "goal.md") or ""
if llm is not None:
_pm = prompts or PromptManager()
_overlay = _get_evolution_overlay(run_dir, "problem_decompose")
sp = _pm.for_stage(
"problem_decompose",
evolution_overlay=_overlay,
topic=config.research.topic,
goal_text=goal_text,
)
resp = llm.chat(
[{"role": "user", "content": sp.user}],
system=sp.system,
)
body = resp.content
else:
body = f"""# Problem Decomposition
## Source
Derived from `goal.md` for topic: {config.research.topic}
## Sub-questions
1. Which problem settings and benchmarks define current SOTA?
2. Which methodological gaps remain unresolved?
3. Which hypotheses are testable under realistic constraints?
4. Which datasets and metrics best discriminate method quality?
5. Which failure modes can invalidate expected gains?
## Priority Ranking
1. Problem framing and benchmark setup
2. Gap identification and hypothesis formulation
3. Experiment and metric design
4. Failure analysis and robustness checks
## Risks
- Ambiguous task definition
- Dataset leakage or metric mismatch
## Generated
{_utcnow_iso()}
"""
(stage_dir / "problem_tree.md").write_text(body, encoding="utf-8")
# IMP-35: Topic/title quality pre-evaluation
# Quick LLM check: is the topic well-scoped for a conference paper?
if llm is not None:
try:
_eval_resp = llm.chat(
[
{
"role": "user",
"content": (
"Evaluate this research topic for a top ML conference paper. "
"Score 1-10 on: (a) novelty, (b) specificity, (c) feasibility. "
"If overall score < 5, suggest a refined topic.\n\n"
f"Topic: {config.research.topic}\n\n"
"Reply as JSON: {\"novelty\": N, \"specificity\": N, "
"\"feasibility\": N, \"overall\": N, \"suggestion\": \"...\"}"
),
}
],
system=(
f"You are a senior {_detect_domain(config.research.topic, config.research.domains)[1]} "
f"researcher evaluating research topic quality."
),
)
_eval_data = _safe_json_loads(_eval_resp.content, {})
if isinstance(_eval_data, dict):
overall = _eval_data.get("overall", 10)
if isinstance(overall, (int, float)) and overall < 5:
logger.warning(
"IMP-35: Topic quality score %s/10 — consider refining: %s",
overall,
_eval_data.get("suggestion", ""),
)
else:
logger.info("IMP-35: Topic quality score %s/10", overall)
(stage_dir / "topic_evaluation.json").write_text(
json.dumps(_eval_data, indent=2), encoding="utf-8"
)
except Exception: # noqa: BLE001
logger.debug("IMP-35: Topic evaluation skipped (non-blocking)")
return StageResult(
stage=Stage.PROBLEM_DECOMPOSE,
status=StageStatus.DONE,
artifacts=("problem_tree.md",),
evidence_refs=("stage-02/problem_tree.md",),
)
================================================
FILE: researchclaw/pipeline/stages.py
================================================
"""23-stage ResearchClaw pipeline state machine.
Defines the stage sequence, status transitions, gate logic, and rollback rules.
Migrated from arc/state_machine.py (19 stages) with the following changes:
- SEARCH_PLAN + SOURCE_CONNECT → SEARCH_STRATEGY
- RELEVANCE_SCREEN + QUALITY_SCREEN → LITERATURE_SCREEN
- CLUSTER_TOPICS + GAP_ANALYSIS → SYNTHESIS
- EXPERIMENT_DESIGN split → EXPERIMENT_DESIGN + CODE_GENERATION
- EXECUTE split → EXPERIMENT_RUN + ITERATIVE_REFINE
- WRITE_DRAFT split → PAPER_OUTLINE + PAPER_DRAFT
- Added PAPER_REVISION, QUALITY_GATE, EXPORT_PUBLISH
- RETROSPECTIVE_ARCHIVE split → KNOWLEDGE_ARCHIVE (+ QUALITY_GATE + EXPORT_PUBLISH)
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Iterable
class Stage(IntEnum):
"""23-stage research pipeline."""
# Phase A: Research Scoping
TOPIC_INIT = 1
PROBLEM_DECOMPOSE = 2
# Phase B: Literature Discovery
SEARCH_STRATEGY = 3
LITERATURE_COLLECT = 4
LITERATURE_SCREEN = 5 # GATE
KNOWLEDGE_EXTRACT = 6
# Phase C: Knowledge Synthesis
SYNTHESIS = 7
HYPOTHESIS_GEN = 8
# Phase D: Experiment Design
EXPERIMENT_DESIGN = 9 # GATE
CODE_GENERATION = 10 # NEW
RESOURCE_PLANNING = 11
# Phase E: Experiment Execution
EXPERIMENT_RUN = 12
ITERATIVE_REFINE = 13 # NEW
# Phase F: Analysis & Decision
RESULT_ANALYSIS = 14
RESEARCH_DECISION = 15
# Phase G: Paper Writing
PAPER_OUTLINE = 16
PAPER_DRAFT = 17
PEER_REVIEW = 18
PAPER_REVISION = 19 # NEW
# Phase H: Finalization
QUALITY_GATE = 20 # GATE
KNOWLEDGE_ARCHIVE = 21
EXPORT_PUBLISH = 22
CITATION_VERIFY = 23
class StageStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
BLOCKED_APPROVAL = "blocked_approval"
APPROVED = "approved"
REJECTED = "rejected"
PAUSED = "paused"
RETRYING = "retrying"
FAILED = "failed"
DONE = "done"
class TransitionEvent(str, Enum):
START = "start"
SUCCEED = "succeed"
APPROVE = "approve"
REJECT = "reject"
TIMEOUT = "timeout"
FAIL = "fail"
RETRY = "retry"
RESUME = "resume"
PAUSE = "pause"
# ---------------------------------------------------------------------------
# Stage navigation
# ---------------------------------------------------------------------------
STAGE_SEQUENCE: tuple[Stage, ...] = tuple(Stage)
NEXT_STAGE: dict[Stage, Stage | None] = {
stage: STAGE_SEQUENCE[idx + 1] if idx + 1 < len(STAGE_SEQUENCE) else None
for idx, stage in enumerate(STAGE_SEQUENCE)
}
PREVIOUS_STAGE: dict[Stage, Stage | None] = {
stage: STAGE_SEQUENCE[idx - 1] if idx > 0 else None
for idx, stage in enumerate(STAGE_SEQUENCE)
}
# ---------------------------------------------------------------------------
# Gate stages — require approval before proceeding
# ---------------------------------------------------------------------------
GATE_STAGES: frozenset[Stage] = frozenset(
{
Stage.LITERATURE_SCREEN,
Stage.EXPERIMENT_DESIGN,
Stage.QUALITY_GATE,
}
)
# Gate rollback targets: when a gate rejects, where to roll back
GATE_ROLLBACK: dict[Stage, Stage] = {
Stage.LITERATURE_SCREEN: Stage.LITERATURE_COLLECT, # reject → re-collect
Stage.EXPERIMENT_DESIGN: Stage.HYPOTHESIS_GEN, # reject → re-hypothesize
Stage.QUALITY_GATE: Stage.PAPER_OUTLINE, # reject → rewrite paper
}
# ---------------------------------------------------------------------------
# Research decision rollback targets (PIVOT/REFINE from Stage 15)
# ---------------------------------------------------------------------------
DECISION_ROLLBACK: dict[str, Stage] = {
"pivot": Stage.HYPOTHESIS_GEN, # Discard hypotheses, re-generate
"refine": Stage.ITERATIVE_REFINE, # Keep hypotheses, re-run experiments
}
MAX_DECISION_PIVOTS: int = 2 # Prevent infinite loops
# ---------------------------------------------------------------------------
# Noncritical stages — can be skipped on failure without aborting pipeline
# ---------------------------------------------------------------------------
NONCRITICAL_STAGES: frozenset[Stage] = frozenset(
{
Stage.QUALITY_GATE, # 20: low quality should warn, not block deliverables
Stage.KNOWLEDGE_ARCHIVE, # 21: archival doesn't affect paper output
# T3.4: CITATION_VERIFY removed — hallucinated citations MUST block export
}
)
# ---------------------------------------------------------------------------
# Phase groupings (for UI and reporting)
# ---------------------------------------------------------------------------
PHASE_MAP: dict[str, tuple[Stage, ...]] = {
"A: Research Scoping": (Stage.TOPIC_INIT, Stage.PROBLEM_DECOMPOSE),
"B: Literature Discovery": (
Stage.SEARCH_STRATEGY,
Stage.LITERATURE_COLLECT,
Stage.LITERATURE_SCREEN,
Stage.KNOWLEDGE_EXTRACT,
),
"C: Knowledge Synthesis": (Stage.SYNTHESIS, Stage.HYPOTHESIS_GEN),
"D: Experiment Design": (
Stage.EXPERIMENT_DESIGN,
Stage.CODE_GENERATION,
Stage.RESOURCE_PLANNING,
),
"E: Experiment Execution": (Stage.EXPERIMENT_RUN, Stage.ITERATIVE_REFINE),
"F: Analysis & Decision": (Stage.RESULT_ANALYSIS, Stage.RESEARCH_DECISION),
"G: Paper Writing": (
Stage.PAPER_OUTLINE,
Stage.PAPER_DRAFT,
Stage.PEER_REVIEW,
Stage.PAPER_REVISION,
),
"H: Finalization": (
Stage.QUALITY_GATE,
Stage.KNOWLEDGE_ARCHIVE,
Stage.EXPORT_PUBLISH,
Stage.CITATION_VERIFY,
),
}
# ---------------------------------------------------------------------------
# Transition logic
# ---------------------------------------------------------------------------
TRANSITION_MAP: dict[StageStatus, frozenset[StageStatus]] = {
StageStatus.PENDING: frozenset({StageStatus.RUNNING}),
StageStatus.RUNNING: frozenset(
{StageStatus.DONE, StageStatus.BLOCKED_APPROVAL, StageStatus.FAILED}
),
StageStatus.BLOCKED_APPROVAL: frozenset(
{StageStatus.APPROVED, StageStatus.REJECTED, StageStatus.PAUSED}
),
StageStatus.APPROVED: frozenset({StageStatus.DONE}),
StageStatus.REJECTED: frozenset({StageStatus.PENDING}),
StageStatus.PAUSED: frozenset({StageStatus.RUNNING}),
StageStatus.RETRYING: frozenset({StageStatus.RUNNING}),
StageStatus.FAILED: frozenset({StageStatus.RETRYING, StageStatus.PAUSED}),
StageStatus.DONE: frozenset(),
}
@dataclass(frozen=True)
class TransitionOutcome:
stage: Stage
status: StageStatus
next_stage: Stage | None
rollback_stage: Stage | None = None
checkpoint_required: bool = False
decision: str = "proceed"
def gate_required(
stage: Stage,
hitl_required_stages: Iterable[int] | None = None,
) -> bool:
"""Check whether a stage requires human-in-the-loop approval."""
if stage not in GATE_STAGES:
return False
if hitl_required_stages is not None:
return int(stage) in frozenset(hitl_required_stages)
return True # Default: all gate stages require approval
def default_rollback_stage(stage: Stage) -> Stage:
"""Return the configured rollback target, or the previous stage."""
return GATE_ROLLBACK.get(stage) or PREVIOUS_STAGE.get(stage) or stage
def advance(
stage: Stage,
status: StageStatus,
event: TransitionEvent | str,
*,
hitl_required_stages: Iterable[int] | None = None,
rollback_stage: Stage | None = None,
) -> TransitionOutcome:
"""Compute the next state given current stage, status, and event.
Raises ValueError on unsupported transitions.
"""
event = TransitionEvent(event)
target_rollback = rollback_stage or default_rollback_stage(stage)
# START → RUNNING
if event is TransitionEvent.START and status in {
StageStatus.PENDING,
StageStatus.RETRYING,
StageStatus.PAUSED,
}:
return TransitionOutcome(
stage=stage, status=StageStatus.RUNNING, next_stage=stage
)
# SUCCEED while RUNNING
if event is TransitionEvent.SUCCEED and status is StageStatus.RUNNING:
if gate_required(stage, hitl_required_stages):
return TransitionOutcome(
stage=stage,
status=StageStatus.BLOCKED_APPROVAL,
next_stage=stage,
checkpoint_required=False,
decision="block",
)
return TransitionOutcome(
stage=stage,
status=StageStatus.DONE,
next_stage=NEXT_STAGE[stage],
checkpoint_required=True,
)
# APPROVE while BLOCKED
if event is TransitionEvent.APPROVE and status is StageStatus.BLOCKED_APPROVAL:
return TransitionOutcome(
stage=stage,
status=StageStatus.DONE,
next_stage=NEXT_STAGE[stage],
checkpoint_required=True,
)
# REJECT while BLOCKED → rollback
if event is TransitionEvent.REJECT and status is StageStatus.BLOCKED_APPROVAL:
return TransitionOutcome(
stage=target_rollback,
status=StageStatus.PENDING,
next_stage=target_rollback,
rollback_stage=target_rollback,
checkpoint_required=True,
decision="pivot",
)
# TIMEOUT while BLOCKED → pause
if event is TransitionEvent.TIMEOUT and status is StageStatus.BLOCKED_APPROVAL:
return TransitionOutcome(
stage=stage,
status=StageStatus.PAUSED,
next_stage=stage,
checkpoint_required=True,
decision="block",
)
# FAIL while RUNNING
if event is TransitionEvent.FAIL and status is StageStatus.RUNNING:
return TransitionOutcome(
stage=stage,
status=StageStatus.FAILED,
next_stage=stage,
checkpoint_required=True,
decision="retry",
)
# RETRY while FAILED
if event is TransitionEvent.RETRY and status is StageStatus.FAILED:
return TransitionOutcome(
stage=stage,
status=StageStatus.RETRYING,
next_stage=stage,
decision="retry",
)
# RESUME while PAUSED
if event is TransitionEvent.RESUME and status is StageStatus.PAUSED:
return TransitionOutcome(
stage=stage, status=StageStatus.RUNNING, next_stage=stage
)
# PAUSE while FAILED
if event is TransitionEvent.PAUSE and status is StageStatus.FAILED:
return TransitionOutcome(
stage=stage,
status=StageStatus.PAUSED,
next_stage=stage,
checkpoint_required=True,
decision="block",
)
raise ValueError(
f"Unsupported transition: {status.value} + {event.value} for stage {int(stage)}"
)
================================================
FILE: researchclaw/pipeline/verified_registry.py
================================================
"""Verified Value Registry — ground truth for all experiment-sourced numbers.
Builds a whitelist of numeric values, condition names, and training config
from ``experiment_summary.json`` and ``refinement_log.json``. Used by
``paper_verifier.py`` and ``results_table_builder.py`` to ensure that
generated papers contain ONLY numbers grounded in real experiment data.
"""
from __future__ import annotations
import logging
import math
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# Infrastructure metric keys — allowed in paper without verification
_INFRA_KEYS: set[str] = {
"elapsed_sec",
"total_elapsed_seconds",
"TIME_ESTIMATE",
"SEED_COUNT",
"time_budget_sec",
"condition_count",
"total_runs",
"total_conditions",
"total_metric_keys",
"stopped_early",
}
# Metric key patterns for per-seed results (e.g. "DQN/0/metric")
_PER_SEED_PATTERN = re.compile(r"^(.+)/(\d+)/(.+)$")
@dataclass
class ConditionResult:
"""Aggregated results for one experimental condition."""
name: str
per_seed_values: dict[int, float] = field(default_factory=dict)
mean: float | None = None
std: float | None = None
n_seeds: int = 0
aggregate_metric: float | None = None # The condition-level metric
def compute_stats(self) -> None:
"""Compute mean and std from per-seed values."""
vals = [v for v in self.per_seed_values.values() if _is_finite(v)]
self.n_seeds = len(vals)
if not vals:
return
self.mean = sum(vals) / len(vals)
if len(vals) >= 2:
variance = sum((v - self.mean) ** 2 for v in vals) / (len(vals) - 1)
self.std = math.sqrt(variance)
else:
self.std = 0.0
@dataclass
class VerifiedRegistry:
"""Registry of all numbers grounded in experiment data."""
values: dict[float, str] = field(default_factory=dict)
condition_names: set[str] = field(default_factory=set)
conditions: dict[str, ConditionResult] = field(default_factory=dict)
primary_metric: float | None = None
primary_metric_std: float | None = None
metric_direction: str = "maximize" # "maximize" or "minimize"
training_config: dict[str, Any] = field(default_factory=dict)
def add_value(self, value: float, source: str) -> None:
"""Register a verified numeric value with its provenance."""
if not _is_finite(value):
return
self.values[value] = source
# Also register common transformations
self._add_variants(value, source)
def _add_variants(self, value: float, source: str) -> None:
"""Register rounding variants and percentage conversions."""
# Rounded variants (2, 3, 4 decimal places)
for dp in (1, 2, 3, 4):
rounded = round(value, dp)
if rounded != value and rounded not in self.values:
self.values[rounded] = f"{source} (rounded to {dp}dp)"
# Percentage conversion: if value is in [0, 1], also register value*100
if 0.0 < abs(value) <= 1.0:
pct = value * 100.0
if pct not in self.values:
self.values[pct] = f"{source} (×100)"
for dp in (1, 2, 3, 4):
pct_r = round(pct, dp)
if pct_r not in self.values:
self.values[pct_r] = f"{source} (×100, {dp}dp)"
# If value > 1 and could be a percentage, also register value/100
if abs(value) > 1.0:
frac = value / 100.0
if frac not in self.values:
self.values[frac] = f"{source} (÷100)"
def is_verified(self, number: float, tolerance: float = 0.01) -> bool:
"""Check if *number* matches any verified value within relative tolerance."""
if not _is_finite(number):
return False
for v in self.values:
if v == 0.0:
if abs(number) < 1e-6:
return True
elif abs(number - v) / max(abs(v), 1e-9) <= tolerance:
return True
return False
def lookup(self, number: float, tolerance: float = 0.01) -> str | None:
"""Return the source description if *number* is verified, else None."""
if not _is_finite(number):
return None
for v, src in self.values.items():
if v == 0.0:
if abs(number) < 1e-6:
return src
elif abs(number - v) / max(abs(v), 1e-9) <= tolerance:
return src
return None
def verify_condition(self, name: str) -> bool:
"""Check if condition name was actually run."""
return name in self.condition_names
@classmethod
def from_experiment(
cls,
experiment_summary: dict,
refinement_log: dict | None = None,
*,
metric_direction: str = "maximize",
) -> VerifiedRegistry:
"""Build registry from experiment artifacts.
Parameters
----------
experiment_summary:
Parsed ``experiment_summary.json``.
refinement_log:
Parsed ``refinement_log.json`` (optional, provides richer per-seed data).
metric_direction:
``"maximize"`` or ``"minimize"`` — used for best-result detection.
"""
reg = cls(metric_direction=metric_direction)
# --- 1. Extract condition-level and per-seed metrics ---
best_run = experiment_summary.get("best_run", {})
metrics = best_run.get("metrics", {})
# Parse per-seed structure: "CondName/seed/metric_key" → value
for key, value in metrics.items():
if not isinstance(value, (int, float)) or not _is_finite(value):
continue
if key in _INFRA_KEYS:
reg.training_config[key] = value
continue
reg.add_value(value, f"best_run.metrics.{key}")
m = _PER_SEED_PATTERN.match(key)
if m:
cond_name, seed_str, _metric_name = m.group(1), m.group(2), m.group(3)
seed_idx = int(seed_str)
if cond_name not in reg.conditions:
reg.conditions[cond_name] = ConditionResult(name=cond_name)
reg.conditions[cond_name].per_seed_values[seed_idx] = value
reg.condition_names.add(cond_name)
# --- 2. Extract condition_summaries ---
for cond_name, cond_data in experiment_summary.get("condition_summaries", {}).items():
reg.condition_names.add(cond_name)
if cond_name not in reg.conditions:
reg.conditions[cond_name] = ConditionResult(name=cond_name)
cond_metrics = cond_data.get("metrics", {})
for mk, mv in cond_metrics.items():
if isinstance(mv, (int, float)) and _is_finite(mv):
reg.add_value(mv, f"condition_summaries.{cond_name}.{mk}")
reg.conditions[cond_name].aggregate_metric = mv
# --- 3. Extract metrics_summary (min/max/mean per key) ---
for key, stats in experiment_summary.get("metrics_summary", {}).items():
if key in _INFRA_KEYS:
continue
for stat_name in ("min", "max", "mean"):
v = stats.get(stat_name)
if isinstance(v, (int, float)) and _is_finite(v):
reg.add_value(v, f"metrics_summary.{key}.{stat_name}")
# --- 4. Extract primary_metric ---
pm = _extract_primary_metric(metrics)
if pm is not None:
reg.primary_metric = pm
reg.add_value(pm, "primary_metric")
pm_std = metrics.get("primary_metric_std")
if isinstance(pm_std, (int, float)) and _is_finite(pm_std):
reg.primary_metric_std = pm_std
reg.add_value(pm_std, "primary_metric_std")
# --- 5. Compute per-condition stats ---
for cond in reg.conditions.values():
cond.compute_stats()
if cond.mean is not None:
reg.add_value(cond.mean, f"{cond.name}.mean")
if cond.std is not None and cond.std > 0:
reg.add_value(cond.std, f"{cond.name}.std")
# --- 6. Compute pairwise differences (for comparative claims) ---
cond_list = [c for c in reg.conditions.values() if c.mean is not None]
for i, c1 in enumerate(cond_list):
for c2 in cond_list[i + 1 :]:
diff = c1.mean - c2.mean # type: ignore[operator]
if _is_finite(diff):
reg.add_value(diff, f"diff({c1.name}-{c2.name})")
reg.add_value(abs(diff), f"|diff({c1.name},{c2.name})|")
# Relative improvement
if c2.mean and abs(c2.mean) > 1e-9: # type: ignore[operator]
rel = (c1.mean - c2.mean) / abs(c2.mean) * 100.0 # type: ignore[operator]
if _is_finite(rel):
reg.add_value(rel, f"rel_improve({c1.name} vs {c2.name})")
reg.add_value(abs(rel), f"|rel_improve({c1.name},{c2.name})|")
# --- 7. Enrich from refinement_log (best iteration only) ---
if refinement_log:
_enrich_from_refinement_log(reg, refinement_log)
logger.info(
"VerifiedRegistry: %d values, %d conditions (%s), primary_metric=%s",
len(reg.values),
len(reg.condition_names),
", ".join(sorted(reg.condition_names)),
reg.primary_metric,
)
return reg
@classmethod
def from_run_dir(
cls,
run_dir: Path,
*,
metric_direction: str = "maximize",
best_only: bool = False,
) -> VerifiedRegistry:
"""Build registry from experiment data sources in *run_dir*.
Parameters
----------
best_only:
BUG-222: When True, use ONLY ``experiment_summary_best.json``
(the promoted best iteration) as the ground truth. This prevents
regressed REFINE iterations from polluting the verified value set.
When False (default), merges all ``stage-14*`` data for backward
compatibility (e.g., pre-built table generation that needs all
condition names).
Scans (when ``best_only=False``):
1. All ``stage-14*/experiment_summary.json`` (sorted, every version)
2. ``experiment_summary_best.json`` at run root (repair cycle output)
3. All ``stage-13*/refinement_log.json`` for enrichment
"""
import json as _json_rd
target = cls(metric_direction=metric_direction)
if best_only:
# BUG-222: Only use promoted best data
best_path = run_dir / "experiment_summary_best.json"
if best_path.is_file():
try:
best_data = _json_rd.loads(best_path.read_text(encoding="utf-8"))
if isinstance(best_data, dict):
sub = cls.from_experiment(best_data, metric_direction=metric_direction)
_merge_into(target, sub)
logger.debug("from_run_dir(best_only): using experiment_summary_best.json (%d values)", len(sub.values))
except (OSError, _json_rd.JSONDecodeError, Exception): # noqa: BLE001
logger.debug("from_run_dir(best_only): failed to load experiment_summary_best.json", exc_info=True)
if not target.values:
# Fallback: no best.json or it was empty — use stage-14/ (non-versioned)
s14_path = run_dir / "stage-14" / "experiment_summary.json"
if s14_path.is_file():
try:
es_data = _json_rd.loads(s14_path.read_text(encoding="utf-8"))
if isinstance(es_data, dict):
sub = cls.from_experiment(es_data, metric_direction=metric_direction)
_merge_into(target, sub)
except (OSError, _json_rd.JSONDecodeError, Exception): # noqa: BLE001
pass
else:
# --- 1. All stage-14* experiment summaries ---
for es_path in sorted(run_dir.glob("stage-14*/experiment_summary.json")):
try:
es_data = _json_rd.loads(es_path.read_text(encoding="utf-8"))
if not isinstance(es_data, dict):
continue
sub = cls.from_experiment(es_data, metric_direction=metric_direction)
_merge_into(target, sub)
logger.debug("from_run_dir: merged %s (%d values)", es_path.name, len(sub.values))
except (OSError, _json_rd.JSONDecodeError, Exception): # noqa: BLE001
logger.debug("from_run_dir: skipping %s", es_path, exc_info=True)
# --- 2. experiment_summary_best.json (repair cycle output) ---
best_path = run_dir / "experiment_summary_best.json"
if best_path.is_file():
try:
best_data = _json_rd.loads(best_path.read_text(encoding="utf-8"))
if isinstance(best_data, dict):
sub = cls.from_experiment(best_data, metric_direction=metric_direction)
_merge_into(target, sub)
logger.debug("from_run_dir: merged experiment_summary_best.json (%d values)", len(sub.values))
except (OSError, _json_rd.JSONDecodeError, Exception): # noqa: BLE001
logger.debug("from_run_dir: skipping experiment_summary_best.json", exc_info=True)
# --- 3. All refinement logs (enrichment) ---
for rl_path in sorted(run_dir.glob("stage-13*/refinement_log.json")):
try:
rl_data = _json_rd.loads(rl_path.read_text(encoding="utf-8"))
if isinstance(rl_data, dict):
_enrich_from_refinement_log(target, rl_data)
logger.debug("from_run_dir: enriched from %s", rl_path.name)
except (OSError, _json_rd.JSONDecodeError, Exception): # noqa: BLE001
logger.debug("from_run_dir: skipping %s", rl_path, exc_info=True)
# Recompute per-condition stats after merging
for cond in target.conditions.values():
cond.compute_stats()
if cond.mean is not None:
target.add_value(cond.mean, f"{cond.name}.mean")
if cond.std is not None and cond.std > 0:
target.add_value(cond.std, f"{cond.name}.std")
logger.info(
"VerifiedRegistry.from_run_dir(%s): %d values, %d conditions (%s)",
"best_only" if best_only else "all",
len(target.values),
len(target.condition_names),
", ".join(sorted(target.condition_names)) if target.condition_names else "none",
)
return target
@classmethod
def from_files(
cls,
experiment_summary_path: Path,
refinement_log_path: Path | None = None,
*,
metric_direction: str = "maximize",
) -> VerifiedRegistry:
"""Convenience: build registry from file paths."""
import json
exp_data = json.loads(experiment_summary_path.read_text(encoding="utf-8"))
ref_data = None
if refinement_log_path and refinement_log_path.exists():
ref_data = json.loads(refinement_log_path.read_text(encoding="utf-8"))
return cls.from_experiment(exp_data, ref_data, metric_direction=metric_direction)
def _merge_into(target: VerifiedRegistry, source: VerifiedRegistry) -> None:
"""Merge *source* values, conditions, and condition_names into *target*."""
for v, desc in source.values.items():
if v not in target.values:
target.values[v] = desc
target.condition_names |= source.condition_names
for cname, cresult in source.conditions.items():
if cname not in target.conditions:
target.conditions[cname] = ConditionResult(name=cname)
existing = target.conditions[cname]
# Merge per-seed values (source wins on conflict — later data is better)
existing.per_seed_values.update(cresult.per_seed_values)
if cresult.aggregate_metric is not None:
existing.aggregate_metric = cresult.aggregate_metric
# Keep the best primary metric
if source.primary_metric is not None:
if target.primary_metric is None:
target.primary_metric = source.primary_metric
elif target.metric_direction == "maximize":
target.primary_metric = max(target.primary_metric, source.primary_metric)
else:
target.primary_metric = min(target.primary_metric, source.primary_metric)
if source.primary_metric_std is not None:
# Only update std if the source's primary_metric actually won
if target.primary_metric == source.primary_metric:
target.primary_metric_std = source.primary_metric_std
target.training_config.update(source.training_config)
def _enrich_from_refinement_log(reg: VerifiedRegistry, refinement_log: dict) -> None:
"""Add values from the best refinement iteration."""
best_metric = refinement_log.get("best_metric")
if isinstance(best_metric, (int, float)) and _is_finite(best_metric):
reg.add_value(best_metric, "refinement_log.best_metric")
best_version = refinement_log.get("best_version", "")
iterations = refinement_log.get("iterations", [])
for it in iterations:
ver = it.get("version_dir", "")
metric = it.get("metric")
if isinstance(metric, (int, float)) and _is_finite(metric):
reg.add_value(metric, f"refinement_log.iteration.{ver}")
# Extract per-seed values from sandbox stdout if available
for sandbox_key in ("sandbox", "sandbox_after_fix"):
sandbox = it.get(sandbox_key, {})
if not isinstance(sandbox, dict):
continue
sb_metrics = sandbox.get("metrics", {})
if isinstance(sb_metrics, dict):
for mk, mv in sb_metrics.items():
if isinstance(mv, (int, float)) and _is_finite(mv) and mk not in _INFRA_KEYS:
reg.add_value(mv, f"refinement.{ver}.{sandbox_key}.{mk}")
# Parse per-seed keys here too
m = _PER_SEED_PATTERN.match(mk)
if m:
cond_name = m.group(1)
seed_idx = int(m.group(2))
reg.condition_names.add(cond_name)
if cond_name not in reg.conditions:
reg.conditions[cond_name] = ConditionResult(name=cond_name)
# Only update per_seed if this is the best version
if ver == best_version or best_version in ver:
reg.conditions[cond_name].per_seed_values[seed_idx] = mv
def _extract_primary_metric(metrics: dict) -> float | None:
"""Extract primary_metric from metrics dict."""
pm = metrics.get("primary_metric")
if isinstance(pm, (int, float)) and _is_finite(pm):
return float(pm)
return None
def _is_finite(value: Any) -> bool:
"""Check if value is a finite number (not NaN, not Inf, not bool)."""
if isinstance(value, bool):
return False
if not isinstance(value, (int, float)):
return False
return math.isfinite(value)
================================================
FILE: researchclaw/project/__init__.py
================================================
"""Multi-project management for AutoResearchClaw."""
from researchclaw.project.models import Idea, Project
from researchclaw.project.manager import ProjectManager
from researchclaw.project.scheduler import ProjectScheduler
from researchclaw.project.idea_pool import IdeaPool
__all__ = ["Idea", "Project", "ProjectManager", "ProjectScheduler", "IdeaPool"]
================================================
FILE: researchclaw/project/idea_pool.py
================================================
"""Idea pool: collect, evaluate, rank, and convert research ideas to projects."""
from __future__ import annotations
import json
import logging
import uuid
from pathlib import Path
from typing import Any
from researchclaw.project.models import Idea, Project
logger = logging.getLogger(__name__)
class IdeaPool:
"""Manage a pool of research ideas with evaluation and ranking."""
def __init__(self, pool_path: str | Path) -> None:
self.pool_path = Path(pool_path).expanduser().resolve()
self.ideas: dict[str, Idea] = {}
self._load()
# ── persistence ───────────────────────────────────────────────
def _load(self) -> None:
if not self.pool_path.exists():
return
try:
data = json.loads(self.pool_path.read_text(encoding="utf-8"))
for entry in data.get("ideas", []):
idea = Idea.from_dict(entry)
self.ideas[idea.id] = idea
except (json.JSONDecodeError, KeyError) as exc:
logger.warning("Failed to load idea pool: %s", exc)
def _save(self) -> None:
self.pool_path.parent.mkdir(parents=True, exist_ok=True)
data = {"ideas": [idea.to_dict() for idea in self.ideas.values()]}
self.pool_path.write_text(
json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8"
)
# ── CRUD ──────────────────────────────────────────────────────
def add(self, title: str, description: str, domains: list[str] | None = None) -> Idea:
"""Add a new idea to the pool."""
idea_id = uuid.uuid4().hex[:8]
idea = Idea(
id=idea_id,
title=title,
description=description,
domains=domains or [],
)
self.ideas[idea_id] = idea
self._save()
logger.info("Added idea %s: %s", idea_id, title)
return idea
def remove(self, idea_id: str) -> None:
"""Remove an idea from the pool."""
if idea_id not in self.ideas:
raise KeyError(f"Unknown idea: {idea_id}")
del self.ideas[idea_id]
self._save()
def get(self, idea_id: str) -> Idea:
"""Get an idea by ID."""
if idea_id not in self.ideas:
raise KeyError(f"Unknown idea: {idea_id}")
return self.ideas[idea_id]
# ── evaluation ────────────────────────────────────────────────
def evaluate(self, idea_id: str, feasibility: float, novelty: float) -> dict[str, Any]:
"""Set feasibility and novelty scores for an idea."""
idea = self.get(idea_id)
idea.feasibility = max(0.0, min(1.0, feasibility))
idea.novelty = max(0.0, min(1.0, novelty))
idea.status = "evaluated"
self._save()
return {
"id": idea.id,
"feasibility": idea.feasibility,
"novelty": idea.novelty,
"score": idea.score,
}
def rank(self) -> list[Idea]:
"""Return all ideas sorted by composite score (descending)."""
return sorted(self.ideas.values(), key=lambda i: i.score, reverse=True)
# ── conversion ────────────────────────────────────────────────
def to_project(self, idea_id: str, config_path: str, projects_dir: str | Path) -> Project:
"""Convert an idea into a project skeleton."""
idea = self.get(idea_id)
from researchclaw.project.manager import ProjectManager
manager = ProjectManager(projects_dir)
project = manager.create(
name=idea.title.lower().replace(" ", "_")[:40],
config_path=config_path,
topic=idea.description,
)
idea.status = "planned"
self._save()
return project
def list_all(self) -> list[Idea]:
"""Return all ideas sorted by creation time."""
return sorted(self.ideas.values(), key=lambda i: i.created_at)
================================================
FILE: researchclaw/project/manager.py
================================================
"""Project manager: CRUD operations and status tracking for research projects."""
from __future__ import annotations
import json
import logging
import shutil
from pathlib import Path
from typing import Any
from researchclaw.project.models import Project
logger = logging.getLogger(__name__)
_REGISTRY_FILE = "registry.json"
class ProjectManager:
"""Manage multiple research projects with independent directories and configs."""
def __init__(self, projects_dir: str | Path) -> None:
self.projects_dir = Path(projects_dir).expanduser().resolve()
self.projects: dict[str, Project] = {}
self._active: str | None = None
self._load_registry()
# ── persistence ───────────────────────────────────────────────
def _registry_path(self) -> Path:
return self.projects_dir / _REGISTRY_FILE
def _load_registry(self) -> None:
"""Load project registry from disk."""
path = self._registry_path()
if not path.exists():
return
try:
data = json.loads(path.read_text(encoding="utf-8"))
for entry in data.get("projects", []):
proj = Project.from_dict(entry)
self.projects[proj.name] = proj
self._active = data.get("active")
except (json.JSONDecodeError, KeyError) as exc:
logger.warning("Failed to load project registry: %s", exc)
def _save_registry(self) -> None:
"""Persist project registry to disk."""
self.projects_dir.mkdir(parents=True, exist_ok=True)
data = {
"active": self._active,
"projects": [p.to_dict() for p in self.projects.values()],
}
self._registry_path().write_text(
json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8"
)
# ── CRUD ──────────────────────────────────────────────────────
def create(
self,
name: str,
config_path: str,
topic: str | None = None,
) -> Project:
"""Create a new project with an independent directory and config copy."""
if name in self.projects:
raise ValueError(f"Project already exists: {name}")
project_dir = self.projects_dir / name
project_dir.mkdir(parents=True, exist_ok=True)
# Copy config to project directory
src = Path(config_path).expanduser().resolve()
if src.exists():
dst = project_dir / "config.yaml"
shutil.copy2(src, dst)
stored_config = str(dst)
else:
stored_config = config_path
run_dir = str(project_dir / "artifacts")
Path(run_dir).mkdir(parents=True, exist_ok=True)
project = Project(
name=name,
config_path=stored_config,
run_dir=run_dir,
topic=topic or "",
)
self.projects[name] = project
if self._active is None:
self._active = name
self._save_registry()
logger.info("Created project: %s", name)
return project
def delete(self, name: str) -> None:
"""Remove project from registry. Does NOT delete artifacts on disk."""
if name not in self.projects:
raise KeyError(f"Unknown project: {name}")
del self.projects[name]
if self._active == name:
self._active = next(iter(self.projects), None)
self._save_registry()
logger.info("Deleted project (registry only): %s", name)
def get(self, name: str) -> Project:
"""Get a single project by name."""
if name not in self.projects:
raise KeyError(f"Unknown project: {name}")
return self.projects[name]
def list_all(self) -> list[Project]:
"""Return all projects sorted by creation time."""
return sorted(self.projects.values(), key=lambda p: p.created_at)
def get_status(self) -> dict[str, Any]:
"""Summary of all project statuses."""
projects = self.list_all()
return {
"total": len(projects),
"active": self._active,
"by_status": _count_by(projects, "status"),
"projects": [
{"name": p.name, "status": p.status, "topic": p.topic}
for p in projects
],
}
# ── project switching ─────────────────────────────────────────
def switch(self, name: str) -> Project:
"""Set the active project."""
if name not in self.projects:
raise KeyError(f"Unknown project: {name}")
self._active = name
self._save_registry()
return self.projects[name]
@property
def active(self) -> Project | None:
"""Currently active project."""
if self._active and self._active in self.projects:
return self.projects[self._active]
return None
# ── comparison ────────────────────────────────────────────────
def compare(self, name_a: str, name_b: str) -> dict[str, Any]:
"""Compare metrics and status of two projects."""
a = self.get(name_a)
b = self.get(name_b)
return {
"project_a": {"name": a.name, "status": a.status, "topic": a.topic, "metrics": a.metrics},
"project_b": {"name": b.name, "status": b.status, "topic": b.topic, "metrics": b.metrics},
"metric_diff": _metric_diff(a.metrics, b.metrics),
}
# ── run lifecycle ─────────────────────────────────────────────
def start_run(self, name: str, run_id: str) -> str:
"""Mark a project as running with a new run ID."""
proj = self.get(name)
proj.status = "running"
proj.last_run_id = run_id
self._save_registry()
return run_id
def finish_run(self, name: str, status: str, metrics: dict[str, Any] | None = None) -> None:
"""Mark a project run as completed or failed."""
proj = self.get(name)
proj.status = status
if metrics:
proj.metrics = metrics
self._save_registry()
def _count_by(projects: list[Project], attr: str) -> dict[str, int]:
counts: dict[str, int] = {}
for p in projects:
val = getattr(p, attr, "unknown")
counts[val] = counts.get(val, 0) + 1
return counts
def _metric_diff(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
all_keys = set(a) | set(b)
diff: dict[str, Any] = {}
for key in sorted(all_keys):
va, vb = a.get(key), b.get(key)
if isinstance(va, (int, float)) and isinstance(vb, (int, float)):
diff[key] = {"a": va, "b": vb, "delta": round(vb - va, 6)}
else:
diff[key] = {"a": va, "b": vb}
return diff
================================================
FILE: researchclaw/project/models.py
================================================
"""Data models for multi-project management."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
@dataclass
class Project:
"""A research project managed by AutoResearchClaw."""
name: str
config_path: str
run_dir: str
status: str = "idle" # idle | running | completed | failed
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_run_id: str | None = None
topic: str = ""
metrics: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Serialize project to a dictionary."""
return {
"name": self.name,
"config_path": self.config_path,
"run_dir": self.run_dir,
"status": self.status,
"created_at": self.created_at.isoformat(),
"last_run_id": self.last_run_id,
"topic": self.topic,
"metrics": self.metrics,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Project:
"""Deserialize project from a dictionary."""
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.now(timezone.utc)
return cls(
name=data["name"],
config_path=data["config_path"],
run_dir=data["run_dir"],
status=data.get("status", "idle"),
created_at=created_at,
last_run_id=data.get("last_run_id"),
topic=data.get("topic", ""),
metrics=data.get("metrics", {}),
)
@dataclass
class Idea:
"""A research idea that can be evaluated and converted to a project."""
id: str
title: str
description: str
status: str = "draft" # draft | evaluated | planned | running | completed
feasibility: float = 0.0 # 0-1
novelty: float = 0.0 # 0-1
domains: list[str] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@property
def score(self) -> float:
"""Composite score: weighted average of feasibility and novelty."""
return 0.4 * self.feasibility + 0.6 * self.novelty
def to_dict(self) -> dict[str, Any]:
"""Serialize idea to a dictionary."""
return {
"id": self.id,
"title": self.title,
"description": self.description,
"status": self.status,
"feasibility": self.feasibility,
"novelty": self.novelty,
"domains": self.domains,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Idea:
"""Deserialize idea from a dictionary."""
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.now(timezone.utc)
return cls(
id=data["id"],
title=data["title"],
description=data["description"],
status=data.get("status", "draft"),
feasibility=float(data.get("feasibility", 0.0)),
novelty=float(data.get("novelty", 0.0)),
domains=data.get("domains", []),
created_at=created_at,
)
================================================
FILE: researchclaw/project/scheduler.py
================================================
"""Project scheduler: priority queue and concurrency control for pipeline runs."""
from __future__ import annotations
import heapq
import logging
from dataclasses import dataclass, field
from typing import Any
from researchclaw.project.manager import ProjectManager
logger = logging.getLogger(__name__)
@dataclass(order=True)
class _QueueEntry:
"""Priority queue entry (lower priority number = higher priority)."""
priority: int
project_name: str = field(compare=False)
class ProjectScheduler:
"""Schedule project pipeline runs with priority and concurrency limits."""
def __init__(self, manager: ProjectManager, max_concurrent: int = 2) -> None:
self.manager = manager
self.max_concurrent = max_concurrent
self._queue: list[_QueueEntry] = []
self._running: set[str] = set()
def enqueue(self, project_name: str, priority: int = 0) -> None:
"""Add a project to the run queue."""
if project_name not in self.manager.projects:
raise KeyError(f"Unknown project: {project_name}")
# Avoid duplicate enqueue
for entry in self._queue:
if entry.project_name == project_name:
logger.info("Project %s already in queue", project_name)
return
if project_name in self._running:
logger.info("Project %s already running", project_name)
return
heapq.heappush(self._queue, _QueueEntry(priority=priority, project_name=project_name))
logger.info("Enqueued project %s with priority %d", project_name, priority)
def dequeue(self) -> str | None:
"""Remove and return the highest-priority project from the queue."""
if not self._queue:
return None
entry = heapq.heappop(self._queue)
return entry.project_name
def next(self) -> str | None:
"""Get the next project that should run, if a slot is available."""
if not self.can_start():
return None
name = self.dequeue()
if name is not None:
self._running.add(name)
return name
def can_start(self) -> bool:
"""Check whether there is capacity to start another run."""
return len(self._running) < self.max_concurrent and len(self._queue) > 0
def mark_done(self, project_name: str) -> None:
"""Mark a running project as finished (frees a concurrency slot)."""
self._running.discard(project_name)
@property
def queue_size(self) -> int:
"""Number of projects waiting in the queue."""
return len(self._queue)
@property
def running_count(self) -> int:
"""Number of projects currently running."""
return len(self._running)
def get_status(self) -> dict[str, Any]:
"""Scheduler status overview."""
return {
"max_concurrent": self.max_concurrent,
"running": sorted(self._running),
"running_count": len(self._running),
"queued": [e.project_name for e in sorted(self._queue)],
"queue_size": len(self._queue),
}
================================================
FILE: researchclaw/prompts.py
================================================
"""Prompt externalization for the ResearchClaw pipeline.
All 23 stage prompts are defined here as defaults and can be overridden
via a user-provided YAML file. Users customize prompts without touching
Python source code.
Architecture
------------
* ``_DEFAULT_STAGES`` — every LLM-facing prompt, keyed by stage name.
* ``_DEFAULT_BLOCKS`` — reusable prompt fragments (topic constraint, etc.).
* ``_DEFAULT_SUB_PROMPTS`` — secondary prompts (code repair, etc.).
* ``PromptManager`` — loads defaults → merges user overrides → renders templates.
* ``_render()`` — safe ``{variable}`` substitution that leaves unmatched
patterns (JSON schemas, curly-brace literals) untouched.
Usage
-----
::
from researchclaw.prompts import PromptManager
pm = PromptManager() # defaults only
pm = PromptManager("my_prompts.yaml") # with user overrides
sp = pm.for_stage("topic_init", topic="RL for drug discovery", domains="ml, bio")
resp = llm.chat(
[{"role": "user", "content": sp.user}],
system=sp.system,
json_mode=sp.json_mode,
max_tokens=sp.max_tokens,
)
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Template rendering
# ---------------------------------------------------------------------------
def _render(template: str, variables: dict[str, str]) -> str:
"""Replace ``{var_name}`` placeholders with *variables* values.
Only bare ``{word_chars}`` tokens are substituted — JSON schema
examples like ``{candidates:[...]}`` or ``{score_1_to_10:number}``
are left untouched because the regex requires the closing ``}``
immediately after the identifier.
"""
def _replacer(match: re.Match[str]) -> str:
key = match.group(1)
return str(variables[key]) if key in variables else match.group(0)
return re.sub(r"\{(\w+)\}", _replacer, template)
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class RenderedPrompt:
"""Fully rendered prompt ready for ``llm.chat()``."""
system: str
user: str
json_mode: bool = False
max_tokens: int | None = None
# ---------------------------------------------------------------------------
# PromptManager
# ---------------------------------------------------------------------------
class PromptManager:
"""Central registry for pipeline prompts with optional YAML overrides."""
def __init__(self, overrides_path: str | Path | None = None) -> None:
# Deep-copy defaults so mutations don't leak across instances
self._stages: dict[str, dict[str, Any]] = {
k: dict(v) for k, v in _DEFAULT_STAGES.items()
}
self._blocks: dict[str, str] = dict(_DEFAULT_BLOCKS)
self._sub_prompts: dict[str, dict[str, Any]] = {
k: dict(v) for k, v in _DEFAULT_SUB_PROMPTS.items()
}
if overrides_path:
self._load_overrides(Path(overrides_path))
# -- loading ----------------------------------------------------------
def _load_overrides(self, path: Path) -> None:
if not path.exists():
logger.warning("Prompts file not found: %s — using defaults", path)
return
try:
data = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as exc:
logger.warning("Bad prompts YAML %s: %s — using defaults", path, exc)
return
for stage_name, stage_data in (data.get("stages") or {}).items():
if stage_name in self._stages and isinstance(stage_data, dict):
self._stages[stage_name].update(stage_data)
else:
logger.warning("Unknown stage in prompts file: %s", stage_name)
for block_name, block_text in (data.get("blocks") or {}).items():
if isinstance(block_text, str):
self._blocks[block_name] = block_text
for sub_name, sub_data in (data.get("sub_prompts") or {}).items():
if sub_name in self._sub_prompts and isinstance(sub_data, dict):
self._sub_prompts[sub_name].update(sub_data)
logger.info("Loaded prompt overrides from %s", path)
# -- primary API ------------------------------------------------------
def for_stage(
self,
stage: str,
*,
evolution_overlay: str = "",
**kwargs: Any,
) -> RenderedPrompt:
"""Return a fully rendered prompt for *stage* with variables filled.
If *evolution_overlay* is provided, it is appended to the user prompt
so the LLM can learn from prior run lessons.
"""
entry = self._stages[stage]
kw = {k: str(v) for k, v in kwargs.items()}
user_text = _render(entry["user"], kw)
if evolution_overlay:
user_text = f"{user_text}\n\n{evolution_overlay}"
return RenderedPrompt(
system=_render(entry["system"], kw),
user=user_text,
json_mode=entry.get("json_mode", False),
max_tokens=entry.get("max_tokens"),
)
def system(self, stage: str) -> str:
"""Return the raw system prompt template for *stage*."""
return self._stages[stage]["system"]
def user(self, stage: str, **kwargs: Any) -> str:
"""Return the rendered user prompt for *stage*."""
return _render(
self._stages[stage]["user"],
{k: str(v) for k, v in kwargs.items()},
)
def json_mode(self, stage: str) -> bool:
return self._stages[stage].get("json_mode", False)
def max_tokens(self, stage: str) -> int | None:
return self._stages[stage].get("max_tokens")
# -- blocks -----------------------------------------------------------
def block(self, name: str, **kwargs: Any) -> str:
"""Render a reusable prompt block."""
return _render(
self._blocks[name],
{k: str(v) for k, v in kwargs.items()},
)
# -- sub-prompts (code repair, etc.) ----------------------------------
def sub_prompt(self, name: str, **kwargs: Any) -> RenderedPrompt:
"""Return a rendered sub-prompt (e.g. code_repair)."""
entry = self._sub_prompts[name]
kw = {k: str(v) for k, v in kwargs.items()}
return RenderedPrompt(
system=_render(entry["system"], kw),
user=_render(entry["user"], kw),
)
# -- introspection ----------------------------------------------------
def stage_names(self) -> list[str]:
return list(self._stages.keys())
def has_stage(self, stage: str) -> bool:
return stage in self._stages
def export_yaml(self, path: Path) -> None:
"""Write current prompts (defaults + overrides) to a YAML file."""
data: dict[str, Any] = {
"version": "1.0",
"blocks": dict(self._blocks),
"stages": {k: dict(v) for k, v in self._stages.items()},
"sub_prompts": {k: dict(v) for k, v in self._sub_prompts.items()},
}
path.write_text(
yaml.dump(data, default_flow_style=False, allow_unicode=True, width=120),
encoding="utf-8",
)
# ========================================================================
# DEFAULT PROMPTS — edit prompts.yaml to override; do NOT edit these.
# ========================================================================
# -- Canonical section word-count targets ----------------------------------
# Single source of truth for per-section word-count ranges.
# Used by executor._validate_draft_quality() and converter.check_paper_completeness().
SECTION_WORD_TARGETS: dict[str, tuple[int, int]] = {
"abstract": (180, 220),
"introduction": (800, 1000),
"related work": (600, 800),
"method": (1000, 1500),
"experiments": (800, 1200),
"results": (600, 800),
"discussion": (400, 600),
"limitations": (200, 300),
"conclusion": (200, 300),
"broader impact": (200, 400),
}
# Aliases mapping heading variants to canonical names in SECTION_WORD_TARGETS.
_SECTION_TARGET_ALIASES: dict[str, str] = {
"methods": "method",
"methodology": "method",
"proposed method": "method",
"approach": "method",
"experimental setup": "experiments",
"experimental results": "results",
"results and discussion": "results",
"results and analysis": "results",
"conclusions": "conclusion",
"conclusion and future work": "conclusion",
"summary": "conclusion",
"background": "related work",
"literature review": "related work",
"prior work": "related work",
"limitation": "limitations",
"limitations and future work": "limitations",
"broader impacts": "broader impact",
"societal impact": "broader impact",
"ethical considerations": "broader impact",
}
# -- Reusable blocks -----------------------------------------------------
_DEFAULT_BLOCKS: dict[str, str] = {
"title_guidelines": (
"\n## TITLE RULES (Hard Constraints)\n"
"1. MAXIMUM 14 words. Ideal: 8-12 words. NEVER exceed 14.\n"
"2. Preferred structure: 'MethodName: Descriptive Phrase' (colon format)\n"
" - Create a catchy 1-3 word method name (acronym, portmanteau, or evocative word)\n"
" - Subtitle explains what it does: 'for X' / 'via Y' / 'in Z'\n"
" - Examples: 'AlphaEdit: Null-Space Knowledge Editing for LMs' (8 words)\n"
" - Examples: 'VAR: Visual Autoregressive Modeling via Next-Scale Prediction' (8 words)\n"
"3. Alternative: Bold declarative claim that surprises the reader\n"
" - 'Not All Tokens Are What You Need for Pretraining' (9 words)\n"
" - 'Vision Transformers Need Registers' (4 words)\n"
"4. FORBIDDEN patterns:\n"
" - 'Investigating...', 'An Empirical Study of...', 'Towards...'\n"
" - 'A Novel Approach to...', 'On the...' (generic academic filler)\n"
" - Repeating the full method description as title\n"
" - Weakness qualifiers: 'in Two Runs', 'Under Limited Data'\n"
"5. MUST define a short method name (2-5 chars) that serves as memorable handle.\n"
" The reader should be able to say 'Have you read the X paper?'\n"
"6. No abbreviations unless universally known (LLM, RL, GAN, NLP are OK).\n"
),
"abstract_structure": (
"\n## ABSTRACT (Hard Rules — 180-220 words, 5-7 sentences)\n"
"STRUCTURE (PMR+ format):\n"
"S1-S2: PROBLEM — What gap exists? Why does it matter? (NO method names yet)\n"
"S3-S4: METHOD — Name your system. One-sentence description of key insight.\n"
"S5-S6: RESULTS — At most 3 specific numbers. Use relative improvements\n"
" ('X% over baseline') not raw values ('0.7667'). Bold the single most\n"
" important result.\n"
"S7 (optional): IMPACT — What does this enable?\n\n"
"HARD CONSTRAINTS:\n"
"- NO \\texttt{{}} in abstract\n"
"- NO more than 3 numeric values in the entire abstract\n"
"- NO per-seed breakdowns or confidence intervals\n"
"- NO method names longer than 3 words (use the short system name)\n"
"- The abstract must be readable by a researcher who skimmed only the title\n"
"- First sentence must NOT start with 'We' or 'This paper'\n"
),
"compute_budget": (
"\n## Compute Budget Constraint\n"
"- Total execution time limit: {time_budget_sec} seconds\n"
"- You MUST design experiments that complete within this budget\n"
"- Estimate: a simple numpy loop runs ~10M iterations/sec; a nested loop over\n"
" conditions runs proportionally slower\n"
"- SCALING RULES (mandatory):\n"
" - If total conditions > 100: reduce seeds to 3-5 (not 20)\n"
" - If total conditions > 500: reduce to 2-3 representative conditions per factor\n"
" - If time_budget < 300s: limit total optimization steps to ≤5,000 per run\n"
" - If time_budget < 120s: limit total optimization steps to ≤1,000 per run\n"
" - Always print intermediate results so partial data is captured on timeout\n"
"- MANDATORY: print a 'TIME_ESTIMATE: Xs' line before the main loop,\n"
" estimating total runtime based on a small pilot (run 1 condition, extrapolate)\n"
"- MANDATORY: implement a time guard — check elapsed time periodically and\n"
" stop gracefully if approaching 80% of budget, saving all results collected so far\n"
"- MANDATORY: add NaN/divergence fast-fail guard:\n"
" - After each optimization step, check if loss is NaN or > 100\n"
" - If detected, print 'FAIL: NaN/divergence detected', save partial results, and exit\n"
" - Do NOT waste compute on a diverging run\n"
"- MINIMUM TRAINING EPOCHS (CRITICAL for meaningful results):\n"
" - CIFAR-10/100 with ResNet/CNN: minimum 50 epochs (200 recommended)\n"
" - FashionMNIST with small CNN: minimum 20 epochs\n"
" - RL environments: follow the RL STEP BUDGET below (CRITICAL)\n"
" - If time_budget is too short for minimum epochs, REDUCE model complexity\n"
" or dataset size INSTEAD of reducing epochs. 8 epochs on CIFAR-10 will\n"
" produce random-chance accuracy (~10%), making all comparisons meaningless.\n"
" - Use a SMALL model (simple CNN, few layers) to fit enough epochs into the budget.\n"
" - A converged small model is worth infinitely more than a diverged large model.\n"
"- MANDATORY: use the experiment_harness module (pre-installed in sandbox):\n"
" ```\n"
" from experiment_harness import ExperimentHarness\n"
" harness = ExperimentHarness(time_budget={time_budget_sec})\n"
" # In your experiment loop:\n"
" if harness.should_stop():\n"
" break # graceful stop at 80% of budget\n"
" if not harness.check_value(value, 'metric_name'):\n"
" print('SKIP: NaN/Inf detected') # skip invalid values\n"
" continue\n"
" harness.report_metric('metric_name', value) # validated output\n"
" # At the end of ALL experiments:\n"
" harness.finalize() # writes results.json — MUST be called\n"
" ```\n"
" The harness provides: time budget enforcement, NaN/Inf detection,\n"
" validated metric reporting, and results.json output. NOT using it\n"
" means your metrics may be lost or malformed.\n"
),
"topic_constraint": (
"\n\n=== HARD TOPIC CONSTRAINT ===\n"
"The paper MUST be about: {topic}\n"
"PROHIBITED content (unless user explicitly specifies case-study mode):\n"
"- Do NOT treat environment setup, dependency installation, or infrastructure "
"failures as a research contribution.\n"
"- Do NOT present debugging logs, system errors, or configuration issues "
"as experimental findings.\n"
"- Do NOT drift to tangential topics not directly related to the stated topic.\n"
"- Every section MUST connect back to the core research question.\n"
"- The Abstract and Introduction MUST clearly state the research problem "
"derived from: {topic}\n"
"- The Method section MUST describe a technical approach, not a workflow.\n"
"- The Results section MUST report quantitative outcomes of experiments, "
"not environment status.\n"
"=== END CONSTRAINT ===\n"
),
"pkg_hint_sandbox": (
"\nAVAILABLE PACKAGES (sandbox mode): Python stdlib, numpy, math, random, "
"statistics, json.\n"
"Do NOT use: torch, tensorflow, jax, sklearn, pandas, scipy, matplotlib, "
"or any deep learning framework.\n"
"Write the experiment using ONLY numpy and stdlib.\n"
),
"dataset_guidance": (
"\n## Standard Datasets & Real Baselines (MANDATORY when applicable)\n"
"You MUST use real benchmark datasets — NEVER synthetic torch.randn() data.\n\n"
"### Tier 1: Pre-cached (ALWAYS available, use download=False)\n"
"These datasets are already in the Docker image. Use download=False:\n"
"- `torchvision.datasets.CIFAR10(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.CIFAR100(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.MNIST(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.FashionMNIST(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.STL10(root='/opt/datasets', split='train'/'test', download=False)`\n"
"- `torchvision.datasets.SVHN(root='/opt/datasets', split='train'/'test', download=False)`\n\n"
"### Tier 2: Downloadable (use setup.py to download before main.py runs)\n"
"For any dataset NOT in Tier 1, create a `setup.py` file that downloads it.\n"
"setup.py runs WITH network access; main.py runs WITHOUT network.\n"
"- Any torchvision dataset (Caltech-101, Flowers102, etc.)\n"
"- HuggingFace datasets: `from datasets import load_dataset`\n"
" Examples: IMDB, AG News, WikiText, SST-2, SQuAD, MMLU\n"
"- OGB benchmarks: ogbg-molhiv, ogbn-arxiv, etc.\n"
"- Tiny-ImageNet (237MB, 200 classes) — good ImageNet proxy\n\n"
"### Tier 3: Too large for download (use alternatives)\n"
"These datasets are TOO LARGE to download within experiment time limits:\n"
"- ImageNet-1K (168GB) → use Tiny-ImageNet or CIFAR-100 as proxy\n"
"- LAION (>1TB) → use smaller HuggingFace image-text datasets\n"
"- Common Crawl, The Pile → use WikiText-103 or pre-tokenized subsets\n"
"NEVER generate 'ImageNet-like' synthetic data — always use a real alternative.\n\n"
"### ANTI-PATTERNS (NEVER DO THESE):\n"
"- `torch.randn(N, 3, 224, 224)` as dataset → use real datasets\n"
"- `download=True` in main.py → put downloads in setup.py\n"
"- `download=False` for non-cached datasets → will FileNotFoundError\n"
"- Random train/test splits → use official splits from dataset\n"
"- `os.makedirs('/opt/datasets/...')` → /opt/datasets is READ-ONLY\n\n"
"DATA PATH: For Tier 1 pre-cached datasets, use `/opt/datasets` as root.\n"
"For Tier 2 datasets downloaded by setup.py, use `/workspace/data` as root.\n"
"WARNING: `/opt/datasets` is READ-ONLY. NEVER call os.makedirs() on it.\n"
"Just pass `root='/opt/datasets'` directly to torchvision dataset constructors.\n\n"
"DISTRIBUTION SHIFT — use torchvision corruption transforms:\n"
"- Gaussian noise: `transforms.Lambda(lambda x: x + torch.randn_like(x) * sigma)`\n"
"- Brightness shift: `transforms.ColorJitter(brightness=0.5)`\n"
"- Contrast shift: `transforms.ColorJitter(contrast=0.5)`\n"
"- Blur: `transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))`\n"
"- For CIFAR-10-C style corruptions, apply transforms to test set only.\n\n"
"REAL BASELINES & MODERN BENCHMARKS (CRITICAL):\n"
"- Use proper train/test splits from the dataset (never split randomly in code)\n"
"- Use standard architectures (ResNet-18/50, ViT, ConvNeXt) — not toy 2-layer MLPs\n"
"- CIFAR INPUT SIZE (IMPORTANT): CIFAR images are 32×32. Two valid approaches:\n"
" 1. PRETRAINED models (ImageNet weights): Use `transforms.Resize(224)` — "
"pretrained models require 224×224 inputs.\n"
" 2. TRAINING FROM SCRATCH (most experiments): Modify the model for 32×32 "
"inputs instead of resizing. For ResNet: use `nn.Conv2d(3,64,3,1,1)` as "
"first conv (not 7×7/stride-2) and REMOVE the initial MaxPool. This is 49× "
"more memory-efficient and trains faster than Resize(224). Use the `timm` "
"library's CIFAR variants or build a custom `get_resnet18_cifar()` helper.\n"
"- Report standard metrics (top-1 accuracy for classification tasks)\n"
"- Compare against published baselines where available\n"
"- BASELINES MUST BE CURRENT: Use baselines from recent top-venue papers "
"(2023-2026). Do NOT use outdated methods as the primary comparison.\n"
" * AlexNet, VGG-16 → use ResNet-50, ViT, ConvNeXt instead\n"
" * Vanilla SGD → use AdamW, SGD+momentum+cosine LR\n"
" * Simple RNN/LSTM for NLP → use Transformer-based models\n"
"- Include at LEAST one strong, modern baseline (near-SOTA).\n"
"- BENCHMARKS MUST BE STANDARD and actively used in the community.\n\n"
"WHEN TO USE SYNTHETIC DATA (required for these domains):\n"
"- **PDE / Scientific computing**: Generate synthetic PDE data (Burgers "
"equation, Darcy flow, heat equation, Navier-Stokes). Use numerical solvers "
"(scipy.integrate, finite differences) to create ground truth.\n"
"- **Combinatorial optimization** (TSP, graph coloring, scheduling): Generate "
"random problem instances (random TSP cities, Erdos-Renyi graphs).\n"
"- **Theoretical analysis**: Synthetic optimization landscapes, toy problems.\n"
"- **Domain with no standard dataset**: Novel combinatorial or mathematical domains.\n"
"For these domains, do NOT use CIFAR/MNIST/ImageNet — they are irrelevant. "
"Generate problem-specific synthetic data in main.py.\n\n"
"DOMAIN-DATASET MATCHING (CRITICAL):\n"
"- Image classification → CIFAR-10/100, MNIST, ImageNet variants\n"
"- NLP → IMDB, AG News, SST-2, WikiText\n"
"- Graph learning → Cora, CiteSeer, ogbn-arxiv\n"
"- PDE/Physics → SYNTHETIC (Burgers, Darcy, Navier-Stokes)\n"
"- Combinatorial optimization → SYNTHETIC (random TSP, graph instances)\n"
"- RL → Gymnasium environments (CartPole, LunarLander, HalfCheetah)\n"
"NEVER use image datasets for non-image problems.\n"
),
"setup_script_guidance": (
"\n## Setup Script (setup.py) — Dataset Download & Preparation\n"
"If your experiment needs datasets NOT in the pre-cached list, generate "
"a SEPARATE file called `setup.py` that downloads and prepares them.\n"
"The setup.py runs WITH NETWORK ACCESS before main.py (which runs WITHOUT network).\n\n"
"IMPORTANT: All download logic MUST be in setup.py, NOT in main.py.\n"
"main.py should only load pre-cached data from /opt/datasets (download=False) "
"or downloaded data from /workspace/data.\n\n"
"Example setup.py:\n"
"```python\n"
"import os\n"
"DATA_DIR = '/workspace/data'\n"
"os.makedirs(DATA_DIR, exist_ok=True)\n\n"
"# Download torchvision datasets\n"
"import torchvision\n"
"torchvision.datasets.Caltech101(root=DATA_DIR, download=True)\n\n"
"# Download HuggingFace datasets\n"
"from datasets import load_dataset\n"
"ds = load_dataset('imdb', cache_dir=os.path.join(DATA_DIR, 'hf'))\n\n"
"# Download OGB benchmarks\n"
"# from ogb.graphproppred import PygGraphPropPredDataset\n"
"# dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root=DATA_DIR)\n\n"
"print('[setup] Dataset download complete.')\n"
"```\n\n"
"IMPORT ANTI-PATTERN (NEVER DO THIS):\n"
"```python\n"
"from datasets import load_dataset\n"
"datasets.load_dataset('imdb', ...) # WRONG — NameError!\n"
"```\n"
"If you write `from datasets import load_dataset`, call `load_dataset(...)` directly.\n"
"If you write `import datasets`, call `datasets.load_dataset(...)` with module prefix.\n"
"NEVER mix the two styles.\n\n"
"If ALL your datasets are pre-cached (CIFAR-10/100, MNIST, FashionMNIST, "
"STL-10, SVHN), you do NOT need setup.py — just use download=False in main.py.\n\n"
"You may also include a `requirements.txt` file listing any additional "
"pip packages your experiment needs beyond the pre-installed set.\n"
),
"network_disabled_guidance": (
"\n## ⚠️ NO NETWORK ACCESS — CRITICAL CONSTRAINT ⚠️\n"
"This experiment runs with network_policy='none'. There is NO network access\n"
"at ANY phase (no pip install, no dataset downloads, no HTTP requests).\n\n"
"### ONLY these pre-cached datasets are available:\n"
"- `torchvision.datasets.CIFAR10(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.CIFAR100(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.MNIST(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.FashionMNIST(root='/opt/datasets', train=True/False, download=False)`\n"
"- `torchvision.datasets.STL10(root='/opt/datasets', split='train'/'test', download=False)`\n"
"- `torchvision.datasets.SVHN(root='/opt/datasets', split='train'/'test', download=False)`\n\n"
"### FORBIDDEN (will cause runtime failure):\n"
"- Do NOT create setup.py (it cannot run without network)\n"
"- Do NOT create requirements.txt (pip install is unavailable)\n"
"- Do NOT use `download=True` on any dataset\n"
"- Do NOT use `urllib`, `requests`, `httpx`, or any HTTP library\n"
"- Do NOT use `datasets.load_dataset()` from HuggingFace (requires download)\n"
"- Do NOT import packages not pre-installed in the Docker image\n\n"
"### Available pre-installed packages:\n"
"torch, torchvision, torchaudio, numpy, scipy, sklearn, matplotlib, seaborn,\n"
"pandas, tqdm, gymnasium, networkx, PyYAML, Pillow, timm, einops, torchmetrics,\n"
"h5py, transformers, datasets, accelerate, peft, bitsandbytes.\n\n"
"If your research topic requires a dataset NOT in the pre-cached list,\n"
"you MUST adapt to use one of the 6 pre-cached datasets instead.\n"
),
"network_full_guidance": (
"\n## Network Access: Full\n"
"This experiment runs with network_policy='full'. Network access is available\n"
"throughout ALL execution phases (setup, pip install, and main experiment).\n"
"You may download datasets, install packages, and make HTTP requests at any time.\n"
),
"hp_reporting": (
"\n## Hyperparameter Reporting (MANDATORY)\n"
"At the TOP of main.py, define a HYPERPARAMETERS dictionary containing ALL "
"tunable hyperparameters used in your experiment:\n"
"```python\n"
"HYPERPARAMETERS = {\n"
" 'learning_rate': 0.001,\n"
" 'batch_size': 64,\n"
" 'num_epochs': 50,\n"
" 'hidden_dim': 256,\n"
" # ... all other hyperparameters\n"
"}\n"
"```\n"
"At the end of main.py, save hyperparameters to results.json:\n"
"```python\n"
"import json\n"
"results = {'hyperparameters': HYPERPARAMETERS, 'metrics': collected_metrics}\n"
"with open('results.json', 'w') as f:\n"
" json.dump(results, f, indent=2)\n"
"```\n"
"EVERY hyperparameter must be used in the code — no dead parameters.\n"
"The paper MUST include a hyperparameter table — this data feeds into it.\n"
),
"rl_step_guidance": (
"\n## RL Training Step Budget (MANDATORY for RL experiments)\n"
"Reinforcement learning requires MANY more training steps than supervised learning.\n"
"Under-trained RL agents produce random-chance performance, making ALL comparisons\n"
"meaningless and the paper unpublishable.\n\n"
"### Environment Availability:\n"
"#### Always available (classic control — no extra dependencies):\n"
"- CartPole-v1, Pendulum-v1, MountainCar-v0, MountainCarContinuous-v0,\n"
" Acrobot-v1, LunarLander-v3\n"
"- These are lightweight and fast — PREFER these unless MuJoCo is specifically required.\n\n"
"#### MuJoCo environments (pre-installed in Docker image):\n"
"- HalfCheetah-v5, Hopper-v5, Walker2d-v5, Ant-v5, Humanoid-v5,\n"
" Swimmer-v5, Reacher-v5, InvertedPendulum-v5, InvertedDoublePendulum-v5\n"
"- Require MuJoCo runtime — available in Docker but NOT in basic sandbox mode.\n\n"
"#### RULE: If the research topic says 'MuJoCo-free', 'without MuJoCo',\n"
" or 'classic control only' → you MUST use classic control environments ONLY.\n"
" Do NOT import or reference MuJoCo in any way.\n\n"
"#### DEFAULT RECOMMENDATION: Prefer classic control environments unless the\n"
" research topic specifically requires MuJoCo locomotion tasks.\n\n"
"### ALGORITHM-ENVIRONMENT COMPATIBILITY (HARD RULE — violation = crash):\n"
"- DQN is ONLY for DISCRETE action spaces (CartPole, LunarLander, Acrobot, Atari).\n"
" DQN will CRASH on Pendulum, HalfCheetah, Hopper, Walker2d, etc.\n"
"- For CONTINUOUS action spaces: use SAC, TD3, or PPO.\n"
"- PPO works for both discrete and continuous.\n"
"- NEVER combine DQN + any continuous environment.\n\n"
"### TIME BUDGET RULES FOR RL:\n"
"- If time_budget ≤ 3600s → ONLY classic control "
"(CartPole, Pendulum, MountainCar, Acrobot, LunarLander)\n"
"- If time_budget ≤ 1800s → ONLY CartPole or Pendulum (simplest)\n"
"- MuJoCo requires >5000s for meaningful results.\n\n"
"### Minimum Steps by Algorithm Family:\n"
"| Algorithm | Environment | Min Steps | Recommended |\n"
"|-----------|-------------|-----------|-------------|\n"
"| PPO | MuJoCo (Ant, HalfCheetah, Humanoid) | 500K | 1M-3M |\n"
"| PPO | Simple control (CartPole, Pendulum) | 100K | 500K |\n"
"| SAC/TD3 | MuJoCo locomotion | 300K | 1M |\n"
"| SAC/TD3 | Simple control | 50K | 200K |\n"
"| DQN/Rainbow | Atari | 1M | 10M |\n"
"| A2C/A3C | Any continuous | 500K | 2M |\n"
"| REINFORCE | Any | 200K | 1M |\n\n"
"### Step Budget Allocation Strategy:\n"
"1. Compute pilot_time = time for 1000 steps of 1 condition.\n"
"2. steps_per_sec = 1000 / pilot_time.\n"
"3. max_steps_per_condition = (time_budget * 0.7) / num_conditions * steps_per_sec.\n"
"4. If max_steps < min_steps for the algorithm, REDUCE num_seeds to 3 (not steps).\n"
"5. If STILL under min_steps, use a simpler environment (e.g., Pendulum instead of Ant).\n"
"6. NEVER reduce steps below the minimum — it wastes compute on meaningless results.\n\n"
"### Evaluation Protocol for RL:\n"
"- Evaluate every N_eval steps (e.g., every 10K steps) using deterministic policy.\n"
"- Run 10 evaluation episodes per checkpoint.\n"
"- Report: mean return, std return, success rate (if applicable).\n"
"- Plot learning curves (return vs steps) — this is EXPECTED by reviewers.\n"
"- Final metric = mean over last 10 evaluation checkpoints (NOT last episode).\n\n"
"### Gymnasium Environment Version (CRITICAL):\n"
"- Use v5 environments (NOT v4): `gym.make('HalfCheetah-v5')`, `gym.make('Hopper-v5')`\n"
"- v4 environments are deprecated and will produce warnings.\n"
"- Available MuJoCo v5 envs: HalfCheetah-v5, Hopper-v5, Walker2d-v5, Ant-v5,\n"
" Humanoid-v5, Swimmer-v5, Reacher-v5, InvertedPendulum-v5, InvertedDoublePendulum-v5\n"
"- For simple/fast experiments: use Pendulum-v1, CartPole-v1, MountainCarContinuous-v0\n\n"
"### Gymnasium API (CRITICAL — common crash source):\n"
"- `env.reset()` returns `(obs, info)` — ALWAYS unpack both:\n"
" `obs, info = env.reset(seed=seed)`\n"
"- `env.step(action)` returns `(obs, reward, terminated, truncated, info)` — 5 values:\n"
" `obs, reward, terminated, truncated, info = env.step(action)`\n"
" `done = terminated or truncated`\n"
"- DO NOT use old `done = env.step(action)[2]` — this is the Gym (v0.26-) API.\n"
"- `reward` is a scalar float, NOT an array. Do NOT index it: use `reward` directly.\n"
"- `obs` shape depends on env: discrete envs give 1D array, image envs give 3D.\n"
" Always check `env.observation_space.shape` and handle accordingly.\n\n"
"### Learning Curve Logging (MANDATORY for RL papers):\n"
"- Print evaluation metrics at regular intervals: every N_eval steps\n"
" `EVAL: step= condition= seed= return=`\n"
"- This enables plotting learning curves (return vs training steps)\n"
"- Learning curves are EXPECTED by RL reviewers — a paper without them\n"
" will be rejected regardless of final performance.\n"
"- At the end, print the full curve:\n"
" `LEARNING_CURVE: condition= seed= steps=[...] returns=[...]`\n"
),
"multi_seed_enforcement": (
"\n## Multi-Seed Experiment Requirement (MANDATORY — NO EXCEPTIONS)\n"
"Running each condition with only 1 seed is NEVER acceptable. Results from\n"
"a single seed cannot distinguish signal from noise and reviewers will reject.\n\n"
"### HARD REQUIREMENT:\n"
"- You MUST use exactly seeds = [0, 1, 2] (3 seeds minimum).\n"
"- Each condition MUST loop over ALL seeds.\n"
"- Print per-seed: `condition=X seed=S {metric_key}: V`\n"
"- Print aggregated: `condition=X {metric_key}_mean: M {metric_key}_std: S`\n"
"- Tables MUST show mean ± std, NEVER single-run values.\n\n"
"### Implementation Pattern (copy this structure):\n"
"```python\n"
"SEEDS = [0, 1, 2] # EXACTLY 3 seeds — mandatory minimum\n"
"all_results = {} # {condition_name: {seed: metric_value}}\n\n"
"for condition_name, ConditionClass in conditions.items():\n"
" all_results[condition_name] = {}\n"
" for seed in SEEDS:\n"
" set_all_seeds(seed) # torch, numpy, random\n"
" result = run_single(ConditionClass, seed=seed)\n"
" all_results[condition_name][seed] = result\n"
" print(f'condition={condition_name} seed={seed} metric: {result}')\n"
" values = list(all_results[condition_name].values())\n"
" print(f'condition={condition_name} metric_mean: {np.mean(values):.4f} '\n"
" f'metric_std: {np.std(values):.4f}')\n"
"```\n\n"
"### Reporting Requirements:\n"
"- Print per-seed results: `condition=X seed=S metric: V`\n"
"- Print aggregated: `condition=X metric_mean: M metric_std: S`\n"
"- Tables in the paper MUST show mean ± std, NEVER single-run values.\n"
"- If time budget forces < 5 seeds, use EXACTLY 3 seeds (minimum).\n"
" Print: `SEED_WARNING: only 3 seeds used due to time budget`.\n"
),
"writing_structure": (
"\n## Paper Section Writing Rules\n"
"MARKDOWN FORMATTING (CRITICAL):\n"
"- Use `# Title` (H1) for the paper title\n"
"- Use `# Abstract`, `# Introduction`, `# Method`, etc. (H1) for MAIN sections\n"
"- Use `## Subsection Name` (H2) for subsections WITHIN a main section\n"
"- NEVER use `##` for main sections — that produces wrong LaTeX heading levels\n"
"- Each main section (H1) MUST contain subsections (H2) when it exceeds 3 paragraphs\n"
"- NEVER place sub-topics (e.g., 'Knowledge Distillation for Compact Models') "
"at the same heading level as main sections (e.g., 'Related Work')\n"
"- NEVER wrap the paper in ```markdown fences\n"
"- NEVER use raw variable names (e.g., `method_name/metric_key = 0.85`) — "
"always use human-readable text\n\n"
"ABSTRACT (150-200 words, 5-sentence structure):\n"
"- (1) Problem and significance (2) Prior approaches and gaps\n"
"- (3) Your approach and novelty (4) Key results with 2-3 specific numbers\n"
"- (5) Implication/takeaway\n"
"- Do NOT list per-seed ranges (e.g., '0.71-0.73 across seeds') — use mean +/- std\n"
"- Do NOT repeat numbers that appear in the Results section — pick the 2-3 most impactful\n\n"
"INTRODUCTION (4 paragraphs, 800-1000 words, cite 8-12 references):\n"
"Paragraph 1: Problem motivation (why this matters). "
"Paragraph 2: What exists and why it falls short. "
"Paragraph 3: Your approach and key insight. "
"Paragraph 4: Contributions (2-3 bullet points allowed here ONLY).\n\n"
"RELATED WORK:\n"
"Organize by sub-topic, not chronologically. "
"End each paragraph with how YOUR work differs from the cited work. "
"Cite at least 15 references, all directly relevant.\n\n"
"METHOD:\n"
"Write as flowing narrative prose (NOT bullet points). "
"Include full algorithm description with pseudocode or step-by-step. "
"State all hyperparameters with values and justification. "
"Provide architecture details sufficient for reproduction.\n\n"
"RESULTS:\n"
"- Do NOT repeat the same number more than twice across the paper\n"
"- Each number in a table should be discussed AT MOST once in text\n"
"- Tables: mean +/- std with 95% CI in parentheses\n"
"- Bold the best result in each column\n"
"- Every comparison claim must cite a p-value or note multiple seeds\n"
"- Report the number of random seeds/runs used\n\n"
"FIGURES AND TABLES:\n"
"- Every figure MUST be referenced in the text (e.g., 'As shown in Figure 1')\n"
"- Every table MUST be referenced in the text (e.g., 'Table 2 summarizes')\n"
"- Figure captions: 1-2 descriptive sentences (not just 'Results comparison')\n"
"- Table captions go ABOVE the table; figure captions go BELOW the figure\n"
"- Axis labels must include units where applicable\n"
"- Use consistent font sizes across all figures\n\n"
"DISCUSSION (if applicable, can be merged into Results):\n"
"- Paragraph 1: Summarize key findings and their significance\n"
"- Paragraph 2: Compare with prior work — explain WHY results differ\n"
"- Paragraph 3: Discuss unexpected or negative results honestly\n"
"- Paragraph 4: Broader implications and practical applications\n\n"
"LIMITATIONS (3-5 points):\n"
"- State each limitation ONCE, here only — not scattered throughout\n"
"- No disclaimers like 'due to computational constraints'\n"
"- Include compute resources used (GPU type, training time)\n\n"
"CONCLUSION:\n"
"- Summarize findings (match actual results, no aspirational claims)\n"
"- 2-3 sentences of future work\n\n"
"PROSE QUALITY (CRITICAL — violation = desk reject):\n"
"- Write FLOWING ACADEMIC PARAGRAPHS, not bullet-point lists.\n"
"- Each paragraph must have 4-8 sentences with smooth transitions.\n"
"- Introduction, Related Work, and Method must each be >=3 paragraphs.\n"
"- FORBIDDEN: starting 3+ consecutive paragraphs with the same word.\n"
"- FORBIDDEN: bullet-point lists in Introduction or Related Work sections.\n"
"- Use varied sentence structures: mix simple, compound, and complex sentences.\n"
"- Connect paragraphs with transition phrases: 'Building on this insight...', "
"'In contrast to prior work...', 'To address this limitation...'.\n"
"- Each Related Work paragraph must COMPARE your approach to cited work, "
"not merely summarize what each paper does.\n"
"- FORBIDDEN AI-BOILERPLATE phrases (instant credibility loss):\n"
" 'delves into', 'it is worth noting', 'plays a crucial role',\n"
" 'leverages the power of', 'paves the way', 'a myriad of',\n"
" 'paradigm shift', 'groundbreaking', 'in the realm of',\n"
" 'holistic approach', 'multifaceted', 'navigate the complexities'.\n"
" Replace ALL such phrases with precise, specific academic language.\n"
),
"llm_training_guidance": (
"\n## LLM Fine-Tuning Guidance (when topic involves language model training)\n"
"AVAILABLE FRAMEWORKS (pre-installed in Docker):\n"
"- transformers (AutoModelForCausalLM, AutoTokenizer, Trainer)\n"
"- peft (LoraConfig, get_peft_model, PeftModel)\n"
"- trl (SFTTrainer, DPOTrainer, GRPOTrainer)\n"
"- datasets (load_dataset, Dataset)\n"
"- accelerate (Accelerator)\n"
"- bitsandbytes (4-bit/8-bit quantization)\n\n"
"GPU MEMORY GUIDELINES (RTX 6000 Ada, 49GB VRAM):\n"
"- Full fine-tune: <=3B parameters\n"
"- LoRA (16-bit): <=14B parameters\n"
"- QLoRA (4-bit): <=72B parameters (practical limit ~14B for training)\n"
"- Optimal: 7B-14B model with QLoRA (rank 16-64)\n\n"
"RECOMMENDED TRAINING PATTERN:\n"
"```python\n"
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n"
"from peft import LoraConfig, get_peft_model, TaskType\n"
"from trl import SFTTrainer, SFTConfig\n"
"from datasets import load_dataset\n\n"
"# 4-bit quantization for memory efficiency\n"
"bnb_config = BitsAndBytesConfig(\n"
" load_in_4bit=True,\n"
" bnb_4bit_quant_type='nf4',\n"
" bnb_4bit_compute_dtype=torch.bfloat16,\n"
")\n"
"model = AutoModelForCausalLM.from_pretrained(\n"
" model_name, quantization_config=bnb_config, device_map='auto'\n"
")\n"
"lora_config = LoraConfig(\n"
" r=16, lora_alpha=32, target_modules='all-linear',\n"
" lora_dropout=0.05, task_type=TaskType.CAUSAL_LM,\n"
")\n"
"model = get_peft_model(model, lora_config)\n"
"```\n\n"
"KEY HYPERPARAMETERS:\n"
"- learning_rate: 1e-4 to 2e-4 (LoRA), 5e-5 to 1e-4 (full FT)\n"
"- lora_r: 8 (minimal) to 64 (high-capacity)\n"
"- lora_alpha: typically 2x lora_r\n"
"- batch_size: 1-4 per device (use gradient_accumulation_steps for effective batch)\n"
"- gradient_accumulation_steps: 4-16 (effective_batch = per_device * accum)\n"
"- max_seq_length: 512 (short), 1024-2048 (standard), 4096 (long)\n"
"- warmup_ratio: 0.03-0.1\n"
"- weight_decay: 0.01-0.1\n\n"
"DATA FORMAT (use datasets library):\n"
"- Instruction tuning: {'instruction': '...', 'output': '...'}\n"
"- Chat format: {'messages': [{'role': 'user', 'content': '...'}, ...]}\n"
"- DPO: {'prompt': '...', 'chosen': '...', 'rejected': '...'}\n"
"- Use load_dataset('json', data_files='train.json') for local data\n"
"- Use load_dataset('HuggingFace/dataset_name') for HF Hub datasets\n\n"
"EVALUATION:\n"
"- Use evaluate library for standard metrics\n"
"- Common: perplexity, ROUGE (summarization), BLEU (translation), accuracy\n"
"- LLM benchmarks: MMLU, ARC, HellaSwag, TruthfulQA\n"
"- Generate sample outputs for qualitative comparison\n\n"
"MODEL DOWNLOAD:\n"
"- Models will be downloaded from HuggingFace Hub at runtime\n"
"- Use 'trust_remote_code=True' for custom model architectures\n"
"- Cache directory: default HF cache (~/.cache/huggingface)\n"
"- Common models: Qwen/Qwen2.5-7B, meta-llama/Llama-3.1-8B, "
"microsoft/Phi-4, google/gemma-2-9b\n\n"
"CRITICAL — NO SIMULATION:\n"
"- You MUST load and train a REAL model from HuggingFace Hub.\n"
"- NEVER simulate training with synthetic utility functions or random scores.\n"
"- NEVER replace model training with np.random/torch.randn mock results.\n"
"- A real experiment loads a model, tokenizes data, runs optimizer steps, "
"and measures real loss/perplexity/accuracy on held-out data.\n"
"- If compute budget is tight, use a SMALLER model (Qwen2.5-0.5B or 1.5B) "
"with fewer training steps rather than simulating.\n"
),
"llm_eval_guidance": (
"\n## LLM Evaluation Guidance\n"
"STANDARD BENCHMARKS:\n"
"- Reasoning: MMLU, ARC-Challenge, HellaSwag, WinoGrande\n"
"- Math: GSM8K, MATH, MathVista\n"
"- Coding: HumanEval, MBPP, LiveCodeBench\n"
"- Safety: TruthfulQA, BBQ, CrowS-Pairs\n"
"- Instruction following: MT-Bench, AlpacaEval, IFEval\n"
"- Multimodal: MMBench, POPE, MathVista, MMMU\n\n"
"EVALUATION FRAMEWORKS:\n"
"- lm-eval-harness: Standard eval framework, run via CLI or Python API\n"
"- vllm: Fast inference engine for throughput-focused evaluation\n"
"- lighteval: HuggingFace's lightweight eval framework\n\n"
"EVALUATION PROTOCOL:\n"
"- Report on at least 3 benchmarks relevant to the task\n"
"- Compare with published baselines from model cards/leaderboards\n"
"- Report both zero-shot and few-shot results where applicable\n"
"- Include perplexity on held-out test set\n"
),
# IMP-20: Academic writing style guide (from NeurIPS/ICLR/ICML 2024-2025 best papers)
"academic_style_guide": (
"\n## ACADEMIC WRITING STANDARDS (from NeurIPS/ICLR/ICML 2024-2025 best papers)\n\n"
"### Title Standards\n"
"- Target 8-14 words. Median of award-winning papers: ~10 words.\n"
"- Preferred format: 'SystemName: Descriptive Subtitle' (35% of best papers)\n"
" e.g., 'AlphaEdit: Null-Space Constrained Knowledge Editing for Language Models'\n"
"- Alternative: Declarative statement that surprises\n"
" e.g., 'Not All Tokens Are What You Need for Pretraining'\n"
"- Give your method a memorable, catchy name (VAR, Genie, PRISM, SEDD).\n"
"- NEVER exceed 18 words. NEVER use 'A Novel Approach to...' or 'Investigating...'\n\n"
"### Abstract Standards (PMR+ Structure, 180-220 words)\n"
"S1-S2: PROBLEM — State the gap. Open with a challenge or status-quo critique.\n"
"S3-S4: METHOD — Name your system by sentence 3. Describe the key insight.\n"
"S5-S6: RESULTS — At least 2-3 concrete quantitative claims:\n"
" - One relative improvement ('36.7% boost over baseline')\n"
" - One absolute benchmark score ('FID of 1.01 on ImageNet')\n"
"AVOID: Per-seed ranges, excessive texttt, defensive hedging.\n\n"
"### Section Writing Standards\n"
"INTRODUCTION (800-1000 words, 4 paragraphs):\n"
" - Para 1: Motivation; Para 2: Gap (cite 3-5 papers); Para 3: Your approach;\n"
" Para 4: Contributions (bullet list of 3-4 specific contributions)\n"
" - MUST cite 8-12 references throughout Introduction\n\n"
"RELATED WORK (600-800 words):\n"
" - Organize by sub-topic (2-3 subsections), NOT as a flat list\n"
" - End each subsection with how YOUR work differs\n"
" - Target >= 15 unique references in this section alone\n\n"
"METHOD (1000-1500 words):\n"
" - Start with problem formulation (notation, objective function)\n"
" - Use algorithm environment for pseudocode (not verbatim)\n"
" - Write as a flowing narrative, NOT bullet points\n\n"
"EXPERIMENTS (800-1200 words):\n"
" - Experimental setup as subsection (datasets, baselines, metrics, hardware)\n"
" - Hyperparameter table (Table 1 always)\n"
" - MUST reference figures: 'As shown in Figure 1, our method...'\n"
" - MUST cite baseline method papers (not just name them)\n\n"
"RESULTS (600-800 words):\n"
" - Main results table with descriptive caption\n"
" - Ablation study table\n"
" - Analysis paragraphs connecting numbers to insights\n"
" - DO NOT repeat the same numbers from Experiments section\n"
" - Reference figures for visual evidence\n\n"
"DISCUSSION (400-600 words):\n"
" - Compare findings with prior work (cite papers here!)\n"
" - Explain surprising results; broader implications\n\n"
"LIMITATIONS (200-300 words): 3-5 specific, concrete limitations. ALL caveats go HERE.\n\n"
"CONCLUSION: Summarize in 2-3 sentences, future work in 2-3 sentences.\n\n"
"### Writing Quality Rules\n"
"- Write as FLOWING PROSE, not bullet points or enumerated lists\n"
"- Each paragraph: topic sentence, evidence, analysis, transition\n"
"- Use transitions: 'Building on this insight...', 'In contrast to...'\n"
"- Academic tone: confident but precise\n"
"- Vary sentence structure: mix short declarative with longer analytical\n"
"- AVOID: Starting 3+ consecutive sentences with 'We', 'The', 'Our'\n"
"- AVOID: 'It is worth noting that', 'It should be mentioned that' (filler)\n"
"- Citations belong in EVERY section, not just Introduction and Related Work\n"
),
# IMP-25: Narrative writing requirements
"narrative_writing_rules": (
"\n## NARRATIVE WRITING REQUIREMENTS\n\n"
"You are writing a paper for human reviewers at a top AI conference. The paper\n"
"must read like a cohesive academic story, NOT a technical report or bullet list.\n\n"
"### Structure of Each Paragraph\n"
"Every paragraph MUST follow this pattern:\n"
"1. TOPIC SENTENCE — states the main claim or finding\n"
"2. EVIDENCE — data, citations, or reasoning that supports the claim\n"
"3. ANALYSIS — what the evidence means, why it matters\n"
"4. TRANSITION — connects to the next paragraph's topic\n\n"
"### FORBIDDEN Writing Patterns\n"
"- Bullet-point lists in the main body (ONLY allowed in Contributions paragraph\n"
" of Introduction and Limitations section)\n"
"- Numbered lists of findings or results\n"
"- Starting a paragraph with 'Table X shows...' without context first\n"
"- Consecutive short sentences without analysis between them\n"
"- Repeating the same sentence structure 3+ times in a row\n\n"
"### REQUIRED Writing Patterns\n"
"- Transition phrases: 'Building on this observation...', 'In contrast to prior work...'\n"
"- Vary sentence length: alternate between short impactful and longer analytical\n"
"- Ground every claim in evidence: '[Result] because [mechanism] (cite)'\n"
"- Discuss implications: 'This X% improvement indicates that [mechanism Y]\n"
" is more effective than [mechanism Z] for [context]'\n"
"- For temporal data: describe trends in prose rather than bullet-point lists\n\n"
"### Example: BAD vs GOOD Method Description\n"
"BAD (bullet-list style):\n"
" 'Our method has three components:\n"
" - Component A\n"
" - Component B\n"
" - Component C'\n\n"
"GOOD (narrative style):\n"
" 'Our method builds on the insight that [core problem] stems from\n"
" [root cause identified in Section 2]. To address this, we introduce\n"
" [MethodName], a [N]-stage framework. First, [Stage 1] maps inputs\n"
" to [representation]. These representations feed into [Stage 2],\n"
" enabling [benefit] without [drawback of prior approaches].\n"
" Crucially, we augment this with [Stage 3] based on [technical\n"
" foundation] (cite original paper), triggering [mechanism] when\n"
" [condition is met].'\n"
" NOTE: Replace all [placeholders] with YOUR actual method details.\n"
" Do NOT copy this template verbatim.\n"
),
# IMP-31: Anti-hedging rules
"anti_hedging_rules": (
"\n## ANTI-HEDGING RULES (MANDATORY)\n"
"1. The following phrases are BANNED from the paper body:\n"
" - 'we do not claim' / 'we cannot claim'\n"
" - 'we intentionally frame this conservatively'\n"
" - 'the evidence does not support' (unless followed by what it DOES support)\n"
" - 'only N seeds/runs' (belongs ONLY in Limitations, stated ONCE)\n"
" - 'this paper is not' / 'we do not' as paragraph openers\n"
"2. Limitations and caveats MUST be consolidated in the Limitations section.\n"
" They may NOT appear in Introduction, Method, Results, or Conclusion.\n"
"3. Confidence framing: Instead of 'we cannot prove X', write 'our results\n"
" provide evidence for X' or 'X is supported by [metrics]'.\n"
"4. If you have a negative result, frame it as an INSIGHT:\n"
" BAD: 'Our method failed to outperform the baseline, we do not claim...'\n"
" GOOD: 'Surprisingly, the standard baseline proved competitive, suggesting\n"
" that [insight about why] — an observation with practical implications for...'\n"
),
# IMP-24: Anti-repetition rules
"anti_repetition_rules": (
"\n## ANTI-REPETITION RULE\n"
"Each specific number (e.g., '0.7667', '36.7%') may appear in AT MOST 2 sections:\n"
" - Once in Results/Experiments (where it is first reported)\n"
" - Once in Abstract (as a summary highlight)\n"
"The Introduction, Discussion, and Conclusion MUST refer to results qualitatively\n"
"('significantly outperformed', 'X% improvement') WITHOUT repeating exact numbers\n"
"from the Results section. Violation of this rule will result in desk rejection.\n"
),
}
# -- Debate role prompts (multi-perspective generation) -------------------
DEBATE_ROLES_HYPOTHESIS: dict[str, dict[str, str]] = {
"innovator": {
"system": (
"You are a bold, creative researcher who thinks outside the box. "
"You pursue high-risk high-reward ideas, draw cross-domain analogies, "
"and propose counter-intuitive hypotheses that challenge mainstream thinking."
),
"user": (
"Generate at least 2 novel, unconventional hypotheses from the synthesis below.\n"
"CRITICAL REQUIREMENTS for EVERY hypothesis:\n"
"1. NOVELTY: Must go beyond incremental combination of existing methods.\n"
"2. FEASIBILITY: Must be testable within 30 minutes of compute on a single GPU.\n"
"3. FALSIFIABILITY: Must define a specific metric threshold that would reject it.\n"
"For each hypothesis provide:\n"
"- A bold claim that pushes boundaries\n"
"- Cross-domain inspiration (if applicable)\n"
"- Rationale grounded in the literature gaps\n"
"- Measurable prediction and failure condition\n"
"- Estimated risk level (low/medium/high)\n\n"
"Topic: {topic}\n"
"Synthesis:\n{synthesis}"
),
},
"pragmatist": {
"system": (
"You are a practical ML engineer focused on what actually works. "
"You prioritize computational feasibility, engineering simplicity, "
"reliable baselines, and incremental but solid improvements."
),
"user": (
"Generate at least 2 feasible, well-grounded hypotheses from the synthesis below.\n"
"For each hypothesis provide:\n"
"- A concrete, testable claim with clear methodology\n"
"- Why this is achievable with limited compute\n"
"- Rationale based on proven techniques\n"
"- Measurable prediction and failure condition\n"
"- Resource requirements estimate\n\n"
"Topic: {topic}\n"
"Synthesis:\n{synthesis}"
),
},
"contrarian": {
"system": (
"You are a rigorous devil's advocate who challenges assumptions. "
"You find blind spots, hidden failure modes, and counter-evidence. "
"Your value is in finding problems others ignore. Be provocative "
"but always grounded in evidence."
),
"user": (
"Critically examine the synthesis and generate at least 2 contrarian hypotheses.\n"
"For each hypothesis provide:\n"
"- A challenge to a widely-held assumption in this area\n"
"- Evidence or reasoning for why the mainstream view may be wrong\n"
"- An alternative hypothesis that accounts for overlooked factors\n"
"- Measurable prediction and failure condition\n"
"- Potential negative results that would be informative\n\n"
"Topic: {topic}\n"
"Synthesis:\n{synthesis}"
),
},
}
DEBATE_ROLES_ANALYSIS: dict[str, dict[str, str]] = {
"optimist": {
"system": (
"You highlight positive findings, promising extensions, and silver linings "
"in experimental results. You identify what worked well and why, "
"and suggest how to build on successes."
),
"user": (
"Analyze the experiment results from an optimistic perspective.\n"
"Cover:\n"
"- What worked well and why\n"
"- Unexpected positive findings\n"
"- Promising extensions and next steps\n"
"- Silver linings in any negative results\n\n"
"{preamble}\n{data_context}\n"
"Run context:\n{context}"
),
},
"skeptic": {
"system": (
"You question the significance of results with maximum rigor. "
"You check statistical validity, identify confounds, and demand "
"stronger evidence. Every claim must earn its place."
),
"user": (
"Critically scrutinize the experiment results.\n"
"Cover:\n"
"- Statistical concerns (significance, sample size, multiple comparisons)\n"
"- Potential confounds and alternative explanations\n"
"- Missing evidence or controls\n"
"- Whether metrics truly capture the intended phenomenon\n\n"
"{preamble}\n{data_context}\n"
"Run context:\n{context}"
),
},
"methodologist": {
"system": (
"You scrutinize HOW experiments were conducted. You audit "
"internal/external validity, reproducibility, baseline fairness, "
"and evaluation protocols."
),
"user": (
"Audit the experimental methodology.\n"
"Cover:\n"
"- Baseline fairness and completeness\n"
"- Metric appropriateness for the research question\n"
"- Evaluation protocol (data leakage, contamination risks)\n"
"- Ablation completeness\n"
"- Reproducibility assessment\n"
"- Specific methodology improvements needed\n\n"
"{preamble}\n{data_context}\n"
"Run context:\n{context}"
),
},
}
# -- Sub-prompts (secondary LLM calls within a stage) --------------------
_DEFAULT_SUB_PROMPTS: dict[str, dict[str, Any]] = {
"hypothesis_synthesize": {
"system": (
"You are a senior research director synthesizing multiple perspectives "
"into a decisive research proposal. The best synthesis is not a "
"compromise but takes the strongest elements from each viewpoint. "
"Preserve genuine disagreements — do not flatten controversy."
),
"user": (
"Below are hypotheses generated from three different research perspectives.\n"
"Synthesize them into a final set of 2-4 hypotheses that:\n"
"1. Take the strongest, most novel ideas\n"
"2. Address critical concerns raised by the contrarian\n"
"3. Ensure feasibility (pragmatist's input)\n"
"4. Note unresolved disagreements between perspectives\n"
"5. For each final hypothesis: rationale, measurable prediction, "
"failure condition\n\n"
"{perspectives}"
),
},
"analysis_synthesize": {
"system": (
"You are a senior research director synthesizing multiple analytical "
"perspectives into a comprehensive assessment. Find the truth — if "
"the skeptic or methodologist raise valid concerns, acknowledge them. "
"Do not suppress criticism."
),
"user": (
"Below are analyses from three different perspectives (optimist, "
"skeptic, methodologist).\n"
"Produce a unified analysis that:\n"
"1. Identifies consensus points (high-confidence conclusions)\n"
"2. Resolves conflicts with evidence-based judgment\n"
"3. Rates result quality (1-10 with justification)\n"
"4. Lists 3-5 key findings\n"
"5. Notes methodology gaps that need addressing\n"
"6. Gives a clear PROCEED/PIVOT/REFINE recommendation\n\n"
"Required sections: Metrics Summary, Consensus Findings, "
"Contested Points, Statistical Checks, Methodology Audit, "
"Limitations, Conclusion.\n\n"
"{perspectives}"
),
"max_tokens": 8192,
},
"code_repair": {
"system": "You fix Python code validation errors while preserving functionality.",
"user": (
"The file `{fname}` in the experiment project has validation errors. "
"Fix ALL issues and return ONLY the corrected file.\n\n"
"## Validation Issues in {fname}\n{issues_text}\n\n"
"## All Project Files\n{all_files_ctx}\n\n"
"IMPORTANT: Do NOT use subprocess, os.system, eval, exec, or any "
"network/shell calls.\n"
"NUMPY 2.x: np.trapz→np.trapezoid, np.erfinv→scipy.special.erfinv, "
"np.bool/int/float→Python builtins.\n"
"Return ONLY the corrected code for `{fname}`."
),
},
"iterative_improve": {
"system": (
"You improve experiment projects and return valid executable Python code. "
"Use ```filename:xxx.py format for each file."
),
"user": (
"Improve the experiment code based on prior run results.\n"
"Return the improved files using ```filename:xxx.py format for each file.\n"
"Primary metric key: {metric_key}\n"
"Metric direction: {metric_direction}\n"
"Do not use subprocess, os.system, eval, exec, or any network/shell calls.\n"
"NUMPY 2.x: np.trapz→np.trapezoid, np.erfinv→scipy.special.erfinv, "
"np.bool/int/float→Python builtins, np.math→math.\n\n"
"EXPERIMENT PLAN ANCHOR (CRITICAL — read before making changes):\n"
"The research topic is: {topic}\n"
"{exp_plan_anchor}"
"RULES FOR REFINEMENT:\n"
"- NEVER rename, remove, or replace existing condition names. "
"The condition names in the code MUST match the experiment plan.\n"
"- NEVER add new conditions that are not in the experiment plan.\n"
"- ONLY improve the IMPLEMENTATION of existing conditions "
"(fix bugs, tune hyperparameters, improve training loops).\n"
"- If the code has fundamental issues (wrong algorithm, missing "
"components), fix the implementation but keep the same condition "
"names and class hierarchy.\n\n"
"{condition_coverage_hint}"
"SEED ENFORCEMENT (MANDATORY — BUG-183):\n"
"- You MUST use exactly seeds = [0, 1, 2] (3 seeds minimum).\n"
"- Each condition MUST loop over ALL seeds.\n"
"- Print per-seed: condition=X seed=S {metric_key}: V\n"
"- Print aggregated: condition=X {metric_key}_mean: M {metric_key}_std: S\n"
"- If 3 seeds × all conditions exceeds the time budget, REDUCE training "
"epochs or conditions — NEVER reduce seed count below 3.\n\n"
"CONDITION COUNT LIMIT (HARD RULE):\n"
"- MAXIMUM 8 total conditions (baselines + methods + ablations).\n"
"- If the previous code had >8 conditions, consolidate ablations to 2-3 values.\n\n"
"DOCKER MOUNT TOPOLOGY (for fixing PermissionError/path issues):\n"
"- WRITABLE: /workspace/ (project files), /tmp/, /workspace/data/\n"
"- READ-ONLY: /opt/datasets/ (pre-cached CIFAR-10/100, MNIST, etc)\n"
"- If you see PermissionError on /opt/datasets, do NOT call "
"os.makedirs() there. Use root='/opt/datasets' with download=False.\n"
"- For new data downloads, use /workspace/data/ as root.\n\n"
"Current project files:\n{files_context}\n"
"Run summaries (JSON):\n{run_summaries}"
),
"max_tokens": 8192,
},
"iterative_repair": {
"system": "You fix Python validation issues without adding unsafe behavior.",
"user": (
"Fix all validation issues in main.py and return corrected Python code only.\n\n"
"## Validation Issues\n{issue_text}\n\n"
"## Common RL Stability Fixes (apply if NaN/divergence detected):\n"
"- Add gradient clipping: `torch.nn.utils.clip_grad_norm_(params, 1.0)`\n"
"- Lower learning rate to 1e-4 or 3e-4\n"
"- Add reward normalization/clipping: `reward = np.clip(reward, -10, 10)`\n"
"- Add NaN guard: `if torch.isnan(loss): continue`\n"
"- Use float32 (not float16) for RL value functions\n"
"- NUMPY 2.x: np.trapz→np.trapezoid, np.erfinv→scipy.special.erfinv, "
"np.bool/int/float→Python builtins\n\n"
"## All Project Files\n{all_files_ctx}"
),
},
# ── Advanced Code Agent sub-prompts ──────────────────────────────────
"architecture_planning": {
"system": (
"You are a senior software architect who designs implementation "
"blueprints for scientific experiment codebases. You produce detailed, "
"directly-implementable specifications with pseudocode for every "
"class method and explicit tensor shape annotations. You emphasize "
"separation of concerns: data loading, model definition, training "
"loop, and evaluation are distinct components. You understand ML "
"training deeply and design for correctness: proper .detach(), "
"consistent tensor shapes, and correct gradient flow.\n\n"
"NUMPY 2.x COMPATIBILITY (CRITICAL):\n"
"- np.trapz is REMOVED → use np.trapezoid\n"
"- np.erfinv does NOT exist → use scipy.special.erfinv\n"
"- np.bool, np.int, np.float, np.complex are REMOVED → use Python builtins\n"
"- np.str, np.object are REMOVED → use str, object\n"
"- np.math is REMOVED → use math module"
),
"user": (
"Create a detailed IMPLEMENTATION BLUEPRINT for an experiment codebase.\n\n"
"## Research Context\n"
"TOPIC: {topic}\n"
"PRIMARY METRIC: {metric}\n\n"
"## Experiment Plan\n{exp_plan}\n\n"
"## Requirements\n"
"1. `main.py` MUST be the entry point — runs ALL conditions sequentially.\n"
"2. Each condition MUST be a SEPARATE class with DISTINCT implementation.\n"
"3. Data loading and model definitions in separate modules.\n"
"4. No more than 5 Python files total.\n"
"5. Every class must have at least 20 lines of effective code.\n"
"6. Child classes MUST override at least one core method with DIFFERENT logic.\n"
"7. NEVER override nn.Module.train/eval with different signatures.\n"
"8. Design child classes as STRATEGY variants, not PARAMETER variants.\n\n"
"## Blueprint Format (YAML)\n"
"The blueprint MUST include ALL of the following for EACH file:\n"
"- `generation_order`: integer (1=first to generate, higher=later)\n"
"- `dependencies`: list of other files this file imports from\n"
"- `classes` or `functions`: with pseudocode for each method\n"
"- For neural network classes: input/output tensor shapes\n\n"
"```yaml\n"
"files:\n"
" - name: config.py\n"
" generation_order: 1\n"
" dependencies: []\n"
" purpose: Hyperparameter configuration\n"
" classes:\n"
" - name: Config\n"
" fields:\n"
" - lr: 0.01\n"
" - batch_size: 128\n"
" - epochs: 20\n"
" - hidden_dim: 128\n\n"
" - name: data.py\n"
" generation_order: 2\n"
" dependencies: [config.py]\n"
" purpose: Dataset loading and preprocessing\n"
" functions:\n"
" - name: get_dataloaders\n"
" signature: (config) -> (train_loader, val_loader, test_loader)\n"
" pseudocode: |\n"
" 1. Load dataset from torchvision/disk\n"
" 2. Apply standard transforms (normalize, augment)\n"
" 3. Split train into train/val (90/10)\n"
" 4. Return DataLoaders with config.batch_size\n\n"
" - name: models.py\n"
" generation_order: 3\n"
" dependencies: [config.py]\n"
" purpose: All model implementations\n"
" classes:\n"
" - name: BaseModel(nn.Module)\n"
" input_shape: [B, 3, 32, 32]\n"
" output_shape: [B, 10]\n"
" methods:\n"
" - name: __init__\n"
" pseudocode: Define layers (conv/linear/attention)\n"
" - name: forward\n"
" pseudocode: |\n"
" 1. x = self.encoder(x) # [B,3,32,32] -> [B, hidden]\n"
" 2. logits = self.classifier(x) # [B, hidden] -> [B, 10]\n"
" 3. return logits\n"
" - name: ProposedMethod(BaseModel)\n"
" differentiator: Uses novel component X\n"
" overrides: [forward]\n"
" methods:\n"
" - name: forward\n"
" pseudocode: |\n"
" 1. x = self.encoder(x)\n"
" 2. x = self.novel_component(x) # KEY DIFFERENCE\n"
" 3. logits = self.classifier(x)\n"
" 4. return logits\n"
" - name: compute_special_loss\n"
" pseudocode: |\n"
" 1. Compute task loss: CE(logits, labels)\n"
" 2. Compute novel regularizer\n"
" 3. return task_loss + lambda * reg\n\n"
" - name: training.py\n"
" generation_order: 4\n"
" dependencies: [config.py, data.py, models.py]\n"
" purpose: Training loop and evaluation\n"
" functions:\n"
" - name: train_one_epoch\n"
" signature: (model, loader, optimizer, device) -> float\n"
" pseudocode: |\n"
" 1. model.train()\n"
" 2. For each batch: forward, loss, backward, step\n"
" 3. Return average loss\n"
" - name: evaluate\n"
" signature: (model, loader, device) -> dict\n"
" pseudocode: |\n"
" 1. model.eval() with torch.no_grad()\n"
" 2. For each batch: forward, argmax predictions\n"
" 3. Return {accuracy, loss}\n\n"
" - name: main.py\n"
" generation_order: 5\n"
" dependencies: [config.py, data.py, models.py, training.py]\n"
" purpose: Entry point — runs ALL conditions\n"
" contract:\n"
" prints_metric_def: true\n"
" prints_registered_conditions: true\n"
" runs_all_conditions: true\n"
" per_seed_reporting: true\n"
" time_budget_guard: true\n"
" functions:\n"
" - name: main\n"
" pseudocode: |\n"
" 1. Print METRIC_DEF line\n"
" 2. Print REGISTERED_CONDITIONS\n"
" 3. Setup time budget guard\n"
" 4. For each condition:\n"
" a. Create model instance\n"
" b. For each seed:\n"
" - Set random seed\n"
" - Train model\n"
" - Evaluate and print per-seed metrics\n"
" c. Print mean/std across seeds\n"
" 5. Print SUMMARY comparison\n\n"
"verification_criteria:\n"
" - All condition classes have DIFFERENT forward/step implementations\n"
" - Input/output tensor shapes are consistent across data->model->loss\n"
" - Time budget guard exists in main training loop\n"
" - Per-seed random state isolation\n"
" - All .detach() calls present for values used across iterations\n\n"
"conditions:\n"
" - name: ConditionName\n"
" class: ClassName\n"
" description: What makes it different\n"
"```\n\n"
"Output ONLY the YAML specification wrapped in ```yaml``` fences.\n"
"Be SPECIFIC in pseudocode — include tensor shapes, loss formulas, "
"and algorithmic details from the experiment plan.\n"
"Every class must have detailed pseudocode showing HOW it differs "
"from others, not just THAT it differs."
),
"max_tokens": 8192,
},
"generate_single_file": {
"system": (
"You are an expert ML engineer who writes production-quality Python code "
"for scientific experiments. You follow implementation blueprints precisely, "
"ensuring tensor shapes match, gradients flow correctly, and all imports "
"resolve. You write complete, runnable code — never stubs or placeholders."
),
"user": (
"Generate the Python file `{file_name}` for an ML experiment project.\n\n"
"## File Specification\n{file_spec}\n\n"
"## Full Project Blueprint\n{blueprint}\n\n"
"## Already Generated Files (summaries)\n{dependency_summaries}\n\n"
"## Already Generated Files (full code of direct dependencies)\n"
"{dependency_code}\n\n"
"## Research Topic\n{topic}\n\n"
"## Experiment Plan\n{exp_plan}\n\n"
"## Environment\n{pkg_hint}\n\n"
"## CRITICAL Rules\n"
"1. Follow the blueprint specification EXACTLY — implement every class "
"and function listed for this file.\n"
"2. Tensor shapes MUST match the blueprint annotations.\n"
"3. Imports from dependency files MUST use the exact class/function names "
"from the already-generated code.\n"
"4. Every method must have a REAL implementation — no `pass`, no `...`, "
"no `raise NotImplementedError`.\n"
"5. NEVER use random numbers as fake metrics.\n"
"6. For RL code: .detach() ALL values from previous iterations before "
"using in current loss.\n"
"7. For neural networks: create layers in __init__, not in forward().\n"
"8. METHOD RICHNESS: Every non-trivial method should be >=5 lines of "
"real logic. If a method only calls super() or returns a constant, "
"add the actual computation it should perform. Training methods should "
"include proper gradient handling, metric logging, and error checks.\n"
"9. ABLATION DIFFERENTIATION: If this file contains ablation/variant "
"classes, each MUST differ in actual algorithm logic — not just in "
"parameter values or by removing a line. Ablations should clearly "
"implement a different computational path.\n"
"10. NO CLI CONDITION ARGS: If this is main.py, NEVER add argparse "
"arguments like --condition or --method. All conditions must be "
"iterated inside main.py with a for-loop. The harness runs "
"`python main.py` with no arguments.\n"
"11. NUMPY 2.x COMPATIBILITY: np.trapz→np.trapezoid, "
"np.erfinv→scipy.special.erfinv, np.bool/np.int/np.float→Python builtins, "
"np.str/np.object→str/object, np.math→math.\n\n"
"Output ONLY the Python code for `{file_name}` — no markdown fences, "
"no explanations, just the code."
),
"max_tokens": 8192,
},
"code_exec_fix": {
"system": (
"You are a debugging expert who fixes runtime errors in Python "
"experiment code. You preserve the original experiment design and "
"scientific methodology while fixing the specific error. You fix "
"the ROOT CAUSE, not just the symptom."
),
"user": (
"The following experiment code crashed during execution.\n\n"
"## Error Output (stderr, last 3000 chars)\n"
"```\n{stderr}\n```\n\n"
"## Standard Output (last 50 lines)\n"
"```\n{stdout_tail}\n```\n\n"
"## Return Code: {returncode}\n\n"
"## Current Code Files\n{files_context}\n\n"
"## Instructions\n"
"1. Identify the ROOT CAUSE of the error.\n"
"2. Fix it while preserving the experiment design.\n"
"3. Check for similar potential issues in ALL files.\n"
"4. Do NOT simplify or remove experiment logic — fix the bug.\n"
"5. Do NOT add subprocess, os.system, eval, exec, or network calls.\n"
"6. COMMON BUG: If error is about `train()` missing arguments, it means "
"a class overrode nn.Module.train() with a custom signature. Fix by "
"renaming the custom method to `fit()` or `run_training()` and updating "
"all callers. Never override nn.Module.train/eval with extra args.\n"
"7. NUMPY 2.x: np.trapz→np.trapezoid, np.erfinv→scipy.special.erfinv, "
"np.bool/int/float/complex→Python builtins, np.str/object→str/object.\n\n"
"Output ALL files in ```filename:xxx.py``` format, including files "
"that don't need changes."
),
"max_tokens": 16384,
},
"code_reviewer": {
"system": (
"You are a meticulous experiment code reviewer focused on "
"scientific correctness, statistical rigor, and code quality. "
"You catch bugs that static analysis cannot: incorrect algorithm "
"implementations, missing controls, wrong metric computation, "
"and experimental design flaws."
),
"user": (
"Review this experiment code for correctness and quality.\n\n"
"## Research Context\n"
"TOPIC: {topic}\n"
"PRIMARY METRIC: {metric}\n\n"
"## Experiment Plan\n{exp_plan}\n\n"
"## Code Files\n{files_context}\n\n"
"## Review Criteria\n"
"1. **CORRECTNESS**: Does the code correctly implement the "
"experiment plan? Are algorithms implemented properly?\n"
"2. **COMPLETENESS**: Are all conditions/ablations implemented "
"with DISTINCT logic? (Not just renamed copies of baseline.)\n"
"3. **STATISTICAL RIGOR**: Multiple seeds? Results averaged and "
"reported with std? Paired comparisons?\n"
"4. **METRIC REPORTING**: Is {metric} correctly computed and "
"printed in the required format?\n"
"5. **ROBUSTNESS**: Shape mismatches? Missing imports? Type "
"errors? Division by zero? GPU/CPU device conflicts?\n"
"6. **CLASS DEPTH**: Each experimental condition class must have "
"at least 20 lines of effective code with distinct logic. Classes "
"that only override __init__ to change parameters are CRITICAL "
"issues — they indicate the condition is not truly different.\n\n"
"## Output Format (JSON)\n"
"```json\n"
'{{\n'
' "verdict": "APPROVE or REVISE",\n'
' "score": 1-10,\n'
' "critical_issues": ["issue1", "issue2"],\n'
' "suggestions": ["suggestion1", "suggestion2"]\n'
'}}\n'
"```\n\n"
"Only use verdict REVISE if there are critical issues that would "
"cause the code to crash or produce scientifically invalid results."
),
"json_mode": True,
"max_tokens": 4096,
},
}
# -- Stage prompts (one entry per LLM-calling stage) ---------------------
_DEFAULT_STAGES: dict[str, dict[str, Any]] = {
# ── Phase A: Research Scoping ────────────────────────────────────────
"topic_init": {
"system": (
"You are a rigorous research planner who identifies NOVEL, TIMELY "
"research angles. You follow recent trends from top venues in the "
"relevant domain and propose research that advances "
"the frontier rather than repeating known results.\n\n"
"NOVELTY PRINCIPLES:\n"
"- A good research angle addresses a GAP not yet covered by existing work.\n"
"- Avoid pure benchmark/comparison studies unless the methodology is novel.\n"
"- Prefer angles that combine existing techniques in new ways, apply methods "
"to underexplored domains, or challenge common assumptions.\n"
"- The research must be FEASIBLE with limited compute (single GPU, hours not days).\n"
"- Check: would a reviewer say 'this is already well-known'? If so, find a sharper angle."
),
"user": (
"Create a SMART research goal in markdown.\n"
"Topic: {topic}\n"
"Domains: {domains}\n"
"Project: {project_name}\n"
"Quality threshold: {quality_threshold}\n\n"
"Required sections:\n"
"- **Topic**: The broad area\n"
"- **Novel Angle**: What specific aspect has NOT been well-studied? "
"Why is this timely NOW (2024-2026)? What recent development creates "
"an opportunity? How does this differ from standard approaches?\n"
"- **Scope**: Focused enough for a single paper\n"
"- **SMART Goal**: Specific, Measurable, Achievable, Relevant, Time-bound\n"
"- **Constraints**: Compute budget, available tools, data access\n"
"- **Success Criteria**: What results would make this publishable?\n"
"- **Generated**: Timestamp\n\n"
"IMPORTANT: The 'Novel Angle' section must convincingly argue why this "
"specific research direction is NOT already covered by existing work. "
"If the topic is well-studied (e.g., 'comparing optimizers'), you MUST "
"find a specific unexplored aspect (e.g., 'under distribution shift with "
"noisy gradients', 'in the few-shot regime', 'with modern architectures').\n\n"
"TREND VALIDATION (MANDATORY):\n"
"- Identify 2-3 recent papers (2024-2026) that establish the relevance "
"of this research direction.\n"
"- Name the specific benchmark/dataset that will be used for evaluation.\n"
"- If no standard benchmark exists, explain how results will be measured.\n"
"- State whether SOTA results exist on this benchmark and what they are.\n"
"- Add a 'Benchmark' subsection listing: name, source, metrics, "
"current SOTA (if known)."
),
},
"problem_decompose": {
"system": "You are a senior research strategist.",
"user": (
"Decompose this research problem into at least 4 prioritized "
"sub-questions.\n"
"Topic: {topic}\n"
"Output markdown with sections: Source, Sub-questions, Priority "
"Ranking, Risks.\n"
"Goal context:\n{goal_text}"
),
},
# ── Phase B: Literature Discovery ────────────────────────────────────
"search_strategy": {
"system": (
"You design literature retrieval strategies and source verification plans."
),
"user": (
"Create a merged search strategy package.\n"
"Return a JSON object with keys: search_plan_yaml, sources.\n"
"search_plan_yaml must be valid YAML text.\n"
"sources must include id,name,type,url,status,query,verified_at.\n"
"Topic: {topic}\n"
"Problem tree:\n{problem_tree}"
),
"json_mode": True,
},
"literature_collect": {
"system": "You are a literature mining assistant.",
"user": (
"Generate candidate papers from the search plan.\n"
"Return JSON: {candidates:[...]} with >=8 rows.\n"
"Each candidate must include id,title,source,url,year,abstract,"
"collected_at.\n"
"Topic: {topic}\n"
"Search plan:\n{plan_text}"
),
"json_mode": True,
},
"literature_screen": {
"system": (
"You are a strict domain-aware reviewer with zero tolerance for "
"cross-domain false positives. You MUST reject papers that are "
"from unrelated fields, even if they share superficial keyword "
"overlap. A paper about 'normalization in database systems' is "
"NOT relevant to 'normalization in deep learning'. A paper about "
"'graph theory in social networks' is NOT relevant to 'graph "
"neural networks for molecular property prediction'."
),
"user": (
"Perform merged relevance+quality screening and return shortlist.\n"
"Return JSON: {shortlist:[...]} each with title, cite_key "
"(if present), relevance_score (0-1), quality_score (0-1), "
"keep_reason.\n"
"Preserve all original fields (paper_id, doi, arxiv_id, cite_key, "
"etc.) from the input.\n"
"Topic: {topic}\n"
"Domains: {domains}\n"
"Threshold: {quality_threshold}\n\n"
"SCREENING RULES (apply strictly):\n"
"1. DOMAIN MATCH: The paper's actual research domain must match "
"the topic's domain. Shared keywords across domains do NOT count.\n"
"2. METHOD RELEVANCE: The paper must discuss methods, benchmarks, "
"or findings directly applicable to the research topic.\n"
"3. CROSS-DOMAIN REJECTION: Reject papers from unrelated fields "
"(e.g., wireless communications, database systems, social science) "
"even if they use similar terminology.\n"
"4. RECENCY PREFERENCE: Prefer papers from 2020+ for methodology, "
"but accept foundational papers (pre-2020) if they introduced key "
"techniques still in use today.\n"
"5. SEMINAL PAPERS: Papers marked as source='seminal_library' are "
"pre-vetted foundational references — keep them if their keywords "
"match the topic (relevance_score >= 0.7).\n"
"6. QUALITY FLOOR: Reject papers with no abstract, no venue, and "
"no citation count (likely not real papers).\n"
"Candidates JSONL:\n{candidates_text}"
),
"json_mode": True,
},
"knowledge_extract": {
"system": "You extract high-signal evidence cards from papers.",
"user": (
"Extract structured knowledge cards from shortlist.\n"
"Return JSON: {cards:[{card_id,title,cite_key,problem,method,"
"data,metrics,findings,limitations,citation}]}.\n"
"IMPORTANT: If the input contains cite_key fields, preserve them "
"exactly in the output.\n"
"Shortlist:\n{shortlist}"
),
"json_mode": True,
},
# ── Phase C: Knowledge Synthesis ─────────────────────────────────────
"synthesis": {
"system": "You are a synthesis specialist for literature reviews.",
"user": (
"Produce merged synthesis output (topic clusters + research gaps).\n"
"Output markdown with sections: Cluster Overview, Cluster 1..N, "
"Gap 1..N, Prioritized Opportunities.\n"
"Topic: {topic}\n"
"Cards context:\n{cards_context}"
),
"max_tokens": 8192,
},
"hypothesis_gen": {
"system": (
"You formulate testable scientific hypotheses that address gaps "
"NOT covered by existing literature. Your hypotheses must be:\n"
"1. NOVEL: Not simply replicating known results or testing obvious things.\n"
"2. GAP-FILLING: Address specific weaknesses or blind spots identified "
"in the literature synthesis.\n"
"3. FEASIBLE: Testable with limited compute (single GPU, <1 day runtime).\n"
"4. FALSIFIABLE: Have clear failure conditions that would definitively "
"reject the hypothesis.\n"
"5. SURPRISING: At least one hypothesis should challenge conventional "
"wisdom or test a counter-intuitive prediction."
),
"user": (
"Generate at least 2 falsifiable hypotheses from the synthesis below.\n"
"For each hypothesis provide:\n"
"- **Hypothesis statement**: A clear, testable claim\n"
"- **Novelty argument**: Why this has NOT been tested before, citing "
"specific gaps from the synthesis\n"
"- **Rationale**: Theoretical or empirical basis for expecting this result\n"
"- **Measurable prediction**: Specific quantitative outcome expected\n"
"- **Failure condition**: What result would reject this hypothesis?\n"
"- **Required baselines**: What modern, state-of-the-art methods must be "
"compared against to make the finding meaningful?\n\n"
"AVOID:\n"
"- Hypotheses that are trivially obvious (e.g., 'more data improves accuracy')\n"
"- Hypotheses that replicate well-known results already in the literature\n"
"- Hypotheses that cannot be tested within the compute budget\n\n"
"Synthesis:\n{synthesis}"
),
},
# ── Phase D: Experiment Design ───────────────────────────────────────
"experiment_design": {
"system": "You are a principal investigator designing rigorous research experiments.",
"user": (
"{preamble}\n\n"
"Design an experiment plan as YAML.\n"
"Required keys: objectives,datasets,baselines,proposed_methods,"
"ablations,metrics,risks,compute_budget.\n\n"
"NAMING REQUIREMENT (CRITICAL for paper quality):\n"
"- Every condition name in baselines, proposed_methods, and ablations MUST be "
"a DESCRIPTIVE algorithm name DERIVED FROM THE HYPOTHESES ABOVE, NOT a generic label.\n"
"- WRONG: baseline_1, baseline_2, method_variant_1, method_variant_2\n"
"- WRONG: random_search, bayesian_optimization, ppo_policy, curiosity_driven_rl "
"(these are generic defaults — NEVER use them unless they are actually what "
"the hypotheses call for)\n"
"- RIGHT: names that reflect the specific methods/architectures/algorithms in "
"the hypotheses (e.g., rim_agent, monolithic_gru, ewc_baseline, sleep_consolidation, "
"no_sleep_ablation, coarse_routing, fine_routing)\n"
"- The name should immediately tell a reader WHAT algorithm or strategy is used.\n"
"- This is critical because these names appear directly in the paper.\n\n"
"BASELINE & BENCHMARK MODERNITY (CRITICAL for acceptance):\n"
"- Baselines MUST be modern, widely-adopted methods from recent top-venue "
"papers (2023-2026). Beating only outdated or weak baselines is NOT a valid "
"contribution and will result in desk rejection.\n"
"- Include at LEAST one strong baseline that represents current SOTA or "
"near-SOTA in the specific sub-area. Check recent NeurIPS/ICML/ICLR papers "
"to identify appropriate baselines.\n"
"- Benchmarks MUST be standard and actively used. If a benchmark has been "
"superseded, use the newer version.\n"
"- For each baseline, cite the original paper and note why it is a fair "
"and competitive comparison.\n\n"
"HYPOTHESIS ALIGNMENT (CRITICAL — most common failure mode):\n"
"- Your experiment plan MUST directly test the hypotheses listed above.\n"
"- Each hypothesis should map to at least one comparison between conditions.\n"
"- Baselines must be the specific alternatives named in the hypotheses, NOT "
"generic optimization methods like random_search or bayesian_optimization.\n"
"- If a hypothesis says 'X outperforms Y', then X must be a proposed_method "
"and Y must be a baseline.\n"
"- Ablations must isolate the specific components claimed to matter in the "
"hypotheses (e.g., if hypothesis claims routing helps, ablate routing).\n\n"
"STABILITY & REPRODUCIBILITY (CRITICAL for RL-based methods):\n"
"- Under `proposed_methods`, specify key hyperparameters (learning rate, "
"gradient clip threshold, entropy coefficient, etc.).\n"
"- Under `risks`, explicitly list numerical stability concerns "
"(NaN/divergence, reward explosion, policy collapse) and mitigations "
"(gradient clipping, reward normalization, early stopping on NaN).\n"
"- Under `metrics`, include:\n"
" * Primary metric: `{metric_key}` with direction: `{metric_direction}` "
"and units\n"
" * IMPORTANT: The metric direction MUST be `{metric_direction}` — do "
"NOT use a different direction. If {metric_direction}=='minimize', lower "
"is better. If {metric_direction}=='maximize', higher is better.\n"
" * `success_rate`: fraction of seeds that complete without NaN/crash\n"
" * At least ONE discovery-aligned endpoint (e.g., identification "
"accuracy, time-to-discovery, final posterior mass on true hypothesis) "
"in addition to any proxy metric\n"
"{dataset_guidance}\n\n"
"- Under `datasets`, specify AT LEAST 2 regime factors to stratify by "
"(e.g., noise_level: [low, high], hypothesis_space_size: [small, large]). "
"Results MUST be reported per-regime. A single-regime experiment cannot "
"support generality claims and will be rejected by reviewers.\n"
"- FACTORIAL DESIGN PREFERRED: If you vary multiple factors (e.g., scale AND "
"noise), design a factorial grid (e.g., small+low, small+high, large+low, "
"large+high) so each factor's effect can be isolated. Bundling factors "
"(e.g., easy=small+low, hard=large+high) is a confounder and reviewers will "
"flag it. If computational budget limits the grid, at minimum acknowledge "
"that factors are bundled and limit claims accordingly.\n"
"- Under `compute_budget`, plan for minimum 10 seeds per condition to "
"ensure valid statistical comparisons.\n\n"
"STATISTICAL POWER REQUIREMENTS (CRITICAL for publishability):\n"
"- Use AT LEAST 5 random seeds per condition (10 preferred)\n"
"- Use AT LEAST 30 episodes per seed for RL methods\n"
"- Report: mean ± std, 95% bootstrap CI, per-seed raw values\n"
"- For method comparisons: use paired bootstrap or Wilcoxon signed-rank test "
"(NOT paired t-test with n < 10)\n"
"- Report effect sizes (Cohen's d or rank-biserial correlation)\n"
"- 3 seeds is INSUFFICIENT — reviewers will reject papers with n=3\n\n"
"HARDWARE ENVIRONMENT (your experiments run on THIS exact machine):\n"
"{hardware_profile}\n"
"- You have exactly ONE GPU. No distributed training. No multi-GPU. No multi-node.\n"
"- Design experiments that fit this single GPU.\n\n"
"COMPUTE BUDGET CONSTRAINT (CRITICAL — experiments MUST fit time budget):\n"
"- Total experiment time budget: {time_budget_sec} seconds.\n"
"- Per-condition budget: ~{per_condition_budget_sec} seconds "
"(= time_budget × 0.7 / 6 conditions).\n"
"- Pre-cached datasets (instant, no download): {available_tier1_datasets}\n"
"- DO NOT plan experiments requiring multiple GPUs or more than "
"{time_budget_sec}s.\n"
"- HARD CONDITION LIMIT: The total number of conditions (baselines + "
"proposed_methods + ablations) MUST NOT exceed 8 for budgets ≤ 3600s.\n"
" * Recommended: 2-3 baselines + 1-2 proposed methods + 2-3 ablations = 5-8 total.\n"
" * Generating 10+ conditions guarantees most will time out and data will be wasted.\n"
" * Quality over quantity: 6 well-run conditions with 5 seeds each >> 20 conditions "
"with 1 seed each.\n"
"- Each run needs AT LEAST 60 seconds for RL (environment setup + "
"training + evaluation). For deep learning with GPU, at least 120 seconds.\n"
"- HARD CAP: total_conditions × num_seeds × seconds_per_run MUST be < "
"{time_budget_sec} × 0.8 (leave 20% margin for overhead).\n"
"- If total would exceed the budget, you MUST reduce by:\n"
" 1. First: reduce conditions (merge similar ablations, keep strongest baselines)\n"
" 2. Then: reduce seeds to 5 (minimum for statistical validity)\n"
" 3. Then: reduce regimes/environments to 1\n"
"- Example: {time_budget_sec}s budget with 120s/condition/seed, 5 seeds → "
"max {time_budget_sec} / (120 * 5) ≈ 4 conditions.\n\n"
"IMPLEMENTATION SPECIFICATION (CRITICAL for code generation):\n"
"For each proposed method AND each baseline, you MUST include an "
"'implementation_spec' key with:\n"
" - class_name: the Python class name for this method\n"
" - key_methods: list of methods the class must implement "
"(e.g., [__init__, forward, train_step, predict])\n"
" - algorithm_steps: pseudocode-level description of the core algorithm "
"(3-10 steps), e.g.:\n"
" 1. Encode input via encoder network (MLP: input_dim -> hidden_dim)\n"
" 2. Compute attention weights over memory buffer\n"
" 3. Aggregate attended features with learned gate\n"
" 4. Decode to output via decoder network\n"
" - loss_function: the mathematical formula for the training loss "
"(e.g., 'L = CE(y_pred, y_true) + lambda * KL(q||p)')\n"
" - key_hyperparameters: dict of hyperparameter name -> default value\n"
" - differentiator: what makes THIS method different from others "
"(must be an algorithmic difference, not just a hyperparameter change)\n\n"
"For each ablation, you MUST specify:\n"
" - what_is_removed: the specific component being ablated\n"
" - how_it_differs: concrete code-level description of the change "
"(e.g., 'replace attention layer with mean pooling', 'set routing "
"weight to uniform 1/N', 'remove skip connection in block 3')\n"
" - expected_effect: why removing this should change results\n\n"
"This specification is MANDATORY — without it, the code generation "
"stage cannot produce correct implementations.\n\n"
"Hypotheses:\n{hypotheses}"
),
},
"code_generation": {
"system": (
"You are a computational scientist who writes real, runnable "
"experiments. Your code implements actual algorithms with real "
"mathematical operations. You NEVER fake results with random number "
"generators. Always use the ```filename:xxx.py format for each file. "
"Use numpy for numerical computation. Keep code self-contained "
"and deterministic."
),
"user": (
"Generate a Python experiment project for the following research "
"topic:\n"
"TOPIC: {topic}\n\n"
"CRITICAL REQUIREMENTS — your code MUST satisfy ALL of these:\n"
"1. Implement the ACTUAL experiment described in the topic and "
"plan below.\n"
" If the topic is about simulation (e.g., multi-agent systems, "
"network dynamics),\n"
" write simulation code. If about optimization, write "
"optimization code.\n"
" Match the code to the topic — do NOT default to generic "
"gradient descent.\n"
"2. Use proper mathematical models appropriate to the research "
"question.\n"
" Examples: agent-based simulation, graph algorithms, "
"statistical analysis,\n"
" optimization, Monte Carlo methods — whatever fits the topic.\n"
"3. Run REAL computational experiments with meaningful "
"parameters.\n"
"4. Collect REAL metrics that directly answer the research "
"question.\n"
"5. The code must be scientifically meaningful — a reviewer should "
"see\n"
" actual implementations relevant to the TOPIC, not a generic "
"optimizer.\n\n"
"OUTPUT FORMAT — return multiple files using this exact format:\n"
"```filename:main.py\n"
"# entry point code\n"
"```\n\n"
"```filename:models.py\n"
"# model/algorithm implementations\n"
"```\n\n"
"Only create additional files (optimizers.py, data_utils.py, etc.) "
"if they contain substantial logic (>20 lines). Do NOT create stub "
"files with only imports or pass statements.\n\n"
"CODE STRUCTURE:\n"
"- main.py: entry point that runs experiments and prints metrics\n"
"- main.py MUST begin with a docstring specifying:\n"
" (a) Dataset used and how it is loaded\n"
" (b) Distribution shift / corruption definition (if applicable)\n"
" (c) Model architecture (layers, dimensions, activation)\n"
" (d) Training protocol (optimizer, epochs, batch size, LR schedule)\n"
" (e) Evaluation protocol (train/test split, metrics computed)\n"
"- Additional modules for algorithms, objective functions, "
"utilities\n"
"- Primary metric key: {metric}\n"
"- main.py must print metric lines as `name: value` (one per "
"line)\n"
"- Use deterministic seeds (numpy.random.seed or random.seed)\n"
"- No external data files, no network calls, no GPU required\n"
"- FORBIDDEN: subprocess, os.system, eval, exec, shutil, socket\n"
"{pkg_hint}\n"
"ANTI-PATTERNS (do NOT do these):\n"
"- Do NOT generate random numbers and pretend they are experiment "
"results\n"
"- Do NOT use `random.uniform()` to simulate a decreasing loss "
"curve\n"
"- Do NOT hardcode metric values or use trivial arithmetic as "
"metrics\n\n"
"MULTI-CONDITION REQUIREMENT (CRITICAL):\n"
"The experiment plan below specifies multiple conditions, treatments, "
"or strategies to compare. Your code MUST:\n"
"1. Implement ALL conditions/treatments listed in the experiment plan "
"— not just one baseline.\n"
"2. Run each condition independently with the same controlled setup "
"(same seeds, same initialization, same budget).\n"
" IMPORTANT: All conditions MUST be iterated INSIDE main.py using a "
"for-loop or dispatch table. NEVER use argparse --condition or any CLI "
"argument to select a condition. The harness calls `python main.py` "
"with NO arguments — if you add a required --condition arg it will crash.\n"
"3. Print metrics with condition labels: "
"`condition= {metric}: ` for EACH condition.\n"
"4. After all conditions, print a summary comparison line: "
"`SUMMARY: condition1=, condition2=, ...`\n"
"5. If the plan has N conditions, the output MUST contain N separate "
"labeled metric streams. Running only one condition is NOT acceptable.\n"
"6. BREADTH-FIRST ORDERING: Run ONE representative configuration per "
"condition FIRST (e.g., default parameters), so that ALL conditions "
"produce at least one result. Only AFTER all conditions have results, "
"run additional parameter sweeps if time remains. This prevents the "
"time budget from being exhausted on condition 1's parameter sweep "
"while conditions 2..N never execute.\n"
"7. CONDITION COMPLETENESS: After code generation, mentally verify that "
"EVERY condition in the experiment plan below has a corresponding code "
"path. If the plan lists conditions A, B, C, D — your code must handle "
"all four, not just A, B, C. Missing conditions invalidate the experiment.\n"
"8. CRASH RESILIENCE: Wrap each condition's execution in a try/except "
"block so that if one condition crashes (e.g., NaN, timeout, config error), "
"the remaining conditions still execute. Print `CONDITION_FAILED: "
"` on failure and continue to the next condition. A partial result "
"set is far more valuable than a complete crash.\n"
"9. CONDITION REGISTRY VALIDATION: At startup (before running experiments), "
"enumerate all condition names and verify each has a valid code path. Print "
"`REGISTERED_CONDITIONS: , , ...` at the top of output. If "
"any condition is unrecognized, print `MISSING_CONDITION: ` and skip "
"it gracefully rather than raising an exception.\n"
"10. TOTAL CONDITIONS LIMIT (HARD RULE): Your code MUST NOT register more "
"than 8 total conditions. If the experiment plan lists ablations with many "
"parameter values (e.g., 'test decay rates 0.9, 0.99, 0.995, 0.999, 0.9999'), "
"pick the 2-3 most informative values — do NOT create a separate condition for "
"each value. 8 conditions × 3 seeds × budget ÷ conditions = tight timing. "
"Quality of each condition matters more than quantity.\n\n"
"METRIC DEFINITION REQUIREMENT (CRITICAL):\n"
"- At the top of main.py, include a docstring or comment block that defines:\n"
" * METRIC NAME: the exact key printed as `{metric}: `\n"
" * DIRECTION: {metric_direction_hint}\n"
" * UNITS/SCALE: what the number represents (e.g., MSE in log scale, "
"accuracy 0-1, discovery rate per episode)\n"
" * FORMULA: how the metric is computed from raw experiment outputs\n"
" * AGGREGATION: how per-step/per-episode values are reduced to a scalar\n"
"- Print this definition at runtime: `METRIC_DEF: {metric} | direction= "
"| desc=`\n"
"- Without this definition, the metric is UNINTERPRETABLE and the paper cannot "
"make any claims about which method is better.\n\n"
"STATISTICAL RIGOR REQUIREMENT:\n"
"- Run each condition with at least 5 different random seeds (10+ preferred "
"if time budget allows). Minimum 3 seeds is MANDATORY.\n"
"- Print per-seed results: `condition= seed= {metric}: `\n"
"- Print mean and std across seeds: "
"`condition= {metric}_mean: {metric}_std: `\n"
"- If time budget is tight, reduce per-seed iterations rather than "
"reducing seed count. Minimum 3 seeds is non-negotiable.\n"
"- SEED COUNT IS FIXED AT 3 MINIMUM. Do NOT compute seed count dynamically.\n"
" Hardcode `SEEDS = [0, 1, 2]`. If 3 seeds × all conditions exceeds the time "
"budget, REDUCE the number of conditions or training epochs — NEVER reduce seeds.\n"
" Print: `SEED_COUNT: 3 (fixed minimum, budget={time_budget}s, conditions=N)`.\n"
"- Report bootstrap 95% confidence intervals when n >= 5.\n\n"
"FAILURE-AWARE REPORTING (CRITICAL for RL/unstable methods):\n"
"- Track how many seeds succeed vs fail (NaN, divergence, crash) per "
"condition. Print: `condition= success_rate: /`\n"
"- Compute UNCONDITIONAL metrics: treat failed seeds as worst-case "
"(e.g., metric=0 or metric=worst_baseline). Print: "
"`condition= unconditional_{metric}_mean: `\n"
"- This prevents survivorship bias where a method looks good only "
"because failed runs are excluded.\n"
"- For RL methods, add STABILITY SAFEGUARDS in the code:\n"
" * Gradient clipping (max norm 1.0)\n"
" * Reward normalization/clipping to [-10, 10]\n"
" * NaN checks on loss/gradients with graceful early stop (not crash)\n"
" * Learning rate warmup or conservative initial learning rate\n"
" These safeguards should PREVENT most NaN/divergence, not just catch "
"them after the fact.\n\n"
"PYTORCH RL IMPLEMENTATION BUGS (CRITICAL — these cause 100% crash rate):\n"
"- 'Trying to backward through the graph a second time' is the #1 crash.\n"
" CAUSE: reusing a computed tensor across multiple backward() calls.\n"
" FIX: Always .detach() values used in the next iteration:\n"
" ```\n"
" # WRONG:\n"
" old_log_prob = policy.log_prob(action) # still attached to graph\n"
" # ... later in update loop:\n"
" ratio = new_log_prob / old_log_prob # backward crashes\n"
" \n"
" # CORRECT:\n"
" old_log_prob = policy.log_prob(action).detach() # detach!\n"
" # ... later in update loop:\n"
" ratio = new_log_prob / old_log_prob.detach() # safe\n"
" ```\n"
"- For PPO: old_log_probs MUST be .detach()ed when stored for later ratio computation.\n"
"- For value functions: target values MUST be .detach()ed (don't backprop through targets).\n"
"- For curiosity/intrinsic reward: prediction errors used as reward MUST be .detach()ed.\n"
"- General rule: any tensor from a PREVIOUS forward pass that is used in the CURRENT "
"loss computation MUST be .detach()ed.\n"
"- When in doubt, add .detach() — it never causes crashes, but missing it always does.\n\n"
"NEURAL NETWORK DIMENSION CONSISTENCY (CRITICAL — #2 crash cause):\n"
"- 'input and weight.T shapes cannot be multiplied' means obs_dim != network input_dim.\n"
"- When the environment observation size VARIES across regimes (e.g., easy=6, hard=8), "
"the neural network's input layer MUST match EACH regime's obs_dim.\n"
"- FIX: Create the network INSIDE the per-regime loop, or parameterize input_dim:\n"
" ```\n"
" # WRONG: fixed input_dim for all regimes\n"
" policy = PolicyNet(input_dim=10) # breaks if obs_dim != 10\n"
" for regime in regimes:\n"
" obs = env.reset() # obs.shape may vary!\n"
" \n"
" # CORRECT: dynamic input_dim per regime\n"
" for regime in regimes:\n"
" obs = env.reset()\n"
" obs_dim = obs.shape[-1] # or len(obs)\n"
" policy = PolicyNet(input_dim=obs_dim) # fresh network per regime\n"
" ```\n"
"- ALWAYS initialize neural networks AFTER knowing the observation dimension.\n\n"
"KNOWLEDGE DISTILLATION (KD) STABILITY (if applicable):\n"
"- Teacher network MUST be frozen: `teacher.eval()` and "
"`for p in teacher.parameters(): p.requires_grad = False`\n"
"- Temperature parameter T: typical range 1-20. Use T=4 as default. "
"NEVER use T<1 (causes sharp distributions → NaN gradients).\n"
"- Loss balance: `loss = alpha * kd_loss + (1-alpha) * task_loss` — "
"set alpha=0.5-0.9. If kd_loss scale >> task_loss, val_loss becomes NaN.\n"
"- PROJECTION LAYERS: If teacher and student have different intermediate "
"dimensions (e.g., teacher_dim=768, student_dim=256), you MUST add "
"`nn.Linear(student_dim, teacher_dim)` to align features before computing "
"distillation loss. Without projection layers, tensor shape mismatch WILL crash.\n"
"- Common KD NaN causes: (1) no temperature scaling on logits, "
"(2) missing gradient clipping, (3) learning rate too high (use ≤1e-3), "
"(4) teacher not frozen → unstable targets.\n\n"
"PAIRED STATISTICAL ANALYSIS (CRITICAL for publishable results):\n"
"- Use the SAME random seeds across all conditions so results are paired.\n"
"- After collecting per-seed results for all conditions, compute paired "
"differences: for each seed s, diff(s) = method(s) - baseline(s).\n"
"- Print paired analysis: "
"`PAIRED: vs mean_diff= std_diff= "
"t_stat= p_value=`\n"
"- Also print bootstrap 95% CI of the paired difference.\n"
"- This is FAR more powerful than independent comparisons because it "
"controls for seed-to-seed variance.\n\n"
"MULTI-REGIME REQUIREMENT (CRITICAL for generality claims):\n"
"- The experiment MUST test at least 2 different difficulty/noise regimes "
"(e.g., low noise vs high noise, small hypothesis space vs large).\n"
"- Report results per-regime, not just aggregated across regimes.\n"
"- Print regime labels: "
"`condition= regime= {metric}: `\n"
"- This prevents conclusions that only hold in one setting from being "
"presented as general findings.\n\n"
"DIMENSION CONSISTENCY CHECK (CRITICAL for RL/neural methods):\n"
"- Before passing observations/states to neural networks or policy "
"parameters, VERIFY that dimensions match. Common bug: environment "
"state has dimension D1 but network expects D2.\n"
"- At the start of each condition, print the state/observation "
"dimension and the network input dimension. If they mismatch, "
"reshape or adjust the network before proceeding.\n"
"- Test EVERY condition with a single dry-run step before the full "
"loop to catch shape mismatches early.\n\n"
"TIME-TO-EVENT METRIC BUG PREVENTION (CRITICAL — common silent bug):\n"
"- If the primary metric is a 'time-to-X' measure (e.g., time-to-discovery, "
"steps-to-convergence, episodes-to-threshold), you MUST check the success "
"criterion at EVERY step inside the loop, not only at the end.\n"
"- WRONG pattern (produces degenerate ceiling data):\n"
" ```\n"
" for t in range(horizon):\n"
" obs, r, done, info = env.step(a)\n"
" success = check(info) # only checked ONCE at end\n"
" time_to_X = horizon if not success else t + 1 # t+1 = horizon always!\n"
" ```\n"
"- CORRECT pattern (captures actual first-success time):\n"
" ```\n"
" time_to_X = horizon # default: never succeeded\n"
" for t in range(horizon):\n"
" obs, r, done, info = env.step(a)\n"
" if check(info) and time_to_X == horizon: # first success\n"
" time_to_X = t + 1\n"
" if done: break\n"
" ```\n"
"- This bug causes ALL methods to return the same ceiling value, making "
"the entire experiment useless. Every method looks identical at the cap.\n"
"- APPLY THIS TO ALL CONDITIONS: RandomSearch, BO, RL — every single "
"condition must check at every step. If even one condition uses the wrong "
"pattern, the comparison is invalid.\n\n"
"METRIC DISCRIMINATION VALIDATION (CRITICAL):\n"
"- After running all conditions, check if all conditions produce the SAME "
"mean metric value. If they do, the metric is NOT discriminative and the "
"experiment is scientifically useless.\n"
"- Common causes: ceiling/floor effects, too-easy or too-hard tasks, "
"time-to-event bug above, metric that doesn't capture real differences.\n"
"- If all conditions have identical means, print "
"`WARNING: DEGENERATE_METRICS all conditions have same mean=` "
"and you MUST take corrective action:\n"
" (a) If all means = 1.0 or max: increase task difficulty (reduce budget, "
"increase noise, enlarge hypothesis space)\n"
" (b) If all means = 0.0: decrease difficulty\n"
" (c) Re-run after adjustment and verify means now differ\n"
" (d) If adjustments don't help, switch to a different primary metric\n"
"- A degenerate experiment CANNOT produce a publishable paper. Fix it.\n\n"
"DIFFICULTY CALIBRATION (CRITICAL for meaningful results):\n"
"- After running a pilot (3-5 seeds, 2 conditions: random_search + one RL), "
"check BOTH success rate AND metric discrimination.\n"
"- TWO things must be true for the experiment to be informative:\n"
" 1. Success rate between 30-80% (not too hard, not too easy)\n"
" 2. Primary metric varies across conditions (not all methods score the same)\n"
"- CEILING DETECTION (CRITICAL): If primary_metric is 1.0 (or max possible) "
"for ALL pilot seeds in ALL pilot conditions, the task is TRIVIALLY EASY. "
"You MUST increase difficulty until the metric varies. Options:\n"
" * Reduce experiment budget/horizon (fewer steps to find solution)\n"
" * Increase hypothesis space size\n"
" * Increase observation noise\n"
" * Tighten the success criterion (e.g., require closer match)\n"
" * Reduce the number of allowed experiments per episode\n"
"- FLOOR DETECTION: If primary_metric is 0.0 for all conditions, task is "
"too hard. Reduce noise, enlarge budget, simplify.\n"
"- Print `CALIBRATION: regime= pilot_success_rate= "
"pilot_primary_metric_std=` after calibration.\n"
"- If std=0, the metric is NOT discriminative — adjust until std > 0.\n"
"- Run a calibration loop: pilot → check → adjust → re-pilot (max 3 iterations).\n\n"
"ALGORITHM IMPLEMENTATION INTEGRITY (CRITICAL — mismatch = academic fraud):\n"
"1. If you name a method 'Bayesian Optimization', you MUST implement:\n"
" - A surrogate model (e.g., Gaussian Process or random forest)\n"
" - An acquisition function (e.g., Expected Improvement, UCB)\n"
" - Surrogate model updates after each observation\n"
" DO NOT implement UCB1 bandit and call it 'Bayesian Optimization'.\n"
"2. If you name a method 'PPO', you MUST implement:\n"
" - A clipped surrogate objective: min(r_t * A_t, clip(r_t, 1-eps, 1+eps) * A_t)\n"
" - A learned value function baseline\n"
" - The clip_eps parameter MUST be used in the policy update\n"
" DO NOT implement vanilla REINFORCE and call it 'PPO'.\n"
"3. Every declared hyperparameter MUST be used in the algorithm:\n"
" - If you declare clip_eps, it must appear in the loss computation\n"
" - If you declare entropy_coef, it must be added to the policy loss\n"
" - Dead parameters (declared but never used) are strictly forbidden\n"
"4. Ablation conditions MUST produce different behavior:\n"
" - Two conditions that differ only in a parameter that is never read are IDENTICAL\n"
" - Verify: if two conditions produce identical outputs on the same seed, "
"the ablation is broken and MUST be fixed\n"
" ABLATION DESIGN PATTERN (CRITICAL — #1 cause of broken ablations):\n"
" - 'no_key_component': Must REMOVE a core algorithmic component "
"(e.g., disable the graph structure by zeroing the adjacency, or remove "
"the contrastive loss, or disable the RL policy and use random actions). "
"The removal MUST change the forward() / step() computation.\n"
" - 'reduced_capacity': Must REDUCE model capacity by at least 2x "
"(e.g., halve hidden dimensions, reduce layers, shrink embedding size). "
"This MUST create a new model with different architecture, NOT just "
"rename a parameter with the same value.\n"
" - SELF-TEST: After implementing ablations, add a startup check that "
"runs one forward pass per condition on the SAME input and asserts outputs "
"differ. Print: `ABLATION_CHECK: vs outputs_differ=True`.\n"
" - If outputs are identical, the ablation is BROKEN — do not proceed.\n\n"
"CODE IMPLEMENTATION DEPTH (CRITICAL — shallow code = reject):\n"
"- Each algorithm/method MUST be a separate Python class with genuine logic.\n"
"- Each class MUST have at least: __init__(), and one core method "
"(forward/predict/train_step/step) with non-trivial implementation.\n"
"- The core method of the MAIN proposed method MUST be at least 20 lines "
"of effective code (excluding comments, blanks, imports).\n"
"- FORBIDDEN patterns that will be detected and rejected:\n"
" * `class MethodB(MethodA): pass` — empty subclass\n"
" * Two classes with identical method bodies but different names\n"
" * nn.Linear/nn.Conv2d created inside forward() instead of __init__()\n"
" * Variables defined only inside an if-branch but used after the branch\n"
" * Using np.erf() (doesn't exist — use scipy.special.erf or math.erf)\n"
" * Using ndarray.ptp() (removed in NumPy 2.0 — use np.ptp(arr) or arr.max()-arr.min())\n"
" * Using np.bool, np.int, np.float, np.complex (removed in NumPy 2.0 — use np.bool_, np.int64, etc.)\n"
" * Replacing real model training with synthetic utility functions or random scores\n"
" * Using dict[key] without ensuring key exists — use dict.get(key, default) "
"or verify key is in dict before access\n"
"- If the experiment plan includes 'implementation_spec', you MUST follow "
"the pseudocode steps exactly. Each algorithm_step should correspond to "
"1-3 lines of code in the class.\n"
"- Ablation variants MUST modify the forward() or step() logic, not just "
"change a hyperparameter value.\n\n"
"MINIMUM SEED COUNT (CRITICAL — 3 seeds = unpublishable):\n"
"- Use AT LEAST 5 random seeds per condition (10 preferred if time permits)\n"
"- Use AT LEAST 30 episodes per seed for RL methods\n"
"- When computing bootstrap CIs, use at least 1000 bootstrap samples\n"
"- For method comparisons: use paired bootstrap or Wilcoxon signed-rank test\n"
"- Report effect sizes (Cohen's d) alongside p-values\n\n"
"Experiment plan:\n{exp_plan}"
),
"max_tokens": 8192,
},
"resource_planning": {
"system": "You are an experiment scheduler.",
"user": (
"Create schedule JSON with GPU/time estimates.\n"
"Schema: {tasks:[{id,name,depends_on,gpu_count,estimated_minutes,"
"priority}], total_gpu_budget, generated}.\n"
"Experiment plan:\n{exp_plan}"
),
"json_mode": True,
},
# ── Phase F: Analysis & Decision ─────────────────────────────────────
"result_analysis": {
"system": (
"You are a quantitative research analyst. Always cite exact numbers "
"from the provided data."
),
"user": (
"{preamble}\n\n"
"{data_context}\n\n"
"Analyze run metrics and produce markdown report with statistical "
"interpretation.\n"
"Use the ACTUAL quantitative values provided above — do NOT invent "
"numbers.\n\n"
"SANITY CHECKS (perform BEFORE interpreting results):\n"
"1. MONOTONICITY: If a condition scales a parameter (e.g., N agents, "
"model size), check whether metrics move in the expected direction. "
"If accuracy *decreases* when adding more agents under majority voting, "
"flag this as a likely implementation bug (vote parsing, normalization, "
"or aggregation issue).\n"
"2. BASELINE PLAUSIBILITY: Random-chance baselines should match "
"theoretical expectations (e.g., 1/K for K-class classification).\n"
"3. CROSS-CONDITION CONSISTENCY: Results across datasets or conditions "
"should be internally coherent — wildly different patterns may indicate "
"confounds or bugs.\n"
"4. REPLICATION: If results are from a single seed (n=1), explicitly "
"note that no statistical significance claims can be made.\n"
"5. ABLATION ISOLATION: Compare per-seed values across conditions. If "
"two conditions produce IDENTICAL values for the same seed, this is a "
"RED FLAG — the ablation/variant may not have actually changed the code "
"path (e.g., config not applied, caching, shared state). Flag this "
"explicitly and recommend a config/registry audit.\n"
"6. METRIC DEFINITION CHECK: Look for a `METRIC_DEF:` line in the output. "
"If absent, flag that the primary metric is UNDEFINED — direction, units, "
"and formula are unknown, making all comparisons uninterpretable. This is "
"a critical methodology gap.\n"
"7. CONDITION COMPLETENESS CHECK: Look for `REGISTERED_CONDITIONS:` in "
"the output. Compare against the experiment plan. If conditions are missing "
"or failed (look for `CONDITION_FAILED:`), list them explicitly and assess "
"whether the remaining conditions can still answer the research question.\n"
"8. DEGENERATE METRICS CHECK: If ALL conditions (or all but one) produce "
"the SAME mean primary metric value, flag this as DEGENERATE — the metric "
"is NOT discriminative. Common causes: (a) time-to-event metric that only "
"checks success at the final step (returns horizon for all methods), "
"(b) ceiling/floor effects from too-easy or too-hard tasks, "
"(c) metric capped at a budget value. This makes the experiment "
"scientifically useless — recommend REFINE with a note to fix the metric "
"computation or task difficulty. Look for `WARNING: DEGENERATE_METRICS` "
"in stdout. Even if not printed, check the numbers yourself.\n\n"
"Required sections: Metrics Summary (with real values), "
"Consensus Findings (high confidence), "
"Contested Points (with evidence-based resolution), "
"Statistical Checks, Methodology Audit, Limitations, Conclusion.\n"
"In the Conclusion, include:\n"
"- Result quality rating (1-10)\n"
"- Key findings (3-5)\n"
"- Methodology gaps to address next\n"
"- Recommendation: PROCEED / REFINE / PIVOT\n\n"
"Run context:\n{context}"
),
"max_tokens": 8192,
},
"research_decision": {
"system": "You are a research program lead making go/no-go decisions.",
"user": (
"Based on the analysis, make one of three decisions:\n"
"- **PROCEED** — results are sufficient, move to paper writing\n"
"- **PIVOT** — hypotheses are fundamentally flawed, generate new ones\n"
"- **REFINE** — hypotheses are sound but experiments need re-tuning\n\n"
"MINIMUM QUALITY CRITERIA for PROCEED (ALL must be met):\n"
"1. At least 2 baselines AND the proposed method have results\n"
"2. The primary metric is defined (direction, units known)\n"
"3. Each condition has results from ≥3 seeds\n"
"4. No identical per-seed values across different conditions (ablation integrity)\n"
"5. The analysis quality rating is ≥4/10\n"
"If ANY criterion is not met, you MUST choose REFINE (not PROCEED).\n\n"
"Output markdown with sections:\n"
"## Decision\n"
"State exactly one of: PROCEED, PIVOT, or REFINE\n\n"
"## Justification\n"
"Why this decision is warranted based on evidence.\n\n"
"## Evidence\n"
"Key data points supporting the decision.\n\n"
"## Next Actions\n"
"Concrete steps for the chosen path.\n\n"
"Analysis:\n{analysis}"
),
},
# ── Phase G: Paper Writing ───────────────────────────────────────────
"paper_outline": {
"system": "You are an academic writing planner for top-tier AI conferences.",
"user": (
"{preamble}\n\n"
"{academic_style_guide}\n\n"
"Create a detailed paper outline in markdown.\n"
"Include per-section goals, word count targets, and evidence links.\n"
"The outline MUST include a catchy method name (2-5 chars) for the paper title.\n"
"Propose 3 candidate titles following the 'MethodName: Subtitle' format "
"(each <= 14 words). Rate each on memorability (1-5), specificity (1-5), "
"and novelty signal (1-5).\n"
"{topic_constraint}"
"{feedback}"
"Analysis:\n{analysis}\n\nDecision:\n{decision}"
),
"max_tokens": 8192,
},
"paper_draft": {
"system": (
"You are a top-tier academic paper author writing for leading venues.\n\n"
"KEY PRINCIPLES (from accepted paper analyses):\n"
"1. NOVELTY: A good paper has 1-2 key ideas and keeps the rest simple.\n"
"2. NARRATIVE: A short, rigorous, evidence-based technical story with a takeaway.\n"
"3. STRONG BASELINES: Invest real effort in making baselines competitive.\n"
"4. ABLATIONS: Remove one component at a time and measure the effect.\n"
"5. HONESTY: Acknowledge limitations explicitly.\n"
"6. REPRODUCIBILITY: Include all details needed to reproduce results.\n\n"
"EVIDENCE-BOUNDING RULES (CRITICAL — violation = reject):\n"
"7. EVERY claim in the title, abstract, and conclusion MUST be directly "
"supported by specific experimental metrics provided below.\n"
"8. If the experiment only covers partial conditions, the title MUST NOT "
"make global causal claims. Use 'Toward...', 'Investigating...', or "
"'An Empirical Study of...' instead of 'X Dominates Y'.\n"
"9. BEFORE writing the title, list the conditions actually tested and "
"their metric values. The title must only claim what those numbers show.\n"
"10. If a metric is a single scalar without condition labels, do NOT "
"claim comparative results between strategies/methods.\n"
"11. Distinguish between 'we propose and validate' (has full results) vs "
"'we propose and present preliminary evidence' (partial results).\n\n"
"You ONLY use real experimental data — never fabricate or approximate numbers.\n\n"
"METHOD SECTION REQUIREMENTS:\n"
"12. The Method section MUST include ALL implementation details needed "
"for reproduction: algorithm pseudocode or step-by-step description, "
"hyperparameters (learning rate, clipping, discount factor, etc.), "
"state/observation representation, reward definition, and baseline "
"configurations.\n"
"13. For learning-based methods: specify model architecture, training procedure "
"(iterations, epochs, batch handling), and any stability "
"mechanisms (regularization, normalization).\n"
"14. For baselines: specify the exact algorithm/method configuration "
"and any tuning performed to make baselines competitive.\n\n"
"FAILURE-AWARE REPORTING REQUIREMENTS:\n"
"15. If any method has a success rate < 100%, the Results section "
"MUST report success rates per method and explain inclusion/exclusion "
"criteria.\n"
"16. Report BOTH conditional metrics (successful runs only) AND "
"unconditional metrics (treating failures as worst-case). Without "
"both, comparative claims are biased by survivorship.\n"
"17. The Limitations section MUST discuss stability/reliability "
"if any method showed NaN/divergence/crashes.\n\n"
"BENCHMARK & ENVIRONMENT SPECIFICATION:\n"
"18. The Experiments section MUST fully specify the evaluation "
"environment: state/observation space, action space, hypothesis space, "
"noise model, episode length, and any randomization procedures.\n"
"19. Report results PER REGIME (e.g., per noise level, per problem "
"size) with separate tables or sub-sections. Aggregated-only results "
"cannot support claims about robustness or generality.\n"
"20. Include a table comparing all methods across all regimes with "
"paired statistical tests (bootstrap CI of paired differences, or "
"paired t-test p-values). Without this, comparative claims lack "
"statistical grounding.\n\n"
"METHOD NAMING RULES:\n"
"21. NEVER use generic labels like 'baseline_1', 'method_variant_1', "
"'method_variant_2' in the paper. Use descriptive algorithm/method names "
"that reflect what the method actually does. Generic labels make the paper "
"scientifically uninterpretable.\n"
"22. Each method MUST have a full description: architecture, "
"training procedure, key hyperparameters, and implementation details. "
"A reader should be able to reimplement every method from the paper alone.\n\n"
"STATISTICAL REPORTING (MANDATORY for acceptance):\n"
"23. EVERY result table MUST include 95% confidence intervals "
"(mean +/- CI or [low, high]).\n"
"24. EVERY comparison claim ('A outperforms B') MUST cite p-value. "
"If p >= 0.05, write: 'The difference is not statistically significant.'\n"
"25. If the proposed method does NOT statistically significantly "
"outperform a baseline, do NOT claim superiority. Reframe as "
"'comparable', 'competitive', or 'negative result'.\n\n"
"WRITING STYLE RULES:\n"
"26. DO NOT repeat disclaimers like 'due to computational constraints, "
"this analysis was not conducted' more than once. State each limitation "
"ONCE in the Limitations section.\n"
"27. The Limitations section should be concise (200-400 words) listing "
"3-5 key limitations. Do NOT scatter limitation disclaimers throughout "
"every section.\n"
"28. Focus 80% of the paper on WHAT YOU DID and WHAT YOU FOUND, not "
"on what you could not do. Positive scientific contribution should "
"dominate the paper.\n"
"29. Cite 25-40 unique references in the paper body. The Related Work "
"section alone should cite at least 15 references. Cite only directly "
"relevant work — do NOT pad with tangentially related papers.\n"
"30. CITE ORIGINAL PAPERS: When discussing a technique (e.g., Batch "
"Normalization, ResNet, Adam, PPO), ALWAYS cite the original paper that "
"introduced it. Do NOT cite a survey or follow-up instead of the original. "
"The available references list includes foundational papers — use them.\n"
"31. BASELINE MODERNITY: When discussing baselines and comparisons, ensure "
"the paper acknowledges whether the baselines represent current practice. "
"If baselines are older methods, explicitly discuss why they were chosen "
"and acknowledge stronger modern alternatives exist."
),
"user": (
"{preamble}\n\n"
"{academic_style_guide}\n"
"{narrative_writing_rules}\n"
"{anti_hedging_rules}\n"
"{anti_repetition_rules}\n"
"Write a full paper draft section by section in markdown.\n"
"Required sections: Title, Abstract, Introduction, Related Work, "
"Method, Experiments, Results, Discussion, Limitations, Broader Impact, "
"Conclusion, References.\n"
"The Broader Impact section (2-3 paragraphs) MUST discuss: "
"(1) potential positive societal impacts of this work, "
"(2) potential negative societal impacts or risks, "
"(3) ethical considerations specific to this research area. "
"This section is MANDATORY for top ML venues and recommended for all research papers.\n"
"{writing_structure}\n"
"{topic_constraint}"
"{exp_metrics_instruction}"
"{citation_instruction}"
"All experimental results MUST be presented in LaTeX tables or inline prose. "
"Raw metric path formats like 'method/env/step/metric: value' are FORBIDDEN "
"in the paper text. Convert all data to clean, formatted presentation.\n"
"The paper MUST fit within 10 pages (excluding references and appendix). "
"Aim for 8-9 pages of main content. Be concise.\n"
"FIGURE RULES: When referencing figures, use ONLY \\ref{fig:label} cross-references. "
"NEVER add bold standalone paragraphs like '**Figure 1.**' after figure environments. "
"Do NOT add \\clearpage before or after figures/tables unless absolutely necessary.\n"
"TABLE RULES: Tables MUST use standard LaTeX tabular syntax with bare braces: "
"\\begin{tabular}{lcc}, NOT \\begin{tabular}\\{lcc\\}. "
"NEVER use '--' as placeholder values in table cells. "
"If a metric is unavailable, write 'N/A' or omit the row entirely.\n"
"Outline:\n{outline}"
),
"max_tokens": 16384,
},
"peer_review": {
"system": "You are a balanced conference reviewer.",
"user": (
"Simulate peer review from at least 3 reviewer perspectives.\n"
"Output markdown with Reviewer A (methodology expert), "
"Reviewer B (domain expert), and Reviewer C (statistics/rigor expert), "
"each including strengths, weaknesses, and actionable revisions.\n\n"
"Check specifically:\n"
"1. TOPIC ALIGNMENT: Does the paper stay on topic ({topic})? "
"Flag any sections where the paper drifts to unrelated topics or "
"presents environment issues as contributions.\n"
"2. CLAIM-EVIDENCE ALIGNMENT: For EACH claim in the title, abstract, "
"and conclusion, verify there is a specific metric/table/figure in "
"the Results section supporting it. Flag unsupported claims.\n"
"3. STATISTICAL VALIDITY: Are confidence intervals or error bars "
"reported? Is n>1 (multiple seeds)? Are significance tests appropriate?\n"
"4. COMPLETENESS: Does the paper have all required sections with "
"sufficient depth? A NeurIPS paper body should be 5,000-6,500 words.\n"
"5. REPRODUCIBILITY: Are hyperparameters, random seeds, compute "
"resources, and dataset details fully specified?\n"
"6. WRITING QUALITY: Is the paper written in flowing prose or bullet lists? "
"Flag any bullet-point lists in Method/Results/Discussion. Check for "
"excessive hedging ('we do not claim'). Verify title is <= 14 words.\n"
"7. FIGURES: Does the paper include at least 2 figures? Zero figures = desk reject.\n"
"8. CITATION DISTRIBUTION: Are citations only in Intro/Related Work? "
"Method, Experiments, and Discussion MUST also cite relevant papers.\n\n"
"Paper draft:\n{draft}\n\n"
"Experiment evidence for verification:\n{experiment_evidence}"
),
"max_tokens": 8192,
},
"paper_revision": {
"system": (
"You are a paper revision expert.\n\n"
"TITLE AND ABSTRACT ALIGNMENT (CRITICAL):\n"
"- After reviewing experimental evidence, UPDATE the title if results "
"do not support the original claim.\n"
"- If the proposed method does NOT beat baselines, use a title like "
"'An Empirical Study of...', 'When X Falls Short: ...', or "
"'Investigating ... : Negative Results and Insights'.\n"
"- Rewrite the abstract to accurately reflect what was FOUND, not "
"what was hoped. The abstract must match actual numbers.\n"
"- The conclusion MUST match actual results — no aspirational claims.\n\n"
"IMPORTANT WRITING RULES:\n"
"- Do NOT add disclaimers like 'due to computational constraints' "
"or 'this analysis was not conducted'. If a limitation exists, "
"mention it ONCE in the Limitations section only.\n"
"- Focus 80% of the paper on what was DONE and what was FOUND.\n"
"- Do NOT add hedging language that was not in the original draft.\n"
"- Keep Limitations to 200-400 words with 3-5 concise points.\n"
"- Ensure every comparison claim cites a p-value or states that "
"the difference is not statistically significant.\n"
),
"user": (
"{academic_style_guide}\n"
"{narrative_writing_rules}\n"
"{anti_hedging_rules}\n"
"{anti_repetition_rules}\n"
"Revise the paper draft to address all review comments.\n"
"Return revised markdown only.\n\n"
"CRITICAL REVISION RULES:\n"
"- Transform any remaining bullet-point lists in the body into flowing "
"prose paragraphs. The only allowed lists are in the Introduction's contribution "
"paragraph and the Limitations section.\n"
"- The title MUST be <= 14 words with a catchy method name.\n"
"- MANDATORY: The revised paper MUST contain at least 2 markdown image references\n"
" (). If the draft has zero figures, ADD them in the Results\n"
" section using the chart files. A paper with zero figures will be desk-rejected.\n"
"- Consolidate ALL hedging/caveats into Limitations section only.\n"
"- The final paper body MUST be <= 6,500 words (standard 9-page conference limit).\n"
" If the current draft exceeds this, compress by removing redundant restatements.\n"
"- If the paper exceeds 10 pages, aggressively cut redundant content, "
"merge similar sections, and tighten prose. Target 8-9 pages of main content.\n"
"- Do NOT add '**Figure N.**' bold paragraphs after figure environments — "
"use only \\ref{fig:label} cross-references. Do NOT add \\clearpage "
"before figures or tables.\n"
"- NEVER use '--' placeholder values in tables. Replace with actual values or 'N/A'.\n"
"- CITATION FORMAT (CRITICAL): All citations MUST remain in [cite_key] bracket "
"format exactly as they appear in the draft, e.g. [smith2024transformer]. "
"Do NOT convert them to author-year format like [Smith et al., 2024] or "
"(Smith et al., 2024). The downstream LaTeX converter relies on the "
"[cite_key] format to generate \\cite{{}} commands. Changing the format "
"will break all references in the final PDF.\n"
"- CITATION KEYS (CRITICAL): Do NOT invent or add new citation keys that "
"are not already present in the draft. If you want to reference additional "
"prior work, describe it in prose WITHOUT a citation bracket. Every "
"[cite_key] you write MUST already exist in the bibliography. Adding "
"hallucinated keys like [smith2020method] creates broken [?] references "
"in the final PDF.\n"
"{writing_structure}\n"
"{topic_constraint}"
"Draft:\n{draft}\n\nReviews:\n{reviews}"
),
"max_tokens": 16384,
},
# ── Phase H: Finalization ────────────────────────────────────────────
"quality_gate": {
"system": "You are a final quality gate evaluator.",
"user": (
"Evaluate revised paper quality and return JSON.\n"
"Schema: {score_1_to_10:number, verdict:string, strengths:[...], "
"weaknesses:[...], required_actions:[...]}.\n"
"Threshold: {quality_threshold}\n"
"Paper:\n{revised}"
),
"json_mode": True,
},
"knowledge_archive": {
"system": "You produce reproducibility-focused research retrospectives.",
"user": (
"{preamble}\n\n"
"Write retrospective archive markdown with lessons, "
"reproducibility notes, and future work.\n"
"Decision:\n{decision}\n\nAnalysis:\n{analysis}\n\n"
"Revised paper:\n{revised}"
),
"max_tokens": 8192,
},
"export_publish": {
"system": "You are a publication formatting editor.",
"user": (
"Format revised paper into clean final markdown for publication "
"export.\n"
"Preserve content quality and readability.\n"
"CITATION FORMAT (CRITICAL): All citations MUST remain in [cite_key] bracket "
"format, e.g. [smith2024transformer]. Do NOT convert to author-year "
"format like [Smith et al., 2024]. The [cite_key] format is required "
"for downstream LaTeX \\cite{{}} generation.\n"
"Input paper:\n{revised}"
),
"max_tokens": 16384,
},
}
================================================
FILE: researchclaw/quality.py
================================================
"""Content quality assessment — template detection and metrics.
Detects placeholder/template content in LLM-generated text and provides
quality metrics for pipeline outputs.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
logger = logging.getLogger(__name__)
_TEMPLATE_PATTERNS: list[tuple[str, str]] = [
(
r"(?i)template\s+(abstract|introduction|method|methodology|conclusion|discussion|results|related\s+work)",
"Template section header",
),
(r"(?i)\[INSERT\s+.*?\]", "Insert placeholder"),
(r"(?i)\[TODO\s*:?\s*.*?\]", "TODO placeholder"),
(r"(?i)\[PLACEHOLDER\s*:?\s*.*?\]", "Explicit placeholder"),
(r"(?i)lorem\s+ipsum", "Lorem ipsum filler"),
(
r"(?i)this\s+section\s+will\s+(describe|discuss|present|outline|explain)",
"Future-tense placeholder",
),
(
r"(?i)we\s+will\s+(describe|discuss|present|outline|explain)\s+in\s+this\s+section",
"Future-tense placeholder",
),
(
r"(?i)add\s+(your|the)\s+(content|text|description)\s+here",
"Add content placeholder",
),
(r"(?i)replace\s+this\s+(text|content|section)", "Replace placeholder"),
(r"(?i)^#+\s*section\s+\d+\s*$", "Generic section header"),
(
r"(?i)your\s+(abstract|introduction|method|results)\s+goes?\s+here",
"Content placeholder",
),
(r"(?i)sample\s+(abstract|introduction|text|content)", "Sample content marker"),
]
@dataclass(frozen=True)
class TemplateMatch:
"""A single template/placeholder detection."""
pattern_desc: str
line_number: int
excerpt: str
@dataclass(frozen=True)
class QualityReport:
"""Quality assessment for a text document."""
total_lines: int
total_chars: int
template_matches: tuple[TemplateMatch, ...] = ()
template_ratio: float = 0.0
@property
def has_template_content(self) -> bool:
return len(self.template_matches) > 0
@property
def match_count(self) -> int:
return len(self.template_matches)
def to_dict(self) -> dict[str, object]:
match_rows: list[dict[str, object]] = [
{
"pattern": m.pattern_desc,
"line": m.line_number,
"excerpt": m.excerpt,
}
for m in self.template_matches
]
return {
"total_lines": self.total_lines,
"total_chars": self.total_chars,
"template_matches": match_rows,
"template_ratio": round(self.template_ratio, 4),
"has_template_content": self.has_template_content,
"match_count": self.match_count,
}
def detect_template_content(text: str) -> list[TemplateMatch]:
"""Scan text for template/placeholder patterns.
Returns list of TemplateMatch objects for each detected pattern.
"""
matches: list[TemplateMatch] = []
lines = text.split("\n")
for line_num, line in enumerate(lines, start=1):
stripped = line.strip()
if not stripped:
continue
for pattern, desc in _TEMPLATE_PATTERNS:
for m in re.finditer(pattern, stripped):
excerpt = m.group(0)[:100]
matches.append(
TemplateMatch(
pattern_desc=desc,
line_number=line_num,
excerpt=excerpt,
)
)
return matches
def compute_template_ratio(text: str) -> float:
"""Estimate what fraction of the text is template/placeholder content.
Returns 0.0 (fully original) to 1.0 (fully template).
Simple heuristic: count characters in matched lines vs total.
"""
if not text.strip():
return 0.0
lines = text.split("\n")
total_chars = sum(len(line.strip()) for line in lines if line.strip())
if total_chars == 0:
return 0.0
template_chars = 0
for line in lines:
stripped = line.strip()
if not stripped:
continue
for pattern, _ in _TEMPLATE_PATTERNS:
if re.search(pattern, stripped):
template_chars += len(stripped)
break
return min(template_chars / total_chars, 1.0)
def assess_quality(text: str) -> QualityReport:
"""Full quality assessment of a text document."""
lines = text.split("\n")
matches = detect_template_content(text)
ratio = compute_template_ratio(text)
report = QualityReport(
total_lines=len(lines),
total_chars=len(text),
template_matches=tuple(matches),
template_ratio=ratio,
)
logger.debug(
"quality assessed lines=%d chars=%d matches=%d ratio=%.4f",
report.total_lines,
report.total_chars,
report.match_count,
report.template_ratio,
)
return report
def check_strict_quality(text: str, *, threshold: float = 0.05) -> tuple[bool, str]:
"""Check if text passes strict quality gate.
Returns (passed, message).
Fails if template_ratio > threshold.
"""
report = assess_quality(text)
if report.template_ratio > threshold:
details = "; ".join(
f"L{m.line_number}: {m.excerpt}" for m in report.template_matches[:5]
)
return False, (
f"Template content detected: ratio={report.template_ratio:.2%}, "
f"{report.match_count} matches. Examples: {details}"
)
return True, f"Quality check passed: template_ratio={report.template_ratio:.2%}"
================================================
FILE: researchclaw/report.py
================================================
"""Generate human-readable run reports from pipeline artifacts."""
# pyright: basic
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def generate_report(run_dir: Path) -> str:
"""Generate a Markdown report from a pipeline run directory.
Args:
run_dir: Path to the run artifacts directory (e.g., artifacts/rc-xxx/)
Returns:
Markdown string with the report content.
Raises:
FileNotFoundError: If run_dir doesn't exist.
ValueError: If run_dir has no pipeline_summary.json.
"""
if not run_dir.exists():
raise FileNotFoundError(f"Run directory not found: {run_dir}")
summary_path = run_dir / "pipeline_summary.json"
if not summary_path.exists():
raise ValueError(f"No pipeline_summary.json found in {run_dir}")
loaded = json.loads(summary_path.read_text(encoding="utf-8"))
summary = loaded if isinstance(loaded, dict) else {}
sections = []
sections.append(_header(summary, run_dir))
sections.append(_paper_section(run_dir))
sections.append(_experiment_section(run_dir))
sections.append(_citation_section(run_dir))
sections.append(_warnings_section(summary))
return "\n\n".join(section for section in sections if section)
def _header(summary: dict[str, Any], run_dir: Path) -> str:
run_id = summary.get("run_id", "unknown")
stages_done = summary.get("stages_done", 0)
stages_total = summary.get("stages_executed", 0)
status = summary.get("final_status", "unknown")
generated = summary.get("generated", "unknown")
status_icon = "✅" if status == "done" else "❌" if status == "failed" else "⚠️"
lines = [
"# ResearchClaw Run Report",
"",
f"**Run ID**: {run_id}",
f"**Date**: {generated}",
f"**Status**: {status_icon} {status} ({stages_done}/{stages_total} stages done)",
f"**Artifacts**: `{run_dir}`",
]
return "\n".join(lines)
def _paper_section(run_dir: Path) -> str:
lines = ["## Paper"]
draft_path = run_dir / "stage-17" / "paper_draft.md"
if draft_path.exists():
text = draft_path.read_text(encoding="utf-8")
word_count = len(text.split())
lines.append(
f"- Draft: `{draft_path.relative_to(run_dir)}` (~{word_count} words)"
)
else:
lines.append("- Draft: not generated")
final_path = run_dir / "stage-22" / "paper_final.md"
if final_path.exists():
lines.append(f"- Final: `{final_path.relative_to(run_dir)}`")
tex_path = run_dir / "stage-22" / "paper.tex"
if tex_path.exists():
lines.append(f"- LaTeX: `{tex_path.relative_to(run_dir)}`")
rev_path = run_dir / "stage-19" / "paper_revised.md"
if rev_path.exists():
lines.append(f"- Revised: `{rev_path.relative_to(run_dir)}`")
return "\n".join(lines)
def _experiment_section(run_dir: Path) -> str:
lines = ["## Experiments"]
code_path = run_dir / "stage-10" / "experiment_code.py"
if code_path.exists():
lines.append(f"- Code: `{code_path.relative_to(run_dir)}`")
results_path = run_dir / "stage-12" / "experiment_results.json"
if results_path.exists():
try:
loaded = json.loads(results_path.read_text(encoding="utf-8"))
if isinstance(loaded, dict):
data = loaded
runs_default: list[Any] = []
iterations = data.get("iterations", data.get("runs", runs_default))
if isinstance(iterations, list):
lines.append(f"- Runs: {len(iterations)} iterations")
best = data.get("best_metric") or data.get("best_result")
if best is not None:
lines.append(f"- Best metric: {best}")
except (json.JSONDecodeError, TypeError):
lines.append("- Results: present (parse error)")
else:
lines.append("- Results: not available")
# BUG-215: Also search stage-14* versioned dirs when stage-14/ is missing.
analysis_path = run_dir / "stage-14" / "analysis.md"
if not analysis_path.exists():
for _s14 in sorted(run_dir.glob("stage-14*"), reverse=True):
_alt = _s14 / "analysis.md"
if _alt.exists():
analysis_path = _alt
break
if analysis_path.exists():
lines.append(f"- Analysis: `{analysis_path.relative_to(run_dir)}`")
return "\n".join(lines)
def _citation_section(run_dir: Path) -> str:
lines = ["## Citations"]
bib_path = run_dir / "stage-22" / "references.bib"
if not bib_path.exists():
bib_path = run_dir / "stage-04" / "references.bib"
if bib_path.exists():
text = bib_path.read_text(encoding="utf-8")
entries = re.findall(r"@\w+\{", text)
lines.append(f"- References: {len(entries)} BibTeX entries")
else:
lines.append("- References: not available")
verify_path = run_dir / "stage-23" / "verification_report.json"
if verify_path.exists():
try:
loaded = json.loads(verify_path.read_text(encoding="utf-8"))
vdata = loaded if isinstance(loaded, dict) else {}
total = int(vdata.get("total_references", 0))
verified = int(vdata.get("verified_count", 0))
suspicious = int(vdata.get("suspicious_count", 0))
hallucinated = int(vdata.get("hallucinated_count", 0))
pct = f"{verified / total * 100:.1f}%" if total > 0 else "N/A"
lines.append(f"- Verified: {verified}/{total} ({pct})")
if suspicious:
lines.append(f"- Suspicious: {suspicious}")
if hallucinated:
lines.append(f"- Hallucinated: {hallucinated}")
except (json.JSONDecodeError, TypeError, ZeroDivisionError):
lines.append("- Verification: present (parse error)")
else:
lines.append("- Verification: not run")
return "\n".join(lines)
def _warnings_section(summary: dict[str, Any]) -> str:
warnings: list[str] = []
stages_failed = summary.get("stages_failed", 0)
if stages_failed:
warnings.append(f"- ⚠️ {stages_failed} stage(s) failed during execution")
content_metrics = summary.get("content_metrics", {})
if isinstance(content_metrics, dict):
template_ratio = content_metrics.get("template_ratio")
if isinstance(template_ratio, (int, float)) and template_ratio > 0.1:
warnings.append(
f"- ⚠️ Template content detected: {template_ratio:.1%} of paper may be template text"
)
degraded = content_metrics.get("degraded_sources", [])
if isinstance(degraded, list) and degraded:
warnings.append(f"- ⚠️ Degraded sources: {', '.join(degraded)}")
if not warnings:
return ""
return "## Warnings\n" + "\n".join(warnings)
def print_report(run_dir: Path) -> None:
print(generate_report(run_dir))
def write_report(run_dir: Path, output_path: Path) -> None:
report = generate_report(run_dir)
_ = output_path.write_text(report, encoding="utf-8")
================================================
FILE: researchclaw/server/__init__.py
================================================
"""ResearchClaw Web server package."""
================================================
FILE: researchclaw/server/app.py
================================================
"""FastAPI application factory."""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from researchclaw.config import RCConfig
from researchclaw.server.middleware.auth import TokenAuthMiddleware
from researchclaw.server.websocket.manager import ConnectionManager
from researchclaw.server.websocket.events import Event, EventType
logger = logging.getLogger(__name__)
# Shared application state accessible by routes
_app_state: dict[str, Any] = {}
def create_app(
config: RCConfig,
*,
dashboard_only: bool = False,
monitor_dir: str | None = None,
) -> FastAPI:
"""Create and configure the FastAPI application.
Args:
config: ResearchClaw configuration.
dashboard_only: If True, only mount dashboard routes.
monitor_dir: Specific run directory to monitor.
"""
app = FastAPI(
title="ResearchClaw",
description="Autonomous Research Pipeline — Web Interface",
version="0.5.0",
)
# Store config in shared state
_app_state["config"] = config
_app_state["monitor_dir"] = monitor_dir
# --- CORS ---
app.add_middleware(
CORSMiddleware,
allow_origins=list(config.server.cors_origins),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Token auth ---
if config.server.auth_token:
app.add_middleware(TokenAuthMiddleware, token=config.server.auth_token)
# --- WebSocket manager ---
event_manager = ConnectionManager()
_app_state["event_manager"] = event_manager
# --- Health endpoint ---
@app.get("/api/health")
async def health() -> dict[str, Any]:
return {
"status": "ok",
"version": "0.5.0",
"active_connections": event_manager.active_count,
}
@app.get("/api/config")
async def config_summary() -> dict[str, Any]:
return {
"project": config.project.name,
"topic": config.research.topic,
"mode": config.experiment.mode,
"server": {
"voice_enabled": config.server.voice_enabled,
"dashboard_enabled": config.dashboard.enabled,
},
}
# --- Routes ---
from researchclaw.server.routes.pipeline import router as pipeline_router
from researchclaw.server.routes.projects import router as projects_router
app.include_router(pipeline_router)
app.include_router(projects_router)
if not dashboard_only:
from researchclaw.server.routes.chat import router as chat_router, set_chat_manager
set_chat_manager(event_manager)
app.include_router(chat_router)
if config.server.voice_enabled:
from researchclaw.server.routes.voice import router as voice_router
app.include_router(voice_router)
# --- WebSocket events endpoint ---
from fastapi import WebSocket, WebSocketDisconnect
import uuid
@app.websocket("/ws/events")
async def events_ws(websocket: WebSocket) -> None:
"""Real-time event stream for dashboard."""
client_id = f"evt-{uuid.uuid4().hex[:8]}"
await event_manager.connect(websocket, client_id)
try:
while True:
# Keep connection alive; client can send pings
await websocket.receive_text()
except WebSocketDisconnect:
event_manager.disconnect(client_id)
# --- Static files (frontend) ---
frontend_dir = Path(__file__).resolve().parent.parent.parent / "frontend"
if frontend_dir.is_dir():
app.mount("/static", StaticFiles(directory=str(frontend_dir)), name="static")
# Serve index.html at root
from fastapi.responses import FileResponse
@app.get("/")
async def index() -> FileResponse:
return FileResponse(str(frontend_dir / "index.html"))
# --- Background tasks ---
@app.on_event("startup")
async def startup() -> None:
asyncio.create_task(event_manager.heartbeat_loop(interval=15.0))
if config.dashboard.enabled:
from researchclaw.dashboard.broadcaster import start_dashboard_loop
asyncio.create_task(
start_dashboard_loop(
event_manager,
interval=config.dashboard.refresh_interval_sec,
monitor_dir=monitor_dir,
)
)
logger.info("ResearchClaw Web server started")
return app
================================================
FILE: researchclaw/server/dialog/__init__.py
================================================
"""Dialog / conversational research modules."""
================================================
FILE: researchclaw/server/dialog/intents.py
================================================
"""Intent classification for conversational research."""
from __future__ import annotations
import re
from enum import Enum
from typing import Any
class Intent(str, Enum):
"""Research chat intents."""
TOPIC_SELECTION = "topic_selection"
START_PIPELINE = "start_pipeline"
CHECK_STATUS = "check_status"
MODIFY_CONFIG = "modify_config"
DISCUSS_RESULTS = "discuss_results"
EDIT_PAPER = "edit_paper"
GENERAL_CHAT = "general_chat"
HELP = "help"
# Keyword patterns for fast classification
_INTENT_PATTERNS: list[tuple[Intent, re.Pattern[str]]] = [
(Intent.HELP, re.compile(
r"(?:^\s*help\s*$|\bhow\s+to\b|\busage\b|帮助|怎么用)", re.IGNORECASE
)),
(Intent.START_PIPELINE, re.compile(
r"(?:\b(?:start|run|begin|launch)\b|开始|启动|跑|运行)",
re.IGNORECASE,
)),
(Intent.CHECK_STATUS, re.compile(
r"(?:\b(?:status|progress|stage|current)\b|阶段|进度|到哪|第几|哪一步)", re.IGNORECASE
)),
(Intent.TOPIC_SELECTION, re.compile(
r"(?:\b(?:topic|idea|direction)\b|research\s+direction|研究方向|选题|研究主题|想法)",
re.IGNORECASE,
)),
(Intent.MODIFY_CONFIG, re.compile(
r"(?:\b(?:config|setting|parameter|batch|epoch)\b|learning\s+rate|学习率|修改|设置)",
re.IGNORECASE,
)),
(Intent.DISCUSS_RESULTS, re.compile(
r"(?:\b(?:results?|metrics?|accuracy|loss|performance)\b|结果|指标|效果|怎么样)",
re.IGNORECASE,
)),
(Intent.EDIT_PAPER, re.compile(
r"(?:\b(?:paper|abstract|introduction|draft)\b|论文|摘要|改一下|写)",
re.IGNORECASE,
)),
]
def classify_intent(message: str) -> tuple[Intent, float]:
"""Classify user intent from message text.
Returns (intent, confidence) where confidence is 0-1.
Uses keyword matching for speed; can be replaced with LLM.
"""
message_lower = message.strip().lower()
if not message_lower:
return Intent.GENERAL_CHAT, 0.0
for intent, pattern in _INTENT_PATTERNS:
if pattern.search(message_lower):
return intent, 0.8
return Intent.GENERAL_CHAT, 0.5
================================================
FILE: researchclaw/server/dialog/router.py
================================================
"""Dialog router — routes messages to appropriate handlers."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
from researchclaw.server.dialog.intents import Intent, classify_intent
from researchclaw.server.dialog.session import ChatSession, SessionManager
logger = logging.getLogger(__name__)
_session_manager = SessionManager()
async def route_message(raw_message: str, client_id: str) -> str:
"""Route incoming chat message and return response."""
# Parse message (could be plain text or JSON)
try:
msg_data = json.loads(raw_message)
text = msg_data.get("message", msg_data.get("text", raw_message))
except (json.JSONDecodeError, TypeError):
text = raw_message
session = _session_manager.get_or_create(client_id)
session.add_message("user", text)
intent, confidence = classify_intent(text)
logger.debug("Intent: %s (%.2f) for: %s", intent.value, confidence, text[:50])
handler = _HANDLERS.get(intent, _handle_general)
response = await handler(text, session)
session.add_message("assistant", response)
return response
async def _handle_help(text: str, session: ChatSession) -> str:
return (
"I can help you with:\n"
"- **Select a research topic**: describe your area of interest\n"
"- **Start a pipeline run**: say 'start experiment' or 'run pipeline'\n"
"- **Check progress**: ask 'what stage are we at?'\n"
"- **View results**: ask about metrics, accuracy, or results\n"
"- **Modify settings**: change learning rate, epochs, etc.\n"
"- **Edit paper**: suggest changes to abstract, introduction, etc.\n\n"
"Just type naturally — I'll figure out what you need!"
)
async def _handle_status(text: str, session: ChatSession) -> str:
from researchclaw.dashboard.collector import DashboardCollector
collector = DashboardCollector()
runs = collector.collect_all()
if not runs:
return "No pipeline runs found. Start one with 'start pipeline'."
active = [r for r in runs if r.is_active]
if active:
r = active[0]
return (
f"**Active run**: {r.run_id}\n"
f"- Stage: {r.current_stage}/23 ({r.current_stage_name})\n"
f"- Status: {r.status}\n"
f"- Topic: {r.topic or '(not set)'}"
)
latest = runs[0]
return (
f"**Latest run**: {latest.run_id}\n"
f"- Stage: {latest.current_stage}/23\n"
f"- Status: {latest.status}\n"
f"- Stages completed: {len(latest.stages_completed)}"
)
async def _handle_start(text: str, session: ChatSession) -> str:
return (
"To start a pipeline run, use the dashboard or API:\n"
"```\n"
"POST /api/pipeline/start\n"
'{"topic": "your research topic", "auto_approve": true}\n'
"```\n"
"Or run from CLI: `researchclaw run -c config.yaml`\n\n"
"Would you like me to help you set up the configuration?"
)
async def _handle_topic(text: str, session: ChatSession) -> str:
return (
"Let me help you find a research direction!\n\n"
"Please tell me:\n"
"1. Your research **domain** (e.g., CV, NLP, RL, AI4Science)\n"
"2. Any **specific interests** (e.g., robustness, efficiency, fairness)\n"
"3. Your **target venue** (e.g., NeurIPS, ICML, ICLR)\n\n"
"I'll suggest novel, timely research angles based on recent trends."
)
async def _handle_config(text: str, session: ChatSession) -> str:
return (
"You can modify the configuration through:\n"
"1. Edit `config.yaml` directly\n"
"2. Use the wizard: `researchclaw wizard`\n"
"3. Pass overrides when starting: "
'`POST /api/pipeline/start {"config_overrides": {...}}`\n\n'
"What setting would you like to change?"
)
async def _handle_results(text: str, session: ChatSession) -> str:
from researchclaw.dashboard.collector import DashboardCollector
collector = DashboardCollector()
runs = collector.collect_all()
if not runs:
return "No runs found yet. Start a pipeline first."
latest = runs[0]
if not latest.metrics:
return f"Run {latest.run_id} has no metrics yet (stage {latest.current_stage}/23)."
lines = [f"**Results for {latest.run_id}**:\n"]
for key, value in latest.metrics.items():
if isinstance(value, (int, float)):
lines.append(f"- {key}: {value}")
return "\n".join(lines) if len(lines) > 1 else f"Metrics: {latest.metrics}"
async def _handle_paper(text: str, session: ChatSession) -> str:
return (
"Paper editing is available after Stage 17 (Paper Draft).\n\n"
"I can help with:\n"
"- Review and suggest improvements to the abstract\n"
"- Check the introduction structure\n"
"- Verify experiment descriptions match actual results\n"
"- Improve related work coverage\n\n"
"Which section would you like to work on?"
)
async def _handle_general(text: str, session: ChatSession) -> str:
return (
"I'm your ResearchClaw assistant. I can help with:\n"
"- Selecting research topics\n"
"- Running experiments\n"
"- Monitoring progress\n"
"- Analyzing results\n"
"- Editing papers\n\n"
"What would you like to do?"
)
_HANDLERS = {
Intent.HELP: _handle_help,
Intent.CHECK_STATUS: _handle_status,
Intent.START_PIPELINE: _handle_start,
Intent.TOPIC_SELECTION: _handle_topic,
Intent.MODIFY_CONFIG: _handle_config,
Intent.DISCUSS_RESULTS: _handle_results,
Intent.EDIT_PAPER: _handle_paper,
Intent.GENERAL_CHAT: _handle_general,
}
================================================
FILE: researchclaw/server/dialog/session.py
================================================
"""Conversation session management."""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ChatMessage:
"""A single chat message."""
role: str # "user" or "assistant"
content: str
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> dict[str, Any]:
return {"role": self.role, "content": self.content, "timestamp": self.timestamp}
@dataclass
class ChatSession:
"""Per-client chat session state."""
client_id: str
history: list[ChatMessage] = field(default_factory=list)
current_project: str = ""
current_run: str = ""
created_at: float = field(default_factory=time.time)
MAX_HISTORY: int = 50
def add_message(self, role: str, content: str) -> ChatMessage:
msg = ChatMessage(role=role, content=content)
self.history.append(msg)
# Trim to prevent unbounded growth
if len(self.history) > self.MAX_HISTORY:
self.history = self.history[-self.MAX_HISTORY:]
return msg
def get_context(self, last_n: int = 10) -> list[dict[str, str]]:
"""Get recent messages for LLM context."""
return [
{"role": m.role, "content": m.content}
for m in self.history[-last_n:]
]
def to_dict(self) -> dict[str, Any]:
return {
"client_id": self.client_id,
"current_project": self.current_project,
"current_run": self.current_run,
"history": [m.to_dict() for m in self.history],
"created_at": self.created_at,
}
class SessionManager:
"""Manage chat sessions."""
def __init__(self, persist_dir: str = ".researchclaw/sessions") -> None:
self._sessions: dict[str, ChatSession] = {}
self._persist_dir = Path(persist_dir)
def get_or_create(self, client_id: str) -> ChatSession:
"""Get existing session or create new one."""
if client_id not in self._sessions:
self._sessions[client_id] = ChatSession(client_id=client_id)
return self._sessions[client_id]
def remove(self, client_id: str) -> None:
"""Remove a session."""
self._sessions.pop(client_id, None)
def save(self, client_id: str) -> None:
"""Persist session to disk."""
session = self._sessions.get(client_id)
if not session:
return
self._persist_dir.mkdir(parents=True, exist_ok=True)
path = self._persist_dir / f"{client_id}.json"
try:
with path.open("w", encoding="utf-8") as f:
json.dump(session.to_dict(), f, ensure_ascii=False, indent=2)
except Exception:
logger.debug("Failed to persist session %s", client_id)
def load(self, client_id: str) -> ChatSession | None:
"""Load session from disk."""
path = self._persist_dir / f"{client_id}.json"
if not path.exists():
return None
try:
with path.open() as f:
data = json.load(f)
session = ChatSession(
client_id=data["client_id"],
current_project=data.get("current_project", ""),
current_run=data.get("current_run", ""),
created_at=data.get("created_at", time.time()),
)
for m in data.get("history", []):
session.history.append(
ChatMessage(
role=m["role"],
content=m["content"],
timestamp=m.get("timestamp", 0),
)
)
self._sessions[client_id] = session
return session
except Exception:
logger.debug("Failed to load session %s", client_id)
return None
================================================
FILE: researchclaw/server/middleware/__init__.py
================================================
"""Server middleware modules."""
================================================
FILE: researchclaw/server/middleware/auth.py
================================================
"""Basic token authentication middleware."""
from __future__ import annotations
from typing import Callable, Awaitable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
class TokenAuthMiddleware(BaseHTTPMiddleware):
"""Optional bearer-token authentication.
If *token* is empty, all requests are allowed (no-op).
"""
# Paths that never require auth
EXEMPT_PATHS = frozenset({"/api/health", "/docs", "/openapi.json"})
def __init__(self, app: object, token: str = "") -> None:
super().__init__(app) # type: ignore[arg-type]
self._token = token
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
# No-op when token is unset
if not self._token:
return await call_next(request)
# Skip auth for exempt paths and static files
path = request.url.path
if path in self.EXEMPT_PATHS or path.startswith("/static"):
return await call_next(request)
# WebSocket connections carry token as query param
if path.startswith("/ws"):
token = request.query_params.get("token", "")
else:
auth_header = request.headers.get("authorization", "")
token = auth_header.removeprefix("Bearer ").strip()
if token != self._token:
return JSONResponse(
{"detail": "Unauthorized"}, status_code=401
)
return await call_next(request)
================================================
FILE: researchclaw/server/routes/__init__.py
================================================
"""API route modules."""
================================================
FILE: researchclaw/server/routes/chat.py
================================================
"""Chat WebSocket endpoint for conversational research."""
from __future__ import annotations
import logging
import uuid
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from researchclaw.server.websocket.events import Event, EventType
from researchclaw.server.websocket.manager import ConnectionManager
logger = logging.getLogger(__name__)
router = APIRouter(tags=["chat"])
# Global connection manager (initialized by app.py)
_chat_manager: ConnectionManager | None = None
def set_chat_manager(manager: ConnectionManager) -> None:
"""Set the shared connection manager."""
global _chat_manager
_chat_manager = manager
def get_chat_manager() -> ConnectionManager:
"""Get the shared connection manager."""
if _chat_manager is None:
raise RuntimeError("Chat manager not initialized")
return _chat_manager
@router.websocket("/ws/chat")
async def chat_websocket(websocket: WebSocket) -> None:
"""WebSocket endpoint for conversational research chat."""
manager = get_chat_manager()
client_id = str(uuid.uuid4())[:8]
await manager.connect(websocket, client_id)
try:
while True:
raw = await websocket.receive_text()
try:
from researchclaw.server.dialog.router import route_message
response = await route_message(raw, client_id)
await manager.send_to(
client_id,
Event(
type=EventType.CHAT_RESPONSE,
data={"message": response, "client_id": client_id},
),
)
except Exception as exc:
logger.exception("Chat error for %s", client_id)
await manager.send_to(
client_id,
Event(
type=EventType.ERROR,
data={"error": str(exc), "client_id": client_id},
),
)
except WebSocketDisconnect:
manager.disconnect(client_id)
================================================
FILE: researchclaw/server/routes/pipeline.py
================================================
"""Pipeline control API routes."""
from __future__ import annotations
import asyncio
import json
import logging
from pathlib import Path
from typing import Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
import re as _re
_RUN_ID_RE = _re.compile(r"^rc-\d{8}-\d{6}-[a-f0-9]+$")
def _validated_run_dir(run_id: str) -> Path:
"""Validate run_id format and return the run directory path."""
if not _RUN_ID_RE.match(run_id):
raise HTTPException(status_code=400, detail=f"Invalid run_id format: {run_id}")
run_dir = Path("artifacts") / run_id
# Ensure resolved path is under artifacts/
if not run_dir.resolve().is_relative_to(Path("artifacts").resolve()):
raise HTTPException(status_code=400, detail=f"Invalid run_id: {run_id}")
return run_dir
router = APIRouter(prefix="/api", tags=["pipeline"])
class PipelineStartRequest(BaseModel):
"""Request body for starting a pipeline run."""
topic: str | None = None
config_overrides: dict[str, Any] | None = None
auto_approve: bool = True
class PipelineStartResponse(BaseModel):
"""Response after starting a pipeline."""
run_id: str
status: str
output_dir: str
# In-memory tracking of the active run (single-tenant MVP)
_active_run: dict[str, Any] | None = None
_run_task: asyncio.Task[Any] | None = None
def _get_app_state() -> dict[str, Any]:
"""Get shared application state (set by app.py)."""
from researchclaw.server.app import _app_state
return _app_state
@router.post("/pipeline/start", response_model=PipelineStartResponse)
async def start_pipeline(req: PipelineStartRequest) -> PipelineStartResponse:
"""Start a new pipeline run."""
global _active_run, _run_task
if _active_run and _active_run.get("status") == "running":
raise HTTPException(status_code=409, detail="A pipeline is already running")
state = _get_app_state()
config = state["config"]
if req.topic:
import dataclasses
new_research = dataclasses.replace(config.research, topic=req.topic)
config = dataclasses.replace(config, research=new_research)
import hashlib
from datetime import datetime, timezone
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
topic_hash = hashlib.sha256(config.research.topic.encode()).hexdigest()[:6]
run_id = f"rc-{ts}-{topic_hash}"
run_dir = _validated_run_dir(run_id)
run_dir.mkdir(parents=True, exist_ok=True)
_active_run = {
"run_id": run_id,
"status": "running",
"output_dir": str(run_dir),
"topic": config.research.topic,
}
async def _run_in_background() -> None:
global _active_run
try:
from researchclaw.adapters import AdapterBundle
from researchclaw.pipeline.runner import execute_pipeline
kb_root = Path(config.knowledge_base.root) if config.knowledge_base.root else None
if kb_root:
kb_root.mkdir(parents=True, exist_ok=True)
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
None,
lambda: execute_pipeline(
run_dir=run_dir,
run_id=run_id,
config=config,
adapters=AdapterBundle(),
auto_approve_gates=req.auto_approve,
skip_noncritical=True,
kb_root=kb_root,
),
)
done = sum(1 for r in results if r.status.value == "done")
failed = sum(1 for r in results if r.status.value == "failed")
if _active_run:
_active_run["status"] = "completed" if failed == 0 else "failed"
_active_run["stages_done"] = done
_active_run["stages_failed"] = failed
except Exception as exc:
logger.exception("Pipeline run failed")
if _active_run:
_active_run["status"] = "failed"
_active_run["error"] = str(exc)
_run_task = asyncio.create_task(_run_in_background())
return PipelineStartResponse(
run_id=run_id,
status="running",
output_dir=str(run_dir),
)
@router.post("/pipeline/stop")
async def stop_pipeline() -> dict[str, str]:
"""Stop the currently running pipeline."""
global _active_run, _run_task
if not _run_task or not _active_run:
raise HTTPException(status_code=404, detail="No pipeline is running")
_run_task.cancel()
_active_run["status"] = "stopped"
return {"status": "stopped"}
@router.get("/pipeline/status")
async def pipeline_status() -> dict[str, Any]:
"""Get current pipeline run status."""
if not _active_run:
return {"status": "idle"}
return _active_run
@router.get("/pipeline/stages")
async def pipeline_stages() -> dict[str, Any]:
"""Get the 23-stage pipeline definition."""
from researchclaw.pipeline.stages import Stage
stages = []
for s in Stage:
stages.append({
"number": int(s),
"name": s.name,
"label": getattr(s, "label", s.name.replace("_", " ").title()),
"phase": getattr(s, "phase", ""),
})
return {"stages": stages}
@router.get("/runs")
async def list_runs() -> dict[str, Any]:
"""List historical pipeline runs from artifacts/ directory."""
artifacts = Path("artifacts")
runs: list[dict[str, Any]] = []
if artifacts.exists():
for d in sorted(artifacts.iterdir(), reverse=True):
if d.is_dir() and d.name.startswith("rc-"):
info: dict[str, Any] = {"run_id": d.name, "path": str(d)}
# Try reading checkpoint
ckpt = d / "checkpoint.json"
if ckpt.exists():
try:
with ckpt.open() as f:
info["checkpoint"] = json.load(f)
except Exception:
pass
runs.append(info)
return {"runs": runs[:50]} # limit to 50 most recent
@router.get("/runs/{run_id}")
async def get_run(run_id: str) -> dict[str, Any]:
"""Get details for a specific run."""
run_dir = _validated_run_dir(run_id)
if not run_dir.exists():
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
info: dict[str, Any] = {"run_id": run_id, "path": str(run_dir)}
ckpt = run_dir / "checkpoint.json"
if ckpt.exists():
try:
with ckpt.open() as f:
info["checkpoint"] = json.load(f)
except Exception:
pass
# List stage directories
stage_dirs = sorted(
[d.name for d in run_dir.iterdir() if d.is_dir() and d.name.startswith("stage-")]
)
info["stages_completed"] = stage_dirs
# Check for paper
for pattern in ["paper.md", "paper.tex", "paper.pdf"]:
found = list(run_dir.rglob(pattern))
if found:
info[f"has_{pattern.split('.')[1]}"] = True
return info
@router.get("/runs/{run_id}/metrics")
async def get_run_metrics(run_id: str) -> dict[str, Any]:
"""Get experiment metrics for a run."""
run_dir = _validated_run_dir(run_id)
if not run_dir.exists():
raise HTTPException(status_code=404, detail=f"Run not found: {run_id}")
metrics: dict[str, Any] = {}
results_file = run_dir / "results.json"
if results_file.exists():
try:
with results_file.open() as f:
metrics = json.load(f)
except Exception:
pass
return {"run_id": run_id, "metrics": metrics}
================================================
FILE: researchclaw/server/routes/projects.py
================================================
"""Project listing / status API routes."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from fastapi import APIRouter
router = APIRouter(prefix="/api", tags=["projects"])
@router.get("/projects")
async def list_projects() -> dict[str, Any]:
"""List all project directories (artifacts/rc-*)."""
artifacts = Path("artifacts")
projects: list[dict[str, Any]] = []
if artifacts.exists():
for d in sorted(artifacts.iterdir(), reverse=True):
if d.is_dir() and d.name.startswith("rc-"):
proj: dict[str, Any] = {
"id": d.name,
"path": str(d),
}
ckpt = d / "checkpoint.json"
if ckpt.exists():
try:
with ckpt.open() as f:
ckpt_data = json.load(f)
proj["current_stage"] = ckpt_data.get("stage")
proj["status"] = ckpt_data.get("status", "unknown")
except Exception:
proj["status"] = "unknown"
else:
proj["status"] = "no_checkpoint"
projects.append(proj)
return {"projects": projects}
================================================
FILE: researchclaw/server/routes/voice.py
================================================
"""Voice upload / transcription API routes."""
from __future__ import annotations
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, UploadFile, File
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/voice", tags=["voice"])
@router.post("/transcribe")
async def transcribe_audio(
file: UploadFile = File(...),
language: str = "zh",
) -> dict[str, Any]:
"""Transcribe uploaded audio using Whisper API."""
try:
from researchclaw.voice.transcriber import VoiceTranscriber
except ImportError:
raise HTTPException(
status_code=501,
detail="Voice dependencies not installed. Run: pip install researchclaw[voice]",
)
from researchclaw.server.app import _app_state
config = _app_state.get("config")
if not config or not config.server.voice_enabled:
raise HTTPException(status_code=403, detail="Voice is not enabled in config")
audio_bytes = await file.read()
transcriber = VoiceTranscriber(config.server)
text = await transcriber.transcribe(audio_bytes, language=language)
return {"text": text, "language": language}
================================================
FILE: researchclaw/server/websocket/__init__.py
================================================
"""WebSocket modules."""
================================================
FILE: researchclaw/server/websocket/events.py
================================================
"""WebSocket event type definitions."""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any
import json
import time
class EventType(str, Enum):
"""All WebSocket event types."""
# Lifecycle
CONNECTED = "connected"
HEARTBEAT = "heartbeat"
ERROR = "error"
# Pipeline
PIPELINE_STARTED = "pipeline_started"
PIPELINE_COMPLETED = "pipeline_completed"
STAGE_START = "stage_start"
STAGE_COMPLETE = "stage_complete"
STAGE_FAIL = "stage_fail"
METRIC_UPDATE = "metric_update"
LOG_LINE = "log_line"
PAPER_READY = "paper_ready"
# Chat
CHAT_RESPONSE = "chat_response"
CHAT_TYPING = "chat_typing"
CHAT_SUGGESTION = "chat_suggestion"
# System
RUN_DISCOVERED = "run_discovered"
RUN_STATUS_CHANGED = "run_status_changed"
@dataclass
class Event:
"""A WebSocket event."""
type: EventType
data: dict[str, Any] = field(default_factory=dict)
timestamp: float = field(default_factory=time.time)
def to_json(self) -> str:
"""Serialize to JSON string."""
return json.dumps(
{
"type": self.type.value,
"data": self.data,
"timestamp": self.timestamp,
}
)
@classmethod
def from_json(cls, raw: str) -> Event:
"""Deserialize from JSON string."""
obj = json.loads(raw)
return cls(
type=EventType(obj["type"]),
data=obj.get("data", {}),
timestamp=obj.get("timestamp", time.time()),
)
================================================
FILE: researchclaw/server/websocket/manager.py
================================================
"""WebSocket connection manager."""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
from fastapi import WebSocket
from .events import Event, EventType
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manage WebSocket connections and broadcast events."""
def __init__(self) -> None:
self._connections: dict[str, WebSocket] = {}
self._event_queue: asyncio.Queue[Event] = asyncio.Queue()
@property
def active_count(self) -> int:
return len(self._connections)
async def connect(self, websocket: WebSocket, client_id: str) -> None:
"""Accept and register a WebSocket connection."""
await websocket.accept()
self._connections[client_id] = websocket
logger.info("WebSocket connected: %s (total: %d)", client_id, self.active_count)
await self._send(
websocket,
Event(type=EventType.CONNECTED, data={"client_id": client_id}),
)
def disconnect(self, client_id: str) -> None:
"""Remove a disconnected client."""
self._connections.pop(client_id, None)
logger.info("WebSocket disconnected: %s (total: %d)", client_id, self.active_count)
async def broadcast(self, event: Event) -> None:
"""Send event to all connected clients."""
dead: list[str] = []
for cid, ws in self._connections.items():
try:
await self._send(ws, event)
except Exception:
dead.append(cid)
for cid in dead:
self.disconnect(cid)
async def send_to(self, client_id: str, event: Event) -> None:
"""Send event to a specific client."""
ws = self._connections.get(client_id)
if ws:
try:
await self._send(ws, event)
except Exception:
self.disconnect(client_id)
async def _send(self, ws: WebSocket, event: Event) -> None:
await ws.send_text(event.to_json())
def publish(self, event: Event) -> None:
"""Non-async publish for use from sync code (thread-safe queue)."""
try:
self._event_queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("Event queue full, dropping event: %s", event.type)
async def drain_queue(self) -> None:
"""Process queued events and broadcast them."""
while not self._event_queue.empty():
event = self._event_queue.get_nowait()
await self.broadcast(event)
async def heartbeat_loop(self, interval: float = 15.0) -> None:
"""Send periodic heartbeat to all clients."""
while True:
await asyncio.sleep(interval)
await self.broadcast(
Event(
type=EventType.HEARTBEAT,
data={"active_clients": self.active_count},
)
)
await self.drain_queue()
================================================
FILE: researchclaw/servers/__init__.py
================================================
"""Multi-server resource scheduling for AutoResearchClaw."""
from researchclaw.servers.registry import ServerRegistry
from researchclaw.servers.monitor import ServerMonitor
from researchclaw.servers.dispatcher import TaskDispatcher
__all__ = ["ServerRegistry", "ServerMonitor", "TaskDispatcher"]
================================================
FILE: researchclaw/servers/cloud_executor.py
================================================
"""Cloud executor: stub for AWS/GCP/Azure GPU instance management."""
from __future__ import annotations
import logging
from typing import Any
from researchclaw.servers.registry import ServerEntry
logger = logging.getLogger(__name__)
class CloudExecutor:
"""Manage cloud GPU instances for experiment execution.
This is a stub implementation. Actual cloud provider APIs (boto3, google-cloud,
azure-mgmt) are imported lazily to avoid hard dependencies.
"""
def __init__(self, server: ServerEntry) -> None:
if server.server_type != "cloud":
raise ValueError(f"Server {server.name} is not a cloud server")
self.server = server
self.provider = server.cloud_provider
async def launch_instance(self) -> dict[str, Any]:
"""Launch a cloud GPU instance."""
logger.info(
"Launching %s instance (%s) for %s",
self.provider,
self.server.cloud_instance_type,
self.server.name,
)
# Stub: actual implementation would call provider SDK
return {
"provider": self.provider,
"instance_type": self.server.cloud_instance_type,
"status": "stub_launched",
"instance_id": f"stub-{self.server.name}",
"cost_per_hour": self.server.cost_per_hour,
}
async def terminate_instance(self, instance_id: str) -> None:
"""Terminate a cloud instance."""
logger.info("Terminating instance %s on %s", instance_id, self.provider)
async def get_instance_status(self, instance_id: str) -> dict[str, Any]:
"""Check instance status."""
return {"instance_id": instance_id, "status": "stub_unknown"}
================================================
FILE: researchclaw/servers/dispatcher.py
================================================
"""Task dispatcher: route experiment tasks to the best available server."""
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Any
from researchclaw.servers.registry import ServerEntry, ServerRegistry
from researchclaw.servers.monitor import ServerMonitor
from researchclaw.servers.ssh_executor import SSHExecutor
from researchclaw.servers.slurm_executor import SlurmExecutor
logger = logging.getLogger(__name__)
class TaskDispatcher:
"""Dispatch experiment tasks to the best available server."""
def __init__(
self,
registry: ServerRegistry,
monitor: ServerMonitor,
prefer_free: bool = True,
failover: bool = True,
) -> None:
self.registry = registry
self.monitor = monitor
self.prefer_free = prefer_free
self.failover = failover
self._tasks: dict[str, dict[str, Any]] = {}
self._busy_servers: set[str] = set()
async def dispatch(self, task: dict[str, Any]) -> str:
"""Dispatch a task to the best available server.
Args:
task: dict with keys: command, local_dir, requirements (optional)
Returns:
task_id for tracking
"""
task_id = uuid.uuid4().hex[:12]
requirements = task.get("requirements", {})
# Find best server
server = self.registry.get_best_match(
requirements=requirements,
prefer_free=self.prefer_free,
)
if server is None:
self._tasks[task_id] = {"status": "queued", "task": task, "error": "No matching server"}
logger.warning("No server available for task %s, queued", task_id)
return task_id
# Dispatch based on server type
self._tasks[task_id] = {
"status": "dispatched",
"server": server.name,
"task": task,
}
self._busy_servers.add(server.name)
logger.info("Dispatched task %s to %s (%s)", task_id, server.name, server.server_type)
return task_id
async def execute_task(self, task_id: str) -> dict[str, Any]:
"""Execute a dispatched task on its assigned server."""
info = self._tasks.get(task_id)
if not info or info["status"] != "dispatched":
return {"success": False, "error": "Task not dispatched"}
server = self.registry.get(info["server"])
task = info["task"]
remote_dir = f"/tmp/researchclaw_{task_id}"
try:
if server.server_type == "slurm":
executor = SlurmExecutor(server)
job_id = await executor.submit_job(
command=task["command"],
remote_dir=remote_dir,
resources=task.get("requirements"),
)
info["status"] = "running"
info["job_id"] = job_id
return {"success": True, "job_id": job_id}
else:
# Default: SSH executor
executor = SSHExecutor(server) # type: ignore[assignment]
result = await executor.run_experiment(
remote_dir=remote_dir,
command=task["command"],
timeout=task.get("timeout", 3600),
)
info["status"] = "completed" if result["success"] else "failed"
info["result"] = result
return result
except Exception as exc:
logger.error("Task %s failed: %s", task_id, exc)
info["status"] = "failed"
info["error"] = str(exc)
# Failover: try another server (non-recursive, single attempt)
if self.failover:
tried = {server.name}
alt = self.registry.get_best_match(
requirements=task.get("requirements"),
prefer_free=self.prefer_free,
)
if alt and alt.name not in tried:
logger.info("Failing over task %s to %s", task_id, alt.name)
info["server"] = alt.name
info["status"] = "dispatched"
try:
alt_server = self.registry.get(alt.name)
result = await alt_server.run_experiment(
remote_dir=task.get("remote_dir", ""),
command=task.get("command", ""),
timeout=task.get("timeout", 3600),
)
info["status"] = "completed"
return result
except Exception as alt_exc:
logger.error("Failover also failed: %s", alt_exc)
return {"success": False, "error": str(exc)}
finally:
self._busy_servers.discard(server.name)
def get_task_status(self, task_id: str) -> dict[str, Any]:
"""Get the status of a task."""
info = self._tasks.get(task_id)
if not info:
return {"task_id": task_id, "status": "unknown"}
return {
"task_id": task_id,
"status": info["status"],
"server": info.get("server"),
"error": info.get("error"),
}
================================================
FILE: researchclaw/servers/monitor.py
================================================
"""Server monitor: check health and resource usage of registered servers."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from researchclaw.servers.registry import ServerEntry, ServerRegistry
logger = logging.getLogger(__name__)
class ServerMonitor:
"""Monitor health and resource usage of registered servers."""
def __init__(self, registry: ServerRegistry) -> None:
self.registry = registry
self._status_cache: dict[str, dict[str, Any]] = {}
async def check_status(self, server: ServerEntry) -> dict[str, Any]:
"""Check a single server's status via SSH (nvidia-smi, free, uptime)."""
try:
result = await _ssh_command(server.host, "nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits 2>/dev/null; echo '---'; free -m | head -2; echo '---'; uptime")
status = _parse_status_output(result, server)
status["reachable"] = True
except Exception as exc:
logger.warning("Cannot reach server %s: %s", server.name, exc)
status = {"reachable": False, "error": str(exc)}
self._status_cache[server.name] = status
return status
async def check_all(self) -> dict[str, dict[str, Any]]:
"""Check all servers concurrently."""
servers = self.registry.list_all()
tasks = [self.check_status(s) for s in servers]
results = await asyncio.gather(*tasks, return_exceptions=True)
out: dict[str, dict[str, Any]] = {}
for server, result in zip(servers, results):
if isinstance(result, Exception):
out[server.name] = {"reachable": False, "error": str(result)}
else:
out[server.name] = result
return out
def get_cached(self, name: str) -> dict[str, Any] | None:
"""Return cached status for a server."""
return self._status_cache.get(name)
def get_gpu_usage(self, server: ServerEntry) -> dict[str, Any]:
"""Return cached GPU usage for a server (sync convenience)."""
cached = self._status_cache.get(server.name, {})
return cached.get("gpu", {})
async def _ssh_command(host: str, command: str) -> str:
"""Run a command on a remote host via SSH."""
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no",
host, command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"SSH command failed (rc={proc.returncode}): {stderr.decode().strip()}")
return stdout.decode()
def _parse_status_output(raw: str, server: ServerEntry) -> dict[str, Any]:
"""Parse combined nvidia-smi + free + uptime output."""
sections = raw.split("---")
status: dict[str, Any] = {"server": server.name, "host": server.host}
# GPU section
if len(sections) >= 1:
gpu_lines = [l.strip() for l in sections[0].strip().splitlines() if l.strip()]
gpus = []
for line in gpu_lines:
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 3:
gpus.append({
"utilization_pct": int(parts[0]),
"memory_used_mb": int(parts[1]),
"memory_total_mb": int(parts[2]),
})
status["gpu"] = {"count": len(gpus), "devices": gpus}
# Memory section
if len(sections) >= 2:
mem_lines = sections[1].strip().splitlines()
if len(mem_lines) >= 2:
parts = mem_lines[1].split()
if len(parts) >= 4:
status["memory"] = {
"total_mb": int(parts[1]),
"used_mb": int(parts[2]),
"free_mb": int(parts[3]),
}
# Uptime section
if len(sections) >= 3:
status["uptime"] = sections[2].strip()
return status
================================================
FILE: researchclaw/servers/registry.py
================================================
"""Server registry: manage available compute servers."""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
class ServerEntry:
"""A compute server that can run experiments."""
def __init__(
self,
name: str,
host: str,
server_type: str = "ssh",
gpu: str = "",
vram_gb: int = 0,
priority: int = 1,
scheduler: str = "",
cloud_provider: str = "",
cloud_instance_type: str = "",
cost_per_hour: float = 0.0,
) -> None:
self.name = name
self.host = host
self.server_type = server_type # ssh | slurm | cloud
self.gpu = gpu
self.vram_gb = vram_gb
self.priority = priority
self.scheduler = scheduler # slurm | pbs | lsf
self.cloud_provider = cloud_provider # aws | gcp | azure
self.cloud_instance_type = cloud_instance_type
self.cost_per_hour = cost_per_hour
def to_dict(self) -> dict[str, Any]:
return {
"name": self.name,
"host": self.host,
"server_type": self.server_type,
"gpu": self.gpu,
"vram_gb": self.vram_gb,
"priority": self.priority,
"scheduler": self.scheduler,
"cloud_provider": self.cloud_provider,
"cloud_instance_type": self.cloud_instance_type,
"cost_per_hour": self.cost_per_hour,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ServerEntry:
return cls(
name=data["name"],
host=data.get("host", ""),
server_type=data.get("server_type", "ssh"),
gpu=data.get("gpu", ""),
vram_gb=int(data.get("vram_gb", 0)),
priority=int(data.get("priority", 1)),
scheduler=data.get("scheduler", ""),
cloud_provider=data.get("cloud_provider", ""),
cloud_instance_type=data.get("cloud_instance_type", ""),
cost_per_hour=float(data.get("cost_per_hour", 0.0)),
)
class ServerRegistry:
"""Registry of available compute servers."""
def __init__(self, servers: list[ServerEntry] | None = None) -> None:
self._servers: dict[str, ServerEntry] = {}
for s in (servers or []):
self._servers[s.name] = s
def add(self, server: ServerEntry) -> None:
"""Register a new server."""
self._servers[server.name] = server
logger.info("Registered server: %s (%s)", server.name, server.host)
def remove(self, name: str) -> None:
"""Remove a server from the registry."""
if name not in self._servers:
raise KeyError(f"Unknown server: {name}")
del self._servers[name]
def get(self, name: str) -> ServerEntry:
"""Get a server by name."""
if name not in self._servers:
raise KeyError(f"Unknown server: {name}")
return self._servers[name]
def list_all(self) -> list[ServerEntry]:
"""Return all registered servers sorted by priority (lower = higher priority)."""
return sorted(self._servers.values(), key=lambda s: s.priority)
def get_available(self, exclude: set[str] | None = None) -> list[ServerEntry]:
"""Return servers not in the exclude set, sorted by priority."""
excluded = exclude or set()
return [s for s in self.list_all() if s.name not in excluded]
def get_best_match(
self,
requirements: dict[str, Any] | None = None,
prefer_free: bool = True,
) -> ServerEntry | None:
"""Find the best server matching resource requirements.
Args:
requirements: dict with optional keys: min_vram_gb, server_type, gpu
prefer_free: prefer servers with cost_per_hour == 0
"""
reqs = requirements or {}
candidates = self.list_all()
# Filter by minimum VRAM
min_vram = reqs.get("min_vram_gb", 0)
if min_vram:
candidates = [s for s in candidates if s.vram_gb >= min_vram]
# Filter by server type
stype = reqs.get("server_type")
if stype:
candidates = [s for s in candidates if s.server_type == stype]
# Filter by GPU model substring
gpu_req = reqs.get("gpu")
if gpu_req:
candidates = [s for s in candidates if gpu_req.lower() in s.gpu.lower()]
if not candidates:
return None
# Sort: prefer free servers, then by priority
if prefer_free:
candidates.sort(key=lambda s: (s.cost_per_hour > 0, s.priority))
return candidates[0]
@property
def count(self) -> int:
return len(self._servers)
================================================
FILE: researchclaw/servers/slurm_executor.py
================================================
"""Slurm HPC executor: submit, monitor, and cancel batch jobs."""
from __future__ import annotations
import asyncio
import logging
import textwrap
from typing import Any
from researchclaw.servers.registry import ServerEntry
logger = logging.getLogger(__name__)
class SlurmExecutor:
"""Submit and manage Slurm batch jobs via SSH."""
def __init__(self, server: ServerEntry) -> None:
if server.server_type != "slurm":
raise ValueError(f"Server {server.name} is not a slurm server")
self.server = server
self.host = server.host
def _generate_sbatch_script(
self,
command: str,
job_name: str = "researchclaw",
resources: dict[str, Any] | None = None,
) -> str:
"""Generate an sbatch submission script."""
res = resources or {}
gpus = res.get("gpus", 1)
mem = res.get("mem_gb", 16)
time_limit = res.get("time", "01:00:00")
partition = res.get("partition", "")
lines = [
"#!/bin/bash",
f"#SBATCH --job-name={job_name}",
f"#SBATCH --gres=gpu:{gpus}",
f"#SBATCH --mem={mem}G",
f"#SBATCH --time={time_limit}",
"#SBATCH --output=slurm-%j.out",
"#SBATCH --error=slurm-%j.err",
]
if partition:
lines.append(f"#SBATCH --partition={partition}")
lines.append("")
lines.append(command)
return "\n".join(lines)
async def submit_job(
self,
command: str,
remote_dir: str,
job_name: str = "researchclaw",
resources: dict[str, Any] | None = None,
) -> str:
"""Submit a Slurm job and return the job ID."""
script = self._generate_sbatch_script(command, job_name, resources)
# Write script and submit via SSH
import shlex as _shlex
ssh_cmd = (
f"cd {_shlex.quote(remote_dir)} && "
f"cat <<'EOFSCRIPT' > _job.sh\n{script}\nEOFSCRIPT\n"
f"&& sbatch _job.sh"
)
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=10", "-o", "StrictHostKeyChecking=no",
self.host, ssh_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"sbatch failed: {stderr.decode().strip()}")
# Parse "Submitted batch job 12345"
output = stdout.decode().strip()
parts = output.split()
if len(parts) >= 4 and parts[-1].isdigit():
job_id = parts[-1]
logger.info("Submitted Slurm job %s on %s", job_id, self.server.name)
return job_id
raise RuntimeError(f"Could not parse sbatch output: {output}")
async def check_job(self, job_id: str) -> dict[str, Any]:
"""Check job status via squeue/sacct."""
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=10", "-o", "StrictHostKeyChecking=no",
self.host,
f"squeue -j {job_id} -h -o '%T' 2>/dev/null || sacct -j {job_id} -n -o State -P 2>/dev/null",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, _ = await proc.communicate()
state = stdout.decode().strip().split("\n")[0].strip() if stdout else "UNKNOWN"
return {"job_id": job_id, "state": state}
async def cancel_job(self, job_id: str) -> None:
"""Cancel a running job."""
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=10", "-o", "StrictHostKeyChecking=no",
self.host, f"scancel {job_id}",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
logger.info("Cancelled Slurm job %s on %s", job_id, self.server.name)
================================================
FILE: researchclaw/servers/ssh_executor.py
================================================
"""SSH remote executor: upload code, run experiments, download results."""
from __future__ import annotations
import asyncio
import logging
import shlex
from pathlib import Path
from typing import Any
from researchclaw.servers.registry import ServerEntry
logger = logging.getLogger(__name__)
class SSHExecutor:
"""Execute experiments on remote servers via SSH/rsync."""
def __init__(self, server: ServerEntry) -> None:
self.server = server
self.host = server.host
async def upload_code(self, local_dir: Path, remote_dir: str) -> None:
"""Upload experiment code via rsync."""
local = str(local_dir.resolve()) + "/"
remote = f"{self.host}:{remote_dir}/"
logger.info("Uploading %s -> %s", local, remote)
proc = await asyncio.create_subprocess_exec(
"rsync", "-az", "--delete",
"-e", "ssh -o ConnectTimeout=10 -o StrictHostKeyChecking=no",
local, remote,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"rsync upload failed: {stderr.decode().strip()}")
async def run_experiment(
self,
remote_dir: str,
command: str,
timeout: int = 3600,
) -> dict[str, Any]:
"""Run an experiment command on the remote server."""
full_cmd = f"cd {shlex.quote(remote_dir)} && {command}"
logger.info("Running on %s: %s", self.host, full_cmd)
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=10", "-o", "StrictHostKeyChecking=no",
self.host, full_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
return {"success": False, "error": f"Timeout after {timeout}s", "returncode": -1}
return {
"success": proc.returncode == 0,
"stdout": stdout.decode(),
"stderr": stderr.decode(),
"returncode": proc.returncode,
}
async def download_results(self, remote_dir: str, local_dir: Path) -> None:
"""Download experiment results via rsync."""
local_dir.mkdir(parents=True, exist_ok=True)
remote = f"{self.host}:{remote_dir}/"
local = str(local_dir.resolve()) + "/"
logger.info("Downloading %s -> %s", remote, local)
proc = await asyncio.create_subprocess_exec(
"rsync", "-az",
"-e", "ssh -o ConnectTimeout=10 -o StrictHostKeyChecking=no",
remote, local,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"rsync download failed: {stderr.decode().strip()}")
async def cleanup(self, remote_dir: str) -> None:
"""Remove remote experiment directory."""
logger.info("Cleaning up %s:%s", self.host, remote_dir)
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=10", "-o", "StrictHostKeyChecking=no",
self.host, f"rm -rf {shlex.quote(remote_dir)}",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
================================================
FILE: researchclaw/skills/__init__.py
================================================
"""Dynamic skills library for AutoResearchClaw.
Provides a registry of reusable research/engineering/writing skills
that can be automatically matched to pipeline stages and injected
into LLM prompts.
"""
from researchclaw.skills.schema import Skill
from researchclaw.skills.registry import SkillRegistry
__all__ = ["Skill", "SkillRegistry"]
================================================
FILE: researchclaw/skills/builtin/__init__.py
================================================
================================================
FILE: researchclaw/skills/builtin/domain/cv-classification/SKILL.md
================================================
---
name: cv-classification
description: Best practices for image classification tasks. Use when working on CIFAR, ImageNet, or other classification benchmarks.
metadata:
category: domain
trigger-keywords: "classification,image,cifar,imagenet,resnet,vision,cnn,vit"
applicable-stages: "9,10"
priority: "3"
version: "1.0"
author: researchclaw
references: "He et al., Deep Residual Learning, CVPR 2016; Dosovitskiy et al., An Image is Worth 16x16 Words, ICLR 2021"
---
## Image Classification Best Practice
Architecture selection:
- Small scale (CIFAR-10/100): ResNet-18/34, WideResNet, Simple ViT
- Medium scale: ResNet-50, EfficientNet-B0/B1, DeiT-Small
- Large scale: ViT-B/16, ConvNeXt, Swin Transformer
Training recipe:
- Optimizer: AdamW (lr=1e-3 to 3e-4) or SGD (lr=0.1 with cosine decay)
- Weight decay: 0.01-0.1 for AdamW, 5e-4 for SGD
- Data augmentation: RandomCrop, RandomHorizontalFlip, Cutout/CutMix
- Warmup: 5-10 epochs linear warmup for transformers
- Batch size: 128-256 for CNNs, 512-1024 for ViTs (if memory allows)
Standard benchmarks:
- CIFAR-10: ~96% (ResNet-18), ~97% (WideResNet)
- CIFAR-100: ~80% (ResNet-18), ~84% (WideResNet)
- ImageNet: ~76% (ResNet-50), ~81% (ViT-B/16)
================================================
FILE: researchclaw/skills/builtin/domain/cv-detection/SKILL.md
================================================
---
name: cv-detection
description: Best practices for object detection tasks. Use when working on COCO, VOC, or detection architectures like YOLO and DETR.
metadata:
category: domain
trigger-keywords: "detection,object,bbox,yolo,coco,anchor,faster rcnn"
applicable-stages: "9,10"
priority: "5"
version: "1.0"
author: researchclaw
references: "Ren et al., Faster R-CNN, NeurIPS 2015; Carion et al., End-to-End Object Detection with Transformers, ECCV 2020"
---
## Object Detection Best Practice
Architecture families:
- One-stage: YOLO (v5/v8), SSD, RetinaNet, FCOS
- Two-stage: Faster R-CNN, Cascade R-CNN
- Transformer: DETR, DINO, RT-DETR
Training recipe:
- Use pre-trained backbone (ImageNet)
- Multi-scale training and testing
- IoU threshold: 0.5 for mAP50, 0.5:0.95 for mAP
- Use FPN for multi-scale feature extraction
- Focal loss for class imbalance in one-stage detectors
Standard benchmarks:
- COCO val2017: ~37 mAP (Faster R-CNN R50), ~51 mAP (DINO Swin-L)
- Pascal VOC: ~80 mAP50 (Faster R-CNN)
================================================
FILE: researchclaw/skills/builtin/domain/nlp-alignment/SKILL.md
================================================
---
name: nlp-alignment
description: Best practices for LLM alignment techniques including RLHF, DPO, and instruction tuning. Use when working on alignment or safety.
metadata:
category: domain
trigger-keywords: "alignment,rlhf,dpo,reward model,preference,instruction tuning,safety"
applicable-stages: "9,10"
priority: "4"
version: "1.0"
author: researchclaw
references: "Ouyang et al., Training language models to follow instructions, NeurIPS 2022; Rafailov et al., DPO, NeurIPS 2023"
---
## LLM Alignment Best Practice
Methods:
- RLHF: Train reward model → PPO fine-tuning (complex but powerful)
- DPO: Direct preference optimization (simpler, no reward model needed)
- GRPO: Group relative policy optimization
- SFT: Supervised fine-tuning as alignment baseline
Training recipe:
- Start with SFT on high-quality instruction data
- DPO: lr=5e-7, beta=0.1, batch_size=64
- PPO: lr=1e-6, clip=0.2, KL coeff=0.02
- Use reference model for KL penalty
- Evaluate on safety benchmarks (TruthfulQA, BBQ, etc.)
Common pitfalls:
- Reward hacking: model finds shortcuts to high reward
- Mode collapse: model generates repetitive outputs
- Catastrophic forgetting: loses general capabilities
================================================
FILE: researchclaw/skills/builtin/domain/nlp-pretraining/SKILL.md
================================================
---
name: nlp-pretraining
description: Best practices for language model pretraining and fine-tuning. Use when generating or reviewing NLP training code.
metadata:
category: domain
trigger-keywords: "language model,pretraining,fine-tuning,bert,gpt,llm,transformer,nlp,text"
applicable-stages: "9,10"
priority: "3"
version: "1.0"
author: researchclaw
references: "Devlin et al., BERT, NAACL 2019; Hu et al., LoRA, ICLR 2022"
---
## NLP Pretraining/Fine-tuning Best Practice
Fine-tuning recipe:
- Use pre-trained checkpoints (HuggingFace hub)
- AdamW optimizer, lr=2e-5 to 5e-5
- Linear warmup (6% of total steps) + linear decay
- Batch size: 16-32 (use gradient accumulation for larger effective batch)
- 3-5 epochs for classification, 1-2 for generation
- Weight decay: 0.01
Parameter-efficient methods:
- LoRA: r=8-64, alpha=16-128, apply to q/v projections
- Prefix tuning: 10-20 prefix tokens
- Adapters: bottleneck dimension 64-256
Evaluation:
- Classification: accuracy, F1 (macro for imbalanced)
- Generation: perplexity, BLEU/ROUGE, human evaluation
- Use multiple seeds and report mean +/- std
================================================
FILE: researchclaw/skills/builtin/domain/rl-policy-optimization/SKILL.md
================================================
---
name: rl-policy-optimization
description: Best practices for reinforcement learning policy optimization. Use when working on RL agents, PPO, SAC, or reward design.
metadata:
category: domain
trigger-keywords: "reinforcement learning,rl,policy,reward,agent,environment,ppo,sac"
applicable-stages: "9,10"
priority: "3"
version: "1.0"
author: researchclaw
references: "Schulman et al., Proximal Policy Optimization, 2017; Haarnoja et al., Soft Actor-Critic, ICML 2018"
---
## RL Policy Optimization Best Practice
Algorithm selection:
- Discrete actions: PPO, DQN, A2C
- Continuous actions: SAC, TD3, PPO
- Multi-agent: MAPPO, QMIX
- Offline: CQL, IQL, Decision Transformer
Training recipe:
- PPO: clip=0.2, lr=3e-4, gamma=0.99, GAE lambda=0.95
- SAC: lr=3e-4, tau=0.005, auto-tune alpha
- Use vectorized environments (e.g., gymnasium.vector)
- Normalize observations and rewards
- Log episode return, episode length, value loss, policy entropy
Evaluation:
- Report mean +/- std over 10+ evaluation episodes
- Use deterministic policy for evaluation
- Compare against random policy and simple baselines
- Report sample efficiency (return vs. env steps)
Common pitfalls:
- Reward shaping can introduce bias
- Seed sensitivity is HIGH — use 5+ seeds
- Hyperparameter sensitivity — do a small sweep
================================================
FILE: researchclaw/skills/builtin/experiment/experimental-design/SKILL.md
================================================
---
name: experimental-design
description: Best practices for designing reproducible ML experiments. Use when planning ablations, baselines, or controlled experiments.
metadata:
category: experiment
trigger-keywords: "experiment,ablation,baseline,control,hypothesis,reproducib"
applicable-stages: "9,10,12"
priority: "2"
version: "1.0"
author: researchclaw
references: "Bouthillier et al., Accounting for Variance in ML Benchmarks, MLSys 2021"
---
## Experimental Design Best Practice
1. ALWAYS include meaningful baselines (not just random):
- At least one classical method baseline
- At least one recent SOTA method baseline
- A simple-but-strong baseline (e.g., linear probe, k-NN)
2. Use MULTIPLE random seeds (minimum 3, ideally 5)
3. Report mean +/- std across seeds
4. Design ablations that isolate EACH key component:
- Remove one component at a time
- Each ablation must be meaningfully different from baseline
5. Control variables: change only ONE thing per comparison
6. Use standard splits (train/val/test) — never test on training data
7. Report wall-clock time and memory usage alongside accuracy
================================================
FILE: researchclaw/skills/builtin/experiment/meta-analysis/SKILL.md
================================================
---
name: meta-analysis
description: Statistical methods for combining results across multiple studies. Use when aggregating cross-study or cross-experiment results.
metadata:
category: experiment
trigger-keywords: "meta-analysis,effect size,pooled,cross-study,aggregat"
applicable-stages: "7,14"
priority: "5"
version: "1.0"
author: researchclaw
references: "Borenstein et al., Introduction to Meta-Analysis, 2009"
---
## Meta-Analysis Best Practice
When comparing results across studies or experiments:
1. Report effect sizes, not just p-values
2. Use standardized metrics for cross-study comparison
3. Account for heterogeneity (different setups, datasets, seeds)
4. Report confidence intervals alongside point estimates
5. Use forest plots to visualize cross-study comparisons
6. Identify and discuss outliers or inconsistent results
7. Consider publication bias when interpreting aggregate results
================================================
FILE: researchclaw/skills/builtin/experiment/systematic-review/SKILL.md
================================================
---
name: systematic-review
description: Structured methodology for comprehensive literature review following PRISMA guidelines. Use during literature search and screening stages.
metadata:
category: experiment
trigger-keywords: "literature,review,survey,related work,prior work"
applicable-stages: "3,4,5,6"
priority: "3"
version: "1.0"
author: researchclaw
references: "Page et al., The PRISMA 2020 statement, BMJ 2021"
---
## Systematic Review Best Practice
Follow PRISMA-like methodology for literature search:
1. Define clear inclusion/exclusion criteria BEFORE searching
2. Use multiple databases (Semantic Scholar, arXiv, OpenAlex)
3. Search with both broad and narrow queries
4. Screen by title/abstract first, then full text
5. Extract: method, dataset, metrics, key findings
6. Synthesize gaps and opportunities, not just summaries
7. Prioritize recent (last 2-3 years) high-citation papers
8. Include at least one seminal/foundational paper per sub-topic
================================================
FILE: researchclaw/skills/builtin/tooling/data-loading/SKILL.md
================================================
---
name: data-loading
description: Optimize data loading pipeline to prevent GPU starvation. Use when setting up DataLoader or data preprocessing.
metadata:
category: tooling
trigger-keywords: "data,loading,dataloader,dataset,preprocessing,augmentation"
applicable-stages: "10"
priority: "6"
version: "1.0"
author: researchclaw
references: "PyTorch Data Loading Tutorial, pytorch.org"
---
## Efficient Data Loading Best Practice
1. Use num_workers = min(8, os.cpu_count()) for DataLoader
2. Enable pin_memory=True when using GPU
3. Use persistent_workers=True to avoid re-spawning
4. Pre-compute and cache transformations when possible
5. For image data: use torchvision.transforms.v2 (faster)
6. For large datasets: consider memory-mapped files or WebDataset
7. Profile with torch.utils.bottleneck to find I/O bottlenecks
================================================
FILE: researchclaw/skills/builtin/tooling/distributed-training/SKILL.md
================================================
---
name: distributed-training
description: Multi-GPU and distributed training patterns with PyTorch DDP. Use when scaling training across GPUs.
metadata:
category: tooling
trigger-keywords: "distributed,multi-gpu,parallel,ddp,scale"
applicable-stages: "10,12"
priority: "7"
version: "1.0"
author: researchclaw
references: "PyTorch DDP Tutorial, pytorch.org; Goyal et al., Accurate Large Minibatch SGD, 2017"
---
## Distributed Training Best Practice
1. Use DistributedDataParallel (DDP) over DataParallel for multi-GPU
2. Initialize process group: dist.init_process_group(backend='nccl')
3. Use DistributedSampler for data sharding
4. Synchronize batch norm: nn.SyncBatchNorm.convert_sync_batchnorm()
5. Only save checkpoint on rank 0
6. Scale learning rate linearly with world size
7. Use gradient accumulation for effectively larger batch sizes
================================================
FILE: researchclaw/skills/builtin/tooling/mixed-precision/SKILL.md
================================================
---
name: mixed-precision
description: Use FP16/BF16 mixed precision to accelerate training and reduce memory. Use when optimizing GPU performance.
metadata:
category: tooling
trigger-keywords: "training,gpu,memory,speed,precision,fp16,bf16"
applicable-stages: "10,12"
priority: "5"
version: "1.0"
author: researchclaw
references: "Micikevicius et al., Mixed Precision Training, ICLR 2018"
code-template: |
scaler = torch.cuda.amp.GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with torch.cuda.amp.autocast():
output = model(batch)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
---
## Mixed Precision Training Best Practice
Use torch.cuda.amp for automatic mixed precision:
- Wrap forward pass in torch.cuda.amp.autocast()
- Use GradScaler for loss scaling
- BF16 preferred over FP16 on Ampere+ GPUs (RTX 3xxx, A100, RTX 4xxx)
- Watch for NaN gradients — reduce learning rate if needed
- Do NOT use amp with custom CUDA kernels unless tested
================================================
FILE: researchclaw/skills/builtin/tooling/pytorch-training/SKILL.md
================================================
---
name: pytorch-training
description: Best practices for building robust PyTorch training loops. Use when generating or reviewing ML training code.
metadata:
category: tooling
trigger-keywords: "training,pytorch,torch,deep learning,neural network,model"
applicable-stages: "10,12"
priority: "3"
version: "1.0"
author: researchclaw
references: "PyTorch Performance Tuning Guide, pytorch.org"
code-template: |
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# Reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# Training loop
model.train()
for epoch in range(num_epochs):
for batch in train_loader:
optimizer.zero_grad(set_to_none=True)
loss = criterion(model(batch['input']), batch['target'])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
---
## PyTorch Training Best Practice
1. Use torch.manual_seed() for reproducibility (set for torch, numpy, random)
2. Use DataLoader with num_workers>0 and pin_memory=True for GPU
3. Enable cudnn.benchmark=True for fixed input sizes
4. Use learning rate schedulers (CosineAnnealingLR or OneCycleLR)
5. Implement early stopping based on validation metric
6. Log metrics every epoch, save best model checkpoint
7. Use torch.no_grad() for evaluation
8. Clear gradients with optimizer.zero_grad(set_to_none=True) for efficiency
================================================
FILE: researchclaw/skills/loader.py
================================================
"""Skill file loader — supports YAML, JSON, and SKILL.md (agentskills.io)."""
from __future__ import annotations
import json
import logging
from pathlib import Path
import yaml
from researchclaw.skills.schema import Skill
logger = logging.getLogger(__name__)
# ── SKILL.md loader ──────────────────────────────────────────────────
def load_skill_from_skillmd(path: Path) -> Skill | None:
"""Load a skill from a ``SKILL.md`` file (agentskills.io format).
Expected layout::
---
name: kebab-case-id
description: one-liner
metadata:
category: domain
trigger-keywords: "kw1,kw2"
---
Markdown body here ...
Args:
path: Path to the SKILL.md file.
Returns:
Parsed :class:`Skill`, or *None* on failure.
"""
try:
text = path.read_text(encoding="utf-8")
except Exception as exc:
logger.warning("Failed to read SKILL.md at %s: %s", path, exc)
return None
# Split on YAML frontmatter markers
parts = text.split("---", 2)
if len(parts) < 3:
logger.warning("SKILL.md missing frontmatter delimiters: %s", path)
return None
try:
header = yaml.safe_load(parts[1])
except Exception as exc:
logger.warning("Invalid YAML frontmatter in %s: %s", path, exc)
return None
if not isinstance(header, dict):
logger.warning("Frontmatter is not a dict in %s", path)
return None
name = str(header.get("name", ""))
if not name:
logger.warning("SKILL.md missing 'name' field: %s", path)
return None
description = str(header.get("description", ""))
body = parts[2].strip()
# Build metadata — flatten nested 'metadata' dict from frontmatter
metadata: dict[str, str] = {}
raw_meta = header.get("metadata")
if isinstance(raw_meta, dict):
for k, v in raw_meta.items():
metadata[str(k)] = str(v)
# Also pull top-level keys that map to metadata
for key in ("category", "license", "compatibility", "version", "author"):
if key in header and key not in metadata:
metadata[key] = str(header[key])
skill_license = str(header.get("license", ""))
compatibility = str(header.get("compatibility", ""))
return Skill(
name=name,
description=description,
body=body,
license=skill_license,
compatibility=compatibility,
metadata=metadata,
source_dir=path.parent,
source_format="skillmd",
)
def load_skillmd_from_directory(directory: Path) -> list[Skill]:
"""Scan *directory* for ``*/SKILL.md`` sub-directories.
Each immediate sub-directory containing a ``SKILL.md`` file is
treated as a single skill.
"""
skills: list[Skill] = []
if not directory.exists():
return skills
for skill_md in sorted(directory.rglob("SKILL.md")):
skill = load_skill_from_skillmd(skill_md)
if skill:
skills.append(skill)
return skills
# ── Legacy YAML / JSON loader ────────────────────────────────────────
def load_skill_file(path: Path) -> Skill | None:
"""Load a single skill from a YAML or JSON file.
Args:
path: Path to the skill file.
Returns:
Parsed Skill object, or None if loading fails.
"""
try:
text = path.read_text(encoding="utf-8")
if path.suffix in (".yaml", ".yml"):
data = yaml.safe_load(text)
elif path.suffix == ".json":
data = json.loads(text)
else:
logger.warning("Unsupported skill file format: %s", path)
return None
if not isinstance(data, dict):
logger.warning("Skill file is not a dict: %s", path)
return None
skill = Skill.from_dict(data)
if not skill.name:
logger.warning("Skill missing name/id: %s", path)
return None
return skill
except Exception as exc:
logger.warning("Failed to load skill from %s: %s", path, exc)
return None
def load_skills_from_directory(directory: Path) -> list[Skill]:
"""Recursively load all skills from a directory.
Supports both ``SKILL.md`` (agentskills.io) and legacy YAML/JSON.
When both formats exist for the same skill name, SKILL.md wins.
Args:
directory: Root directory to scan.
Returns:
List of successfully loaded Skill objects.
"""
skills_by_name: dict[str, Skill] = {}
if not directory.exists():
return []
# 1. Load SKILL.md files first (higher priority)
for skill in load_skillmd_from_directory(directory):
skills_by_name[skill.name] = skill
# 2. Load legacy YAML/JSON (only if no SKILL.md with same name)
for pattern in ("*.yaml", "*.yml", "*.json"):
for path in sorted(directory.rglob(pattern)):
if path.name == "__init__.py":
continue
skill = load_skill_file(path)
if skill and skill.name not in skills_by_name:
skills_by_name[skill.name] = skill
skills = list(skills_by_name.values())
logger.info("Loaded %d skills from %s", len(skills), directory)
return skills
================================================
FILE: researchclaw/skills/matcher.py
================================================
"""Skill-to-stage matching engine."""
from __future__ import annotations
import logging
import re
from researchclaw.skills.schema import STAGE_NAME_TO_NUMBER, Skill
logger = logging.getLogger(__name__)
def _tokenize(text: str) -> set[str]:
"""Extract lowercase tokens from text."""
return set(re.findall(r"[a-z0-9_]+", text.lower()))
def _resolve_stage(stage: int | str) -> int:
"""Convert a stage name to its number, or pass through an int."""
if isinstance(stage, int):
return stage
return STAGE_NAME_TO_NUMBER.get(stage, -1)
def match_skills(
skills: list[Skill],
context: str,
stage: int | str,
top_k: int = 3,
*,
fallback_matching: bool = True,
) -> list[Skill]:
"""Match skills to the current context and stage.
Scoring:
- Stage applicability (must match, or empty = all stages)
- Keyword overlap with context
- Description-based fallback at 0.5x discount (for skills without trigger_keywords)
- Priority (lower = higher priority)
Args:
skills: Available skills to match against.
context: Current task context text.
stage: Current pipeline stage number or name.
top_k: Maximum number of skills to return.
fallback_matching: Enable description-based matching for skills
without trigger_keywords.
Returns:
List of matched skills sorted by relevance.
"""
stage_num = _resolve_stage(stage)
context_tokens = _tokenize(context)
scored: list[tuple[float, Skill]] = []
for skill in skills:
# Filter by stage applicability
if skill.applicable_stages and stage_num not in skill.applicable_stages:
continue
# Keyword matching score
keyword_score = 0.0
has_keywords = bool(skill.trigger_keywords)
for kw in skill.trigger_keywords:
kw_tokens = _tokenize(kw)
if kw_tokens & context_tokens:
keyword_score += 1.0
# Description-based fallback for external skills without keywords
if keyword_score == 0.0 and not has_keywords and fallback_matching:
desc_tokens = _tokenize(skill.description)
overlap = len(desc_tokens & context_tokens)
if overlap > 0:
keyword_score = overlap * 0.5 # 0.5x discount
max_possible = max(len(desc_tokens), 1)
normalized_kw = keyword_score / max_possible
else:
continue
elif keyword_score == 0.0:
continue
else:
max_possible = max(len(skill.trigger_keywords), 1)
normalized_kw = keyword_score / max_possible
# Priority adjustment (priority 1 → boost 0.5, priority 10 → boost 0.0)
priority_boost = (10 - skill.priority) / 20.0
total_score = normalized_kw + priority_boost
scored.append((total_score, skill))
scored.sort(key=lambda x: (-x[0], x[1].priority))
return [skill for _, skill in scored[:top_k]]
def format_skills_for_prompt(skills: list[Skill], max_chars: int = 4000) -> str:
"""Format matched skills as prompt injection text.
Uses ``skill.body`` as primary content. Truncates long bodies
(common with external skills) to ``max_chars / len(skills)`` per skill.
Args:
skills: List of matched skills.
max_chars: Maximum character limit.
Returns:
Formatted string for LLM prompt injection.
"""
if not skills:
return ""
per_skill_budget = max_chars // max(len(skills), 1)
parts: list[str] = []
total_len = 0
for skill in skills:
content = skill.body or skill.prompt_template
# Truncate long bodies
if len(content) > per_skill_budget:
content = content[:per_skill_budget - 20] + "\n\n[... truncated]"
section = f"### {skill.name} ({skill.category})\n{content}"
if skill.code_template:
section += f"\n**Code Template:**\n```python\n{skill.code_template}\n```"
if skill.references:
section += "\n**References:** " + "; ".join(skill.references)
if total_len + len(section) > max_chars:
break
parts.append(section)
total_len += len(section)
return "\n\n".join(parts)
================================================
FILE: researchclaw/skills/registry.py
================================================
"""Skill registry — central hub for loading and querying skills."""
from __future__ import annotations
import logging
from pathlib import Path
from researchclaw.skills.loader import load_skills_from_directory
from researchclaw.skills.matcher import format_skills_for_prompt, match_skills
from researchclaw.skills.schema import Skill
logger = logging.getLogger(__name__)
# Default builtin directory relative to this file
_BUILTIN_DIR = Path(__file__).parent / "builtin"
class SkillRegistry:
"""Central registry for managing and querying skills.
Loads builtin skills on init, then optionally loads custom and
external skills from user-specified directories.
"""
def __init__(
self,
builtin_dir: str | Path = "",
custom_dirs: tuple[str, ...] | list[str] = (),
external_dirs: tuple[str, ...] | list[str] = (),
auto_match: bool = True,
max_skills_per_stage: int = 3,
fallback_matching: bool = True,
) -> None:
self._skills: dict[str, Skill] = {}
self._auto_match = auto_match
self._max_skills = max_skills_per_stage
self._fallback_matching = fallback_matching
# Load builtin skills
builtin = Path(builtin_dir) if builtin_dir else _BUILTIN_DIR
self._load_from_dir(builtin)
# Load custom skills
for d in custom_dirs:
self._load_from_dir(Path(d))
# Load external skills (same mechanism)
for d in external_dirs:
self._load_from_dir(Path(d))
def _load_from_dir(self, directory: Path) -> None:
"""Load skills from a directory and register them."""
skills = load_skills_from_directory(directory)
for skill in skills:
self.register(skill)
def register(self, skill: Skill) -> None:
"""Register a skill. Overwrites existing skill with same name.
Args:
skill: The skill to register.
"""
self._skills[skill.name] = skill
logger.debug("Registered skill: %s", skill.name)
def unregister(self, skill_id: str) -> bool:
"""Remove a skill from the registry.
Args:
skill_id: The name/ID of the skill to remove.
Returns:
True if skill was found and removed.
"""
if skill_id in self._skills:
del self._skills[skill_id]
return True
return False
def get(self, skill_id: str) -> Skill | None:
"""Get a skill by name/ID."""
return self._skills.get(skill_id)
def list_all(self) -> list[Skill]:
"""Return all registered skills."""
return list(self._skills.values())
def list_by_category(self, category: str) -> list[Skill]:
"""Return skills filtered by category."""
return [s for s in self._skills.values() if s.category == category]
def list_by_stage(self, stage: int) -> list[Skill]:
"""Return skills applicable to a specific stage."""
return [
s for s in self._skills.values()
if not s.applicable_stages or stage in s.applicable_stages
]
def match(
self,
context: str,
stage: int | str,
top_k: int | None = None,
) -> list[Skill]:
"""Match skills to current context and stage.
Args:
context: Task context text.
stage: Current pipeline stage number or name.
top_k: Max results (defaults to max_skills_per_stage).
Returns:
List of matched skills sorted by relevance.
"""
k = top_k or self._max_skills
return match_skills(
list(self._skills.values()),
context,
stage,
top_k=k,
fallback_matching=self._fallback_matching,
)
def export_for_prompt(
self,
skills: list[Skill],
max_chars: int = 4000,
) -> str:
"""Format matched skills as prompt injection text.
Args:
skills: List of matched skills.
max_chars: Character limit.
Returns:
Formatted prompt text.
"""
return format_skills_for_prompt(skills, max_chars=max_chars)
def count(self) -> int:
"""Return total number of registered skills."""
return len(self._skills)
================================================
FILE: researchclaw/skills/schema.py
================================================
"""Skill data model definition (agentskills.io compatible)."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# Maps pipeline stage names to stage numbers.
STAGE_NAME_TO_NUMBER: dict[str, int] = {
"topic_init": 1,
"problem_decompose": 2,
"search_strategy": 3,
"literature_collect": 4,
"literature_screen": 5,
"knowledge_extract": 6,
"synthesis": 7,
"hypothesis_gen": 8,
"experiment_design": 9,
"code_generation": 10,
"resource_planning": 11,
"experiment_run": 12,
"iterative_refine": 13,
"result_analysis": 14,
"research_decision": 15,
"paper_outline": 16,
"paper_draft": 17,
"peer_review": 18,
"paper_revision": 19,
"quality_gate": 20,
"knowledge_archive": 21,
"export_publish": 22,
"citation_verify": 23,
}
# Valid categories in the new taxonomy.
VALID_CATEGORIES = ("writing", "domain", "experiment", "tooling")
@dataclass
class Skill:
"""A single skill definition (agentskills.io compatible).
Standard fields follow the agentskills.io specification.
Legacy YAML fields are accessible via backward-compat properties
that read from ``metadata``.
"""
# agentskills.io standard fields
name: str
description: str
body: str = ""
license: str = ""
compatibility: str = ""
metadata: dict[str, str] = field(default_factory=dict)
# filesystem context
source_dir: Path | None = None
source_format: str = "skillmd" # "skillmd" | "yaml"
# ── backward-compat property accessors ───────────────────────
@property
def id(self) -> str: # noqa: A003
"""Alias for ``name`` (legacy)."""
return self.name
@property
def category(self) -> str:
return self.metadata.get("category", "domain")
@property
def trigger_keywords(self) -> list[str]:
raw = self.metadata.get("trigger-keywords", "")
return [k.strip() for k in raw.split(",") if k.strip()] if raw else []
@property
def applicable_stages(self) -> list[int]:
raw = self.metadata.get("applicable-stages", "")
if not raw:
return []
parts: list[int] = []
for tok in raw.split(","):
tok = tok.strip()
if tok.isdigit():
parts.append(int(tok))
return parts
@property
def priority(self) -> int:
return int(self.metadata.get("priority", "5"))
@property
def prompt_template(self) -> str:
"""Alias for ``body`` (legacy)."""
return self.body
@property
def code_template(self) -> str | None:
return self.metadata.get("code-template") or None
@property
def references(self) -> list[str]:
raw = self.metadata.get("references", "")
return [r.strip() for r in raw.split(";") if r.strip()] if raw else []
@property
def version(self) -> str:
return self.metadata.get("version", "1.0")
# ── serialization ────────────────────────────────────────────
def to_dict(self) -> dict[str, Any]:
"""Serialize to dictionary (legacy-compatible output)."""
return {
"id": self.name,
"name": self.name,
"category": self.category,
"description": self.description,
"trigger_keywords": self.trigger_keywords,
"applicable_stages": self.applicable_stages,
"prompt_template": self.body,
"code_template": self.code_template,
"references": self.references,
"version": self.version,
"priority": self.priority,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Skill:
"""Deserialize from a legacy YAML/JSON dictionary."""
# Pack legacy top-level fields into metadata
meta: dict[str, str] = {}
if data.get("category"):
meta["category"] = str(data["category"])
kw = data.get("trigger_keywords") or []
if kw:
meta["trigger-keywords"] = ",".join(str(k) for k in kw)
stages = data.get("applicable_stages") or []
if stages:
meta["applicable-stages"] = ",".join(str(s) for s in stages)
if data.get("priority") is not None:
meta["priority"] = str(data["priority"])
if data.get("version"):
meta["version"] = str(data["version"])
if data.get("code_template"):
meta["code-template"] = str(data["code_template"])
refs = data.get("references") or []
if refs:
meta["references"] = "; ".join(str(r) for r in refs)
# Merge any explicit metadata from the dict
if isinstance(data.get("metadata"), dict):
for k, v in data["metadata"].items():
meta.setdefault(str(k), str(v))
name = str(data.get("name") or data.get("id") or "")
# For legacy YAML, use 'id' if 'name' looks like a display name
# and 'id' looks like a slug
raw_id = str(data.get("id", ""))
if raw_id and "-" in raw_id:
name = raw_id
return cls(
name=name,
description=str(data.get("description", "")),
body=str(data.get("prompt_template", "")),
metadata=meta,
source_format="yaml",
)
================================================
FILE: researchclaw/templates/__init__.py
================================================
"""Conference-aware LaTeX template system.
Supports automatic template switching for NeurIPS, ICLR, and ICML.
Given a target conference name, generates a complete ``.tex`` file from
Markdown paper content + BibTeX references.
Usage::
from researchclaw.templates import get_template, markdown_to_latex
tpl = get_template("neurips_2025")
tex = markdown_to_latex(paper_md, tpl, title=..., authors=..., bib_file="references.bib")
"""
from researchclaw.templates.conference import (
CONFERENCE_REGISTRY,
ConferenceTemplate,
get_template,
list_conferences,
)
from researchclaw.templates.converter import markdown_to_latex
__all__ = [
"CONFERENCE_REGISTRY",
"ConferenceTemplate",
"get_template",
"list_conferences",
"markdown_to_latex",
]
================================================
FILE: researchclaw/templates/compiler.py
================================================
"""LaTeX compilation and error repair utilities (IMP-18).
Provides ``compile_latex()`` which attempts ``pdflatex`` compilation,
parses the log for common errors, applies automated fixes, and retries
up to 3 times. Designed to run inside ``_package_deliverables()`` so
that the final paper.tex in ``deliverables/`` is compile-tested.
If pdflatex is not installed the module gracefully returns a failure
report without raising.
"""
from __future__ import annotations
import logging
import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
logger = logging.getLogger(__name__)
# BUG-201: Cyrillic → Latin transliteration for author names from Semantic Scholar.
# pdflatex without T2A font encoding chokes on Cyrillic (e.g. "А. И. Колесников").
_CYRILLIC_TO_LATIN_MAP: dict[str, str] = {
"А": "A", "Б": "B", "В": "V", "Г": "G", "Д": "D", "Е": "E",
"Ё": "E", "Ж": "Zh", "З": "Z", "И": "I", "Й": "Y", "К": "K",
"Л": "L", "М": "M", "Н": "N", "О": "O", "П": "P", "Р": "R",
"С": "S", "Т": "T", "У": "U", "Ф": "F", "Х": "Kh", "Ц": "Ts",
"Ч": "Ch", "Ш": "Sh", "Щ": "Shch", "Ъ": "", "Ы": "Y", "Ь": "",
"Э": "E", "Ю": "Yu", "Я": "Ya",
"а": "a", "б": "b", "в": "v", "г": "g", "д": "d", "е": "e",
"ё": "e", "ж": "zh", "з": "z", "и": "i", "й": "y", "к": "k",
"л": "l", "м": "m", "н": "n", "о": "o", "п": "p", "р": "r",
"с": "s", "т": "t", "у": "u", "ф": "f", "х": "kh", "ц": "ts",
"ч": "ch", "ш": "sh", "щ": "shch", "ъ": "", "ы": "y", "ь": "",
"э": "e", "ю": "yu", "я": "ya",
}
@dataclass
class CompileResult:
"""Outcome of a LaTeX compilation attempt."""
success: bool
log_excerpt: str = ""
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
fixes_applied: list[str] = field(default_factory=list)
attempts: int = 0
def compile_latex(
tex_path: Path,
*,
max_attempts: int = 3,
timeout: int = 120,
) -> CompileResult:
"""Compile *tex_path* with pdflatex, auto-fixing common errors.
Parameters
----------
tex_path:
Path to the ``.tex`` file. Must be inside a directory that also
contains ``references.bib`` and any required ``.sty`` files.
max_attempts:
Maximum compile→fix cycles.
timeout:
Seconds before killing a stuck pdflatex process.
Returns
-------
CompileResult
Contains success flag, log excerpt, errors found, and fixes applied.
"""
if not shutil.which("pdflatex"):
return CompileResult(
success=False,
log_excerpt="pdflatex not found on PATH",
errors=["pdflatex not installed"],
)
result = CompileResult(success=False)
work_dir = tex_path.parent
tex_name = tex_path.name
bib_stem = tex_name.rsplit(".", 1)[0]
# Pre-flight: sanitize .bib file (escape bare & in field values)
# Find bib filename from \bibliography{...} in the tex source
_tex_src = tex_path.read_text(encoding="utf-8", errors="replace")
_bib_match = re.search(r"\\bibliography\{([^}]+)\}", _tex_src)
_bib_name = _bib_match.group(1) if _bib_match else bib_stem
_sanitize_bib_file(work_dir / f"{_bib_name}.bib")
# BUG-197: Pre-flight — strip invisible/problematic Unicode from .tex.
# Characters like U+202F (NARROW NO-BREAK SPACE) cause pdflatex to emit
# broken UTF-8 in error messages, which crashes subprocess text decoding
# and prevents the bibtex + multi-pass pipeline from completing.
_sanitize_tex_unicode(tex_path)
for attempt in range(1, max_attempts + 1):
result.attempts = attempt
# --- Full 3-pass compilation: pdflatex → bibtex → pdflatex × 2 ---
# Pass 1: generate .aux (needed by bibtex). Use nonstopmode (NOT
# halt-on-error) so .aux is written even when there are non-fatal
# errors like missing figures or overfull hboxes.
log_text, pass1_ok = _run_pdflatex(work_dir, tex_name, timeout)
if log_text is None:
result.errors.append(f"pdflatex failed on pass 1 (attempt {attempt})")
break
# BibTeX: always run after pass 1 — it only needs .aux + .bib.
# Previously gated behind pass1 success, which meant citations were
# always [?] when the first pass had non-fatal errors.
_run_bibtex(work_dir, bib_stem, timeout=60)
# Passes 2-3: resolve cross-references and bibliography
for _pass in (2, 3):
pass_log, _ = _run_pdflatex(work_dir, tex_name, timeout)
if pass_log is not None:
log_text = pass_log # keep final pass log for error analysis
# Parse the final log for errors/warnings
errors, warnings = _parse_log(log_text)
result.warnings = warnings
result.log_excerpt = log_text[-2000:] if len(log_text) > 2000 else log_text
# Check for fatal errors only — non-fatal ones (overfull hbox,
# missing figure in draft) don't prevent a valid PDF.
fatal = [e for e in errors if _is_fatal_error(e)]
result.errors = errors
if not fatal:
result.success = True
logger.info("IMP-18: LaTeX compiled successfully on attempt %d", attempt)
break
# Try to auto-fix fatal errors
tex_text = tex_path.read_text(encoding="utf-8")
fixed_text, fixes = fix_common_latex_errors(tex_text, errors)
if fixes:
result.fixes_applied.extend(fixes)
tex_path.write_text(fixed_text, encoding="utf-8")
logger.info(
"IMP-18: Applied %d fixes on attempt %d: %s",
len(fixes),
attempt,
fixes,
)
else:
# No fixes available — stop retrying
logger.warning(
"IMP-18: Compilation failed on attempt %d with %d unfixable errors",
attempt,
len(fatal),
)
break
return result
def fix_common_latex_errors(
tex_text: str, errors: list[str]
) -> tuple[str, list[str]]:
"""Apply automated fixes for common LaTeX errors.
Returns ``(fixed_text, list_of_fix_descriptions)``.
"""
fixes: list[str] = []
fixed = tex_text
# --- Pre-error-loop fixes: structural repairs that prevent compilation ---
# Fix escaped braces in tabular column specs: \{lcccc\} → {lcccc}
if re.search(r"\\begin\{tabular\}\\\{", fixed):
fixed = re.sub(
r"\\begin\{tabular\}\\\{([^}]*?)\\\}",
r"\\begin{tabular}{\1}",
fixed,
)
fixes.append("Fixed escaped braces in tabular column specs")
# Fix escaped & inside tabular data rows: \& → & (column separator).
# The converter's _escape_latex escapes & globally; inside tabular
# environments the & must remain unescaped as the column separator.
if "\\begin{tabular}" in fixed and "\\&" in fixed:
fixed, n_tab_amp = _fix_escaped_ampersand_in_tabular(fixed)
if n_tab_amp:
fixes.append(f"Un-escaped \\& in {n_tab_amp} tabular data row(s)")
# Fix escaped \} at end of \caption{...}: \caption{text.\}} → \caption{text.}
if re.search(r"\\caption\{.*?\\\}", fixed):
fixed = re.sub(
r"(\\caption\{[^}]*?)\\\}",
r"\1}",
fixed,
)
fixes.append("Fixed escaped \\} in \\caption arguments")
# Collapse multiple consecutive \clearpage into one
if re.search(r"(\\clearpage\s*){2,}", fixed):
fixed = re.sub(r"(\\clearpage\s*){2,}", "\\\\clearpage\n", fixed)
fixes.append("Collapsed multiple \\clearpage commands")
# Remove \textbf{Figure N.} paragraphs that follow \end{figure}
dup_cap = re.search(
r"(\\end\{figure\})\s*\n\s*\\textbf\{Figure\s+\d+",
fixed,
)
if dup_cap:
fixed = re.sub(
r"(\\end\{figure\})\s*\n\s*\\textbf\{Figure\s+\d+[.:].*?\}\s*\n",
r"\1\n",
fixed,
)
fixes.append("Removed duplicate bold Figure captions after \\end{figure}")
# BUG-189: Fix Python-style pseudocode inside algorithmic environments.
# LLM generates `# comment` (LaTeX macro param char) and `var_name`
# (unescaped underscore) inside \STATE commands — causes cascading errors.
_algo_pat = re.compile(
r"(\\begin\{algorithmic\}.*?\\end\{algorithmic\})", re.DOTALL
)
def _fix_algo_block(m: re.Match) -> str:
block = m.group(0)
lines = block.split("\n")
out: list[str] = []
for line in lines:
if line.strip().startswith(("\\begin{", "\\end{")):
out.append(line)
continue
# Replace # (Python comment) with \COMMENT{...}
if "#" in line and "\\#" not in line:
line = re.sub(r"#\s*(.*)$", r"\\COMMENT{\1}", line)
# Escape bare underscores not already in math mode
# Don't touch \STATE, \IF, \FOR, etc. commands
parts = re.split(r"(\\\w+\{[^}]*\}|\$[^$]+\$)", line)
fixed_parts = []
for part in parts:
if part.startswith("\\") or part.startswith("$"):
fixed_parts.append(part)
else:
fixed_parts.append(re.sub(r'(? tuple[list[str], list[str]]:
"""Parse pdflatex log output for errors and warnings."""
errors: list[str] = []
warnings: list[str] = []
for line in log_text.split("\n"):
line_stripped = line.strip()
line_lower = line_stripped.lower()
if line_stripped.startswith("!"):
errors.append(line_stripped)
elif "LaTeX Warning:" in line_stripped:
warnings.append(line_stripped)
# BUG-R6-26: Use elif to avoid duplicating "!" lines
elif "Undefined control sequence" in line_stripped:
errors.append(line_stripped)
elif "Missing" in line_stripped and "inserted" in line_stripped:
errors.append(line_stripped)
elif "File" in line_stripped and "not found" in line_stripped:
errors.append(line_stripped)
# BUG-R6-21: Detect "Float(s) lost" and "Too many unprocessed floats"
# even when they don't start with "!"
elif "float(s) lost" in line_lower:
errors.append(line_stripped)
elif "too many unprocessed floats" in line_lower:
errors.append(line_stripped)
return errors, warnings
@dataclass
class QualityCheckResult:
"""Results of post-compilation quality checks."""
unresolved_refs: list[str] = field(default_factory=list)
unresolved_cites: list[str] = field(default_factory=list)
overfull_hboxes: list[str] = field(default_factory=list)
underfull_hboxes: list[str] = field(default_factory=list)
page_count: int = 0
orphan_figures: list[str] = field(default_factory=list)
orphan_labels: list[str] = field(default_factory=list)
warnings_summary: list[str] = field(default_factory=list)
@property
def has_critical_issues(self) -> bool:
return bool(self.unresolved_refs or self.unresolved_cites)
def check_compiled_quality(
tex_path: Path,
*,
page_limit: int = 10,
) -> QualityCheckResult:
"""Run post-compilation quality checks on a LaTeX document.
Parses the .log file and .tex source for:
- Unresolved references (??)
- Unresolved citations
- Overfull/underfull hboxes
- Page count vs limit
- Orphan figures (defined but never referenced, or vice versa)
"""
result = QualityCheckResult()
work_dir = tex_path.parent
stem = tex_path.stem
# --- Parse .log file ---
log_path = work_dir / f"{stem}.log"
if log_path.exists():
log_text = log_path.read_text(encoding="utf-8", errors="replace")
for line in log_text.split("\n"):
line_s = line.strip()
# Unresolved references
if "LaTeX Warning: Reference" in line_s and "undefined" in line_s:
result.unresolved_refs.append(line_s)
# Unresolved citations
if "LaTeX Warning: Citation" in line_s and "undefined" in line_s:
result.unresolved_cites.append(line_s)
# Overfull hboxes (only flag significant ones > 1pt)
if "Overfull \\hbox" in line_s:
m = re.search(r"(\d+\.?\d*)pt", line_s)
if m and float(m.group(1)) > 1.0:
result.overfull_hboxes.append(line_s)
# Underfull hboxes (badness >= 5000)
if "Underfull \\hbox" in line_s and "badness" in line_s:
m = re.search(r"badness (\d+)", line_s)
if m and int(m.group(1)) >= 5000:
result.underfull_hboxes.append(line_s)
# --- Count pages from .aux or .log ---
aux_path = work_dir / f"{stem}.aux"
if aux_path.exists():
aux_text = aux_path.read_text(encoding="utf-8", errors="replace")
# Look for \newlabel{LastPage}{{N}{...}}
m = re.search(r"\\newlabel\{LastPage\}\{\{(\d+)\}", aux_text)
if m:
result.page_count = int(m.group(1))
if result.page_count == 0 and log_path.exists():
# Fallback: count "Output written on ... (N pages)"
m = re.search(r"Output written on .* \((\d+) page", log_text)
if m:
result.page_count = int(m.group(1))
# --- Cross-reference validation ---
tex_text = tex_path.read_text(encoding="utf-8", errors="replace")
# Find all \label{fig:X}
fig_labels = set(re.findall(r"\\label\{(fig:[^}]+)\}", tex_text))
# Find all \ref{fig:X}
fig_refs = set(re.findall(r"\\ref\{(fig:[^}]+)\}", tex_text))
# Orphan labels (defined but never referenced)
result.orphan_labels = sorted(fig_labels - fig_refs)
# Orphan references (referenced but never defined)
result.orphan_figures = sorted(fig_refs - fig_labels)
# --- Build warnings summary ---
if result.unresolved_refs:
result.warnings_summary.append(
f"{len(result.unresolved_refs)} unresolved reference(s)"
)
if result.unresolved_cites:
result.warnings_summary.append(
f"{len(result.unresolved_cites)} unresolved citation(s)"
)
if result.overfull_hboxes:
result.warnings_summary.append(
f"{len(result.overfull_hboxes)} overfull hbox(es) > 1pt"
)
if result.page_count > page_limit:
result.warnings_summary.append(
f"Page count {result.page_count} exceeds limit {page_limit}"
)
if result.orphan_figures:
result.warnings_summary.append(
f"{len(result.orphan_figures)} referenced but undefined figure(s): "
+ ", ".join(result.orphan_figures[:3])
)
if result.orphan_labels:
result.warnings_summary.append(
f"{len(result.orphan_labels)} defined but unreferenced figure(s): "
+ ", ".join(result.orphan_labels[:3])
)
return result
def remove_missing_figures(tex_text: str, stage_dir: Path) -> tuple[str, list[str]]:
"""Remove \\begin{figure}...\\end{figure} blocks that reference missing images.
Returns ``(fixed_text, list_of_removed_paths)``.
"""
removed: list[str] = []
def _check_fig(m: re.Match) -> str:
block = m.group(0)
img_match = re.search(r"\\includegraphics.*?\{([^}]+)\}", block)
if img_match:
img_rel = img_match.group(1)
img_path = stage_dir / img_rel
if not img_path.exists():
# Try prefix-matching: fig_main_results.png → fig_main_results_comparison.png
parent = img_path.parent
stem = img_path.stem # e.g. "fig_main_results"
if parent.exists():
candidates = sorted(parent.glob(f"{stem}*.png"))
if len(candidates) == 1:
new_rel = str(candidates[0].relative_to(stage_dir))
logger.info(
"Auto-mapped missing figure: %s → %s",
img_rel, new_rel,
)
return block.replace(img_rel, new_rel)
logger.warning(
"Removing figure block with missing image: %s",
img_rel,
)
removed.append(img_rel)
return "" # Remove the entire figure block
return block
fixed = re.sub(
r"\\begin\{figure\}.*?\\end\{figure\}",
_check_fig,
tex_text,
flags=re.DOTALL,
)
# Clean up orphan \ref{fig:X} that point to removed/nonexistent figures.
# These render as "??" in the PDF.
if removed:
remaining_labels = set(re.findall(r"\\label\{(fig:[^}]+)\}", fixed))
all_fig_refs = set(re.findall(r"\\ref\{(fig:[^}]+)\}", fixed))
orphan = all_fig_refs - remaining_labels
for oref in orphan:
# Replace "Figure \ref{fig:X}" or "Fig. \ref{fig:X}" with empty
fixed = re.sub(
rf"(?:Figure|Fig\.?)\s*~?\\ref\{{{re.escape(oref)}\}}",
"(figure omitted)",
fixed,
)
# Replace standalone \ref{fig:X}
fixed = fixed.replace(f"\\ref{{{oref}}}", "(ref omitted)")
return fixed, removed
def _sanitize_tex_unicode(tex_path: Path) -> None:
"""Strip problematic Unicode characters from .tex source.
BUG-197: Characters like U+202F (NARROW NO-BREAK SPACE), U+2009 (THIN
SPACE), U+00A0 (NO-BREAK SPACE), and other non-ASCII whitespace cause
pdflatex to emit broken UTF-8 in error messages, which crashes Python's
``subprocess.run(text=True)`` and prevents the bibtex + multi-pass
pipeline from completing. These characters appear when LLMs copy-paste
text from web sources or academic papers.
The safe replacement is a normal ASCII space for whitespace-like chars,
and empty string for invisible control chars.
"""
if not tex_path.exists():
return
try:
text = tex_path.read_text(encoding="utf-8", errors="replace")
except Exception:
return
# Whitespace-like Unicode → ASCII space
_UNICODE_SPACES = (
"\u00a0", # NO-BREAK SPACE
"\u202f", # NARROW NO-BREAK SPACE (BUG-197 trigger)
"\u2009", # THIN SPACE
"\u2007", # FIGURE SPACE
"\u2008", # PUNCTUATION SPACE
"\u200a", # HAIR SPACE
"\u205f", # MEDIUM MATHEMATICAL SPACE
"\u3000", # IDEOGRAPHIC SPACE
)
# Invisible control characters → remove
_INVISIBLE_CHARS = (
"\u200e", # LEFT-TO-RIGHT MARK
"\u200f", # RIGHT-TO-LEFT MARK
"\ufeff", # BOM / ZERO-WIDTH NO-BREAK SPACE
"\u200b", # ZERO-WIDTH SPACE
"\u200c", # ZERO-WIDTH NON-JOINER
"\u200d", # ZERO-WIDTH JOINER
"\u00ad", # SOFT HYPHEN
"\u2060", # WORD JOINER
"\u2028", # LINE SEPARATOR
"\u2029", # PARAGRAPH SEPARATOR
)
changed = False
for ch in _UNICODE_SPACES:
if ch in text:
text = text.replace(ch, " ")
changed = True
for ch in _INVISIBLE_CHARS:
if ch in text:
text = text.replace(ch, "")
changed = True
# BUG-201: Transliterate any Cyrillic that leaked into .tex (from bib
# entries inlined by bibtex, or from LLM-generated text).
_has_cyrillic = any("\u0400" <= ch <= "\u04ff" for ch in text)
if _has_cyrillic:
for cyr, lat in _CYRILLIC_TO_LATIN_MAP.items():
if cyr in text:
text = text.replace(cyr, lat)
changed = True
if changed:
tex_path.write_text(text, encoding="utf-8")
logger.info("BUG-197: Sanitized problematic Unicode in %s", tex_path.name)
def _sanitize_bib_file(bib_path: Path) -> None:
"""Sanitize .bib files: escape bare ``&`` and strip invisible Unicode.
BibTeX treats ``&`` as a special character; journal names like
"Science & Technology" must use ``\\&``.
BUG-180: Invisible Unicode characters (U+200E LEFT-TO-RIGHT MARK,
U+200F RIGHT-TO-LEFT MARK, U+FEFF BOM, U+200B ZERO-WIDTH SPACE,
U+200C/U+200D joiners, U+00AD soft hyphen) can appear in
copy-pasted author names and break pdflatex.
"""
if not bib_path.exists():
return
try:
text = bib_path.read_text(encoding="utf-8")
except Exception:
return
# BUG-180: Strip invisible Unicode characters
_INVISIBLE_CHARS = (
"\u200e", # LEFT-TO-RIGHT MARK
"\u200f", # RIGHT-TO-LEFT MARK
"\ufeff", # BOM / ZERO-WIDTH NO-BREAK SPACE
"\u200b", # ZERO-WIDTH SPACE
"\u200c", # ZERO-WIDTH NON-JOINER
"\u200d", # ZERO-WIDTH JOINER
"\u00ad", # SOFT HYPHEN
"\u2060", # WORD JOINER
"\u2028", # LINE SEPARATOR
"\u2029", # PARAGRAPH SEPARATOR
)
for ch in _INVISIBLE_CHARS:
if ch in text:
text = text.replace(ch, "")
# BUG-201: Transliterate Cyrillic characters to Latin equivalents.
# Russian author names (e.g. "А. И. Колесников") from Semantic Scholar
# cause "! LaTeX Error: Unicode character" when pdflatex runs without T2A
# font encoding. Transliterating preserves name readability.
_orig_text = text
for cyr, lat in _CYRILLIC_TO_LATIN_MAP.items():
if cyr in text:
text = text.replace(cyr, lat)
# BUG-217: Strip literal escape sequences (\n, \r, \t) in bib field values.
# These appear when API responses embed Python-style escapes into titles.
# A literal `\n` is never a valid BibTeX/LaTeX command and causes
# "Undefined control sequence" errors during compilation.
text = re.sub(r"\\n(?=\s)", " ", text)
text = re.sub(r"\\r(?=\s)", "", text)
text = re.sub(r"\\t(?=\s)", " ", text)
lines = text.split("\n")
changed = text != _orig_text
for i, line in enumerate(lines):
stripped = line.strip()
# Only fix field-value lines (e.g. journal = {Science & Technology},)
# Skip @type{ lines, key lines, and URL/DOI fields (BUG-DA8-12)
if "=" in stripped and "{" in stripped and "&" in stripped and "\\&" not in stripped:
_field_name = stripped.split("=", 1)[0].strip().lower()
if _field_name in ("url", "doi", "howpublished", "eprint"):
continue # Don't escape & in URLs
lines[i] = line.replace("&", "\\&")
changed = True
new_text = "\n".join(lines)
if new_text != text or changed:
bib_path.write_text(new_text, encoding="utf-8")
logger.info("Sanitized bib file %s", bib_path.name)
def _fix_escaped_ampersand_in_tabular(tex: str) -> tuple[str, int]:
"""Replace ``\\&`` with ``&`` inside tabular environments.
Only touches data rows (between \\toprule/\\midrule/\\bottomrule)
to avoid corrupting ``\\&`` in regular text. Returns the fixed text
and the count of rows fixed.
"""
count = 0
def _fix_tabular(m: re.Match[str]) -> str:
nonlocal count
block = m.group(0)
if "\\&" not in block:
return block
# Only un-escape \& on lines that look like data rows (contain \\)
lines = block.split("\n")
for i, line in enumerate(lines):
if "\\&" in line and "\\\\" in line:
lines[i] = line.replace("\\&", "&")
count += 1
return "\n".join(lines)
tex = re.sub(
r"\\begin\{tabular\}.*?\\end\{tabular\}",
_fix_tabular,
tex,
flags=re.DOTALL,
)
return tex, count
def _run_pdflatex(
work_dir: Path,
tex_name: str,
timeout: int = 120,
) -> tuple[str | None, bool]:
"""Run a single pdflatex pass with ``-interaction=nonstopmode``.
Returns ``(log_text, success)``. *log_text* is ``None`` only on
hard failures (timeout, binary missing).
BUG-197: Uses bytes mode with manual UTF-8 decoding (errors="replace")
instead of ``text=True``. pdflatex stdout can contain broken UTF-8
sequences (e.g. from U+202F NARROW NO-BREAK SPACE error messages),
which cause ``UnicodeDecodeError`` with ``text=True`` and kill the
entire compilation pipeline — bibtex never runs, all citations [?].
"""
try:
proc = subprocess.run(
["pdflatex", "-interaction=nonstopmode", tex_name],
cwd=work_dir,
capture_output=True,
timeout=timeout,
)
except subprocess.TimeoutExpired:
logger.warning("pdflatex timed out after %ds", timeout)
return None, False
except FileNotFoundError:
return None, False
stdout = proc.stdout.decode("utf-8", errors="replace")
stderr = proc.stderr.decode("utf-8", errors="replace")
log_text = stdout + "\n" + stderr
return log_text, proc.returncode == 0
# Fatal error patterns — these prevent a valid PDF from being generated.
# Non-fatal issues (overfull hbox, missing figure, float warnings) still
# produce a usable PDF and should NOT trigger the auto-fix retry loop.
_FATAL_ERROR_PATTERNS = [
"runaway argument",
"emergency stop",
"fatal error",
"undefined control sequence",
"missing $ inserted",
"extra alignment tab",
"misplaced alignment tab",
"missing \\begin{document}",
"file `" , # file not found (sty, cls)
"file not found",
]
def _is_fatal_error(err: str) -> bool:
"""Return True if *err* represents a fatal LaTeX error."""
err_lower = err.lower()
# "!" prefix errors are almost always fatal
if err.startswith("!"):
# Non-fatal "!" errors — PDF is still generated
if "overfull" in err_lower or "underfull" in err_lower:
return False
if "float(s) lost" in err_lower:
return False
if "too many unprocessed floats" in err_lower:
return False
# amsmath commands outside math mode — PDF still generates
if "allowed only in math mode" in err_lower:
return False
# Encoding errors for special characters — PDF still generates
if "unavailable in encoding" in err_lower:
return False
# BUG-197: Unicode character errors (e.g. U+202F NARROW NO-BREAK
# SPACE "not set up for use with LaTeX") — pdflatex skips the
# character and generates a valid PDF. Treating these as fatal
# prevents the retry loop from succeeding and blocks bibtex.
# The error line is "! LaTeX Error: Unicode character X (U+XXXX)"
# — the "not set up" text is on a continuation line.
if "unicode character" in err_lower:
return False
return True
for pat in _FATAL_ERROR_PATTERNS:
if pat in err_lower:
return True
return False
def _run_bibtex(work_dir: Path, stem: str, timeout: int = 60) -> bool:
"""Run bibtex if the binary exists. Returns True on success.
BUG-197: Uses bytes mode with manual UTF-8 decoding (errors="replace")
to avoid ``UnicodeDecodeError`` from non-ASCII bib content. Logs
failures so that silent bibtex issues are diagnosable.
"""
if not shutil.which("bibtex"):
logger.warning("bibtex not found on PATH — citations will be [?]")
return False
try:
proc = subprocess.run(
["bibtex", stem],
cwd=work_dir,
capture_output=True,
timeout=timeout,
)
stdout = proc.stdout.decode("utf-8", errors="replace")
stderr = proc.stderr.decode("utf-8", errors="replace")
if proc.returncode != 0:
logger.warning(
"bibtex returned %d: %s",
proc.returncode,
(stdout + stderr).strip()[:500],
)
return False
# Log bibtex output at debug level for diagnostics
if stdout.strip():
logger.debug("bibtex output: %s", stdout.strip()[:300])
# Verify .bbl was actually generated
bbl_path = work_dir / f"{stem}.bbl"
if not bbl_path.exists():
logger.warning("bibtex ran but %s.bbl was not generated", stem)
return False
return True
except subprocess.TimeoutExpired:
logger.warning("bibtex timed out after %ds", timeout)
return False
except FileNotFoundError:
return False
================================================
FILE: researchclaw/templates/conference.py
================================================
"""Conference template definitions for NeurIPS, ICLR, and ICML.
Each template stores the LaTeX preamble, document structure, author format,
and bibliography style needed to produce a submission-ready ``.tex`` file.
Style files (``.sty``) are NOT bundled — the generated ``.tex`` references
them, and users download the official files from the conference website.
Download URLs are included as comments in the output.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
# Root directory for bundled style files
_STYLES_DIR = Path(__file__).parent / "styles"
@dataclass(frozen=True)
class ConferenceTemplate:
"""LaTeX template specification for one conference."""
name: str
display_name: str
year: int
document_class: str
style_package: str
style_options: str
extra_packages: tuple[str, ...]
author_format: str # "neurips" | "iclr" | "icml"
bib_style: str
columns: int # 1 or 2
style_download_url: str
preamble_extra: str = ""
def render_preamble(
self,
title: str,
authors: str,
abstract: str,
) -> str:
# Style options (e.g. "preprint") go on the style package, not documentclass
options = f"[{self.style_options}]" if self.style_options else ""
pkg_lines = "\n".join(f"\\usepackage{{{p}}}" for p in self.extra_packages)
author_block = self._render_authors(authors)
# Substitute __TITLE__ placeholder in preamble_extra (e.g. ICML running title)
preamble_extra = self.preamble_extra.replace("__TITLE__", title)
style_line = (
f"\\usepackage{options}{{{self.style_package}}}\n"
if self.style_package
else ""
)
style_comment = (
f"% Style file: {self.style_download_url}\n"
if self.style_download_url
else ""
)
# BUG-51 fix: ICML's \begin{icmlauthorlist} is an environment that
# must appear AFTER \begin{document}. For non-ICML templates the
# \author{} command is a preamble declaration and stays before.
if self.author_format == "icml":
preamble_author = ""
post_doc_author = f"{author_block}\n\n"
else:
preamble_author = f"{author_block}\n"
post_doc_author = ""
return (
f"{style_comment}"
f"\\documentclass{{{self.document_class}}}\n"
f"{style_line}"
f"{pkg_lines}\n"
f"{preamble_extra}\n"
f"\n"
f"\\title{{{title}}}\n"
f"\n"
f"{preamble_author}"
f"\n"
f"\\begin{{document}}\n"
f"{post_doc_author}"
f"\\begin{{abstract}}\n"
f"{abstract}\n"
f"\\end{{abstract}}\n"
f"\n"
f"\\maketitle\n"
)
def render_footer(self, bib_file: str = "references") -> str:
return (
f"\n\\bibliographystyle{{{self.bib_style}}}\n"
f"\\bibliography{{{bib_file}}}\n"
f"\n"
f"\\end{{document}}\n"
)
def get_style_files(self) -> list[Path]:
"""Return paths to bundled ``.sty`` and ``.bst`` files for this template.
Files are stored under ``researchclaw/templates/styles//``.
Returns only files that exist on disk.
"""
style_dir = _STYLES_DIR / self.name
if not style_dir.is_dir():
return []
return sorted(
p for p in style_dir.iterdir()
if p.suffix in {".sty", ".bst", ".cls"}
)
def _render_authors(self, authors: str) -> str:
if self.author_format == "icml":
return (
f"\\begin{{icmlauthorlist}}\n"
f"\\icmlauthor{{{authors}}}{{aff1}}\n"
f"\\end{{icmlauthorlist}}\n"
f"\\icmlaffiliation{{aff1}}{{Affiliation}}"
)
return f"\\author{{{authors}}}"
# ---------------------------------------------------------------------------
# Template definitions
# ---------------------------------------------------------------------------
# -- Legacy (kept for backward compat) --
NEURIPS_2024 = ConferenceTemplate(
name="neurips_2024",
display_name="NeurIPS 2024",
year=2024,
document_class="article",
style_package="neurips_2024",
style_options="preprint",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"nicefrac",
"microtype",
"graphicx",
"natbib",
"algorithm",
"algorithmic",
"adjustbox",
),
author_format="neurips",
bib_style="plainnat",
columns=1,
style_download_url="https://media.neurips.cc/Conferences/NeurIPS2024/Styles.zip",
preamble_extra="\\usepackage[utf8]{inputenc}\n\\usepackage[T1]{fontenc}\n\\usepackage{lmodern}",
)
ICLR_2025 = ConferenceTemplate(
name="iclr_2025",
display_name="ICLR 2025",
year=2025,
document_class="article",
style_package="iclr2025_conference",
style_options="",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"nicefrac",
"microtype",
"graphicx",
"natbib",
"algorithm",
"algorithmic",
"adjustbox",
),
author_format="iclr",
bib_style="iclr2025_conference",
columns=1,
style_download_url="https://github.com/ICLR/Master-Template/raw/master/iclr2025.zip",
)
ICML_2025 = ConferenceTemplate(
name="icml_2025",
display_name="ICML 2025",
year=2025,
document_class="article",
style_package="icml2025",
style_options="",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"nicefrac",
"microtype",
"graphicx",
"natbib",
"algorithm",
"algorithmic",
"adjustbox",
),
author_format="icml",
bib_style="icml2025",
columns=2,
style_download_url="https://icml.cc/Conferences/2025/StyleAuthorInstructions",
preamble_extra="\\icmltitlerunning{__TITLE__}",
)
# -- Current (2025/2026) --
NEURIPS_2025 = ConferenceTemplate(
name="neurips_2025",
display_name="NeurIPS 2025",
year=2025,
document_class="article",
style_package="neurips_2025",
style_options="preprint",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"nicefrac",
"microtype",
"graphicx",
"natbib",
"algorithm",
"algorithmic",
"adjustbox",
),
author_format="neurips",
bib_style="plainnat",
columns=1,
style_download_url="https://media.neurips.cc/Conferences/NeurIPS2025/Styles.zip",
preamble_extra="\\usepackage[utf8]{inputenc}\n\\usepackage[T1]{fontenc}\n\\usepackage{lmodern}",
)
ICLR_2026 = ConferenceTemplate(
name="iclr_2026",
display_name="ICLR 2026",
year=2026,
document_class="article",
style_package="iclr2026_conference",
style_options="",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"nicefrac",
"microtype",
"graphicx",
"natbib",
"algorithm",
"algorithmic",
"adjustbox",
),
author_format="iclr",
bib_style="iclr2026_conference",
columns=1,
style_download_url="https://github.com/ICLR/Master-Template",
)
ICML_2026 = ConferenceTemplate(
name="icml_2026",
display_name="ICML 2026",
year=2026,
document_class="article",
style_package="icml2026",
style_options="",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"nicefrac",
"microtype",
"graphicx",
"natbib",
"algorithm",
"algorithmic",
"adjustbox",
"morefloats",
),
author_format="icml",
bib_style="icml2026",
columns=2,
style_download_url="https://icml.cc/Conferences/2026/AuthorInstructions",
preamble_extra="\\icmltitlerunning{__TITLE__}",
)
# -- Generic (non-ML) --
GENERIC = ConferenceTemplate(
name="generic",
display_name="Generic Academic Paper",
year=2025,
document_class="article",
style_package="",
style_options="",
extra_packages=(
"hyperref",
"url",
"booktabs",
"amsfonts",
"amsmath",
"graphicx",
"natbib",
"geometry",
"adjustbox",
),
author_format="neurips",
bib_style="plainnat",
columns=1,
style_download_url="",
preamble_extra="\\usepackage[utf8]{inputenc}\n\\usepackage[T1]{fontenc}\n\\usepackage{lmodern}\n\\usepackage[margin=1in]{geometry}",
)
# ---------------------------------------------------------------------------
# Registry — short aliases point to LATEST version of each conference
# ---------------------------------------------------------------------------
CONFERENCE_REGISTRY: dict[str, ConferenceTemplate] = {
# Latest (default aliases)
"neurips": NEURIPS_2025,
"iclr": ICLR_2026,
"icml": ICML_2026,
# Generic for non-ML domains
"generic": GENERIC,
"article": GENERIC,
# Versioned keys (all versions)
"neurips_2025": NEURIPS_2025,
"neurips_2024": NEURIPS_2024,
"iclr_2026": ICLR_2026,
"iclr_2025": ICLR_2025,
"icml_2026": ICML_2026,
"icml_2025": ICML_2025,
}
def get_template(name: str) -> ConferenceTemplate:
"""Look up a conference template by name.
Raises ``KeyError`` if *name* is not in the registry.
Accepts both full names (``"neurips_2024"``) and short aliases (``"neurips"``).
"""
key = name.lower().strip().replace("-", "_").replace(" ", "_")
if key not in CONFERENCE_REGISTRY:
available = ", ".join(sorted({t.name for t in CONFERENCE_REGISTRY.values()}))
raise KeyError(f"Unknown conference template: {name!r}. Available: {available}")
return CONFERENCE_REGISTRY[key]
def list_conferences() -> list[str]:
"""Return deduplicated list of canonical template names."""
return sorted({t.name for t in CONFERENCE_REGISTRY.values()})
================================================
FILE: researchclaw/templates/converter.py
================================================
"""Markdown-to-LaTeX converter with conference template support.
Converts a ResearchClaw paper (Markdown with embedded LaTeX math) into a
complete ``.tex`` file using a :class:`ConferenceTemplate` for preamble,
author block, bibliography style, and document structure.
Design constraints:
- **Zero new dependencies** — stdlib only (``re``, ``textwrap``).
- Handles inline math ``\\(...\\)``, display math ``\\[...\\]``,
bold/italic, bullet lists, numbered lists, code blocks, tables,
and ``\\cite{...}`` references.
- Extracts abstract from ``# Abstract`` or ``## Abstract`` section.
- ICML two-column structure handled via template's ``render_preamble``.
"""
from __future__ import annotations
import re
import textwrap
import threading
from dataclasses import dataclass, field
from researchclaw.templates.conference import ConferenceTemplate
_render_counters = threading.local()
def _reset_render_counters() -> None:
"""Reset per-render figure and table counters for the current thread."""
_render_counters.table = 0
_render_counters.figure = 0
def _next_table_num() -> int:
"""Return the next table number for the current thread."""
next_num = getattr(_render_counters, "table", 0) + 1
_render_counters.table = next_num
return next_num
def _next_figure_num() -> int:
"""Return the next figure number for the current thread."""
next_num = getattr(_render_counters, "figure", 0) + 1
_render_counters.figure = next_num
return next_num
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def markdown_to_latex(
paper_md: str,
template: ConferenceTemplate,
*,
title: str = "",
authors: str = "Anonymous",
bib_file: str = "references",
bib_entries: dict[str, str] | None = None,
) -> str:
"""Convert a Markdown paper to a complete LaTeX document.
Parameters
----------
paper_md:
Full paper in Markdown with embedded LaTeX math.
template:
Conference template controlling preamble and structure.
title:
Paper title. If empty, extracted from ``# Title`` heading or the
first ``# ...`` heading in *paper_md*.
authors:
Author string inserted into the template author block.
bib_file:
Bibliography filename (without ``.bib`` extension).
bib_entries:
Optional mapping of author-year patterns to cite_keys for
recovering author-year citations that slipped through earlier
processing, e.g. ``{"Raissi et al., 2019": "raissi2019physics"}``.
Returns
-------
str
A complete ``.tex`` file ready for compilation.
"""
_reset_render_counters()
paper_md = _preprocess_markdown(paper_md)
paper_md = _round_raw_metrics(paper_md)
sections = _parse_sections(paper_md)
# Extract title from first H1 heading if not provided
if not title:
title = _extract_title(sections, paper_md)
# Extract abstract
abstract = _extract_abstract(sections)
# Build body (everything except title/abstract headings)
body = _build_body(sections, title=title)
# IMP-30: Detect and remove duplicate tables
body = _deduplicate_tables(body)
# R10-Fix5: Completeness check
completeness_warnings = check_paper_completeness(sections)
if completeness_warnings:
import logging
_logger = logging.getLogger(__name__)
for warning in completeness_warnings:
_logger.warning("LaTeX completeness check: %s", warning)
# BUG-28: Log warnings only — don't inject comments into LaTeX body
preamble = template.render_preamble(
title=_escape_latex(title),
authors=authors,
abstract=_convert_inline(abstract),
)
footer = template.render_footer(bib_file)
tex = preamble + "\n" + body + footer
# Final sanitization pass on the complete LaTeX output
tex = _sanitize_latex_output(tex, bib_entries=bib_entries)
return tex
# ---------------------------------------------------------------------------
# Post-processing: sanitize final LaTeX
# ---------------------------------------------------------------------------
def _sanitize_latex_output(
tex: str,
*,
bib_entries: dict[str, str] | None = None,
) -> str:
"""Remove artifacts that slip through pre-processing into the final .tex."""
# 0. BUG-102 safety net: Convert remaining author-year citations to \cite{}.
# If upstream conversion missed any [Author et al., 2024] patterns, catch them here.
if bib_entries:
for ay_pattern in sorted(bib_entries, key=len, reverse=True):
cite_key = bib_entries[ay_pattern]
# [Author et al., 2024] → \cite{key}
tex = tex.replace(f"[{ay_pattern}]", f"\\cite{{{cite_key}}}")
# Also handle inside existing brackets (multi-citation)
tex = tex.replace(ay_pattern, f"\\cite{{{cite_key}}}")
# Clean up double-nested \cite from multi-citation brackets:
# [\cite{a}, \cite{b}] → \cite{a, b}
def _merge_bracket_cites(m: re.Match[str]) -> str:
inner = m.group(1)
keys = re.findall(r"\\cite\{([^}]+)\}", inner)
if keys:
return "\\cite{" + ", ".join(keys) + "}"
return m.group(0)
tex = re.sub(r"\[([^\]]*\\cite\{[^\]]+)\]", _merge_bracket_cites, tex)
# 1. Remove broken citation markers: \cite{?key:NOT_IN_BIB} or literal [?key:NOT_IN_BIB]
tex = re.sub(r"\\cite\{\?[^}]*:NOT_IN_BIB\}", "", tex)
tex = re.sub(r"\[\?[a-zA-Z0-9_:-]+:NOT_IN_BIB\]", "", tex)
# 1b. Convert leftover raw bracket citations [key2019word, key2020word] → \cite{...}
# Skip inside verbatim/lstlisting environments to avoid corrupting code blocks.
_CITE_KEY_PAT_L = r"[a-zA-Z][a-zA-Z0-9_-]*\d{4}[a-zA-Z0-9_]*"
_VERBATIM_RE = re.compile(
r"(\\begin\{(?:verbatim|lstlisting|minted)\}.*?\\end\{(?:verbatim|lstlisting|minted)\})",
re.DOTALL,
)
_cite_re = re.compile(
rf"\[({_CITE_KEY_PAT_L}(?:\s*,\s*{_CITE_KEY_PAT_L})*)\]"
)
def _cite_outside_verbatim(tex_src: str) -> str:
parts = _VERBATIM_RE.split(tex_src)
for i, part in enumerate(parts):
if not _VERBATIM_RE.match(part):
parts[i] = _cite_re.sub(r"\\cite{\1}", part)
return "".join(parts)
tex = _cite_outside_verbatim(tex)
# 1c. BUG-110 safety net: Replace any remaining Unicode Greek/math symbols.
# _convert_inline handles most, but titles, captions, and preamble
# fragments can still contain raw Unicode that kills pdflatex.
for _uchar, _lcmd in _UNICODE_GREEK_TO_LATEX.items():
if _uchar in tex:
tex = tex.replace(_uchar, _lcmd)
# 2. Remove HTML entities that survived pre-processing
tex = tex.replace(" ", "~")
tex = tex.replace("&", "\\&")
# 2b. Fix escaped \& inside tabular data rows. The converter's
# _convert_inline escapes & globally; inside tabular environments
# the & must remain unescaped as the column separator.
if "\\begin{tabular}" in tex and "\\&" in tex:
def _fix_tabular_amp(m: re.Match[str]) -> str:
block = m.group(0)
if "\\&" not in block:
return block
lines = block.split("\n")
for i, line in enumerate(lines):
if "\\&" in line and "\\\\" in line:
lines[i] = line.replace("\\&", "&")
return "\n".join(lines)
tex = re.sub(
r"\\begin\{tabular\}.*?\\end\{tabular\}",
_fix_tabular_amp,
tex,
flags=re.DOTALL,
)
# 3. Remove stray markdown code fences in LaTeX body (outside verbatim)
# Only match fences NOT inside \begin{verbatim}...\end{verbatim}
# Simple approach: remove ``` lines that don't have verbatim nearby
tex = re.sub(r"^(\s*```[a-z]*\s*)$", r"% removed stray fence: \1", tex, flags=re.MULTILINE)
# 4. Fix placeholder table captions: \caption{Table N} → descriptive
# Can't auto-generate content, but at least don't leave "Table 1" as
# the only caption text — append " -- See text for details."
tex = re.sub(
r"\\caption\{(Table\s+\d+)\}",
r"\\caption{\1 -- Summary of experimental results.}",
tex,
)
# 4b. Auto-map orphan \ref{fig:X} to closest \label{fig:Y} by prefix.
# The converter generates long labels from captions (fig:overall_cifar_100)
# but the LLM references short names (fig:overall).
fig_labels = set(re.findall(r"\\label\{(fig:[^}]+)\}", tex))
fig_refs = set(re.findall(r"\\ref\{(fig:[^}]+)\}", tex))
orphan_refs = fig_refs - fig_labels
orphan_labels = fig_labels - fig_refs
if orphan_refs and orphan_labels:
for oref in orphan_refs:
# Find a label that starts with the ref prefix
candidates = [l for l in orphan_labels if l.startswith(oref)]
if len(candidates) == 1:
tex = tex.replace(f"\\ref{{{oref}}}", f"\\ref{{{candidates[0]}}}")
orphan_labels.discard(candidates[0])
# 5. Fix "Untitled Paper" / "Running Title" fallback titles
tex = re.sub(
r"\\title\{Untitled Paper\}",
r"\\title{[Title Generation Failed -- Manual Title Required]}",
tex,
)
tex = re.sub(
r"\\icmltitlerunning\{Running Title\}",
"",
tex,
)
# 6. Remove \texttt{} wrapped raw metric paths that the LLM dumped
# Handles both raw underscores and LaTeX-escaped underscores (\_)
# Pattern: condition/env/step/metric_name: value (3+ path segments)
tex = re.sub(
r"\\texttt\{[a-zA-Z0-9_\\_/.:=-]+(?:/[a-zA-Z0-9_\\_/.:=-]+){2,}(?:\s*[=:]\s*[^}]*)?\}",
"",
tex,
)
# 6b. Remove entire \item lines that are just metric paths
tex = re.sub(
r"^\s*\\item\s*\\texttt\{[^}]*\}\s*$",
"",
tex,
flags=re.MULTILINE,
)
# 7. Clean up empty \item lines that result from removed content
tex = re.sub(r"\\item\s*\n\s*\\item", r"\\item", tex)
# Also remove completely empty \item lines (just whitespace after \item)
tex = re.sub(r"^\s*\\item\s*$", "", tex, flags=re.MULTILINE)
# 8. Remove consecutive blank lines (more than 2)
tex = re.sub(r"\n{3,}", "\n\n", tex)
return tex
# ---------------------------------------------------------------------------
# Pre-processing
# ---------------------------------------------------------------------------
_OUTER_FENCE_RE = re.compile(
r"^\s*```(?:markdown|md|latex|tex)?\s*\n(.*?)^\s*```\s*$",
re.MULTILINE | re.DOTALL,
)
# Greedy variant — matches the *last* closing fence so inner code blocks
# (```text … ```) don't truncate the capture prematurely.
_OUTER_FENCE_GREEDY_RE = re.compile(
r"^\s*```(?:markdown|md|latex|tex)?\s*\n(.*)^\s*```\s*$",
re.MULTILINE | re.DOTALL,
)
# Pattern for raw metric values with excessive decimal places
# e.g. 0.9717036975193437 → 0.972
_RAW_METRIC_RE = re.compile(r"(\d+\.\d{5,})")
def _round_raw_metrics(text: str) -> str:
"""Round excessively precise metric values (>4 decimal places).
Uses significant-figure-aware rounding so small values like
learning rates (e.g. 0.00001) are preserved instead of becoming 0.0000.
"""
def _rounder(m: re.Match[str]) -> str:
try:
val = float(m.group(1))
if val == 0.0:
return "0.0"
# For very small values (< 0.001), use 2 significant figures
# to preserve scientific meaning (e.g. lr=0.00003 → 0.00003)
import math
abs_val = abs(val)
if abs_val < 0.001:
sig_figs = 2
digits = sig_figs - int(math.floor(math.log10(abs_val))) - 1
return f"{val:.{digits}f}"
# Normal range: 4 decimal places
return f"{val:.4f}"
except (ValueError, OverflowError):
return m.group(0)
return _RAW_METRIC_RE.sub(_rounder, text)
def _preprocess_markdown(md: str) -> str:
"""Clean up common LLM artifacts before parsing.
1. Strip outer fenced code blocks (e.g. triple-backtick markdown) that LLMs
around the entire paper content.
2. Remove standalone Markdown horizontal rules (``---``, ``***``, ``___``).
3. Convert blockquotes (``> text``) to a form the converter can handle.
4. Round excessively precise metric values.
"""
text = md
# 1. Strip outer markdown fences (LLMs sometimes wrap entire paper in them)
# Repeatedly strip in case of double-wrapping.
# Try greedy match first (handles papers with inner code blocks),
# then fall back to non-greedy if greedy doesn't help.
for _ in range(3):
stripped = False
for pat in (_OUTER_FENCE_GREEDY_RE, _OUTER_FENCE_RE):
m = pat.search(text)
if m and len(m.group(1)) > len(text) * 0.5:
text = m.group(1)
stripped = True
break
if not stripped:
# Also handle the case where the first line is ```markdown
# and the last non-blank line is ``` (simple boundary strip)
lines = text.split("\n")
first = lines[0].strip() if lines else ""
last_idx = len(lines) - 1
while last_idx > 0 and not lines[last_idx].strip():
last_idx -= 1
last = lines[last_idx].strip() if last_idx > 0 else ""
if (
re.match(r"^```(?:markdown|md|latex|tex)?\s*$", first)
and last == "```"
):
text = "\n".join(lines[1:last_idx])
stripped = True
if not stripped:
break
# 2. Remove standalone horizontal rules (---, ***, ___)
text = re.sub(r"^\s*[-*_]{3,}\s*$", "", text, flags=re.MULTILINE)
# 2a. Strip HTML entities that LLMs inject into markdown
text = text.replace(" ", " ")
text = text.replace("&", "&")
text = text.replace("<", "<")
text = text.replace(">", ">")
text = text.replace("—", "---")
text = text.replace("–", "--")
# 2b. Note: stray code fences are handled in _sanitize_latex_output
# after conversion, not here (to avoid breaking real code blocks).
# 2c. Round excessively precise metric values (e.g. 0.9717036975 → 0.9717)
text = _round_raw_metrics(text)
# 2d. Remove raw \texttt{...} or backtick-wrapped metric key paths
# Pattern: \texttt{some/long/metric_path/name: 0.1234} or `path/to/metric: val`
text = re.sub(
r"\\texttt\{[a-zA-Z0-9_/.:=-]+(?:/[a-zA-Z0-9_/.:=-]+){2,}(?:\s*[=:]\s*[^}]*)?\}",
"",
text,
)
# Also strip backtick-wrapped metric paths in markdown source
text = re.sub(
r"`[a-zA-Z0-9_/.-]+(?:/[a-zA-Z0-9_/.-]+){2,}(?:\s*[=:]\s*[^`]*)?`",
"",
text,
)
# 2e. Clean NOT_IN_BIB citation markers: [?key:NOT_IN_BIB] → remove
text = re.sub(r"\[\?[a-zA-Z0-9_:-]+:NOT_IN_BIB\]", "", text)
# 3. Convert blockquotes: > text → \begin{quote}text\end{quote}
# Collect consecutive > lines into a single quote block.
lines = text.split("\n")
out_lines: list[str] = []
in_quote = False
quote_buf: list[str] = []
for line in lines:
stripped = line.strip()
if stripped.startswith("> "):
if not in_quote:
in_quote = True
quote_buf = []
quote_buf.append(stripped[2:])
elif stripped == ">" and in_quote:
quote_buf.append("")
else:
if in_quote:
out_lines.append("\\begin{quote}")
out_lines.extend(quote_buf)
out_lines.append("\\end{quote}")
in_quote = False
quote_buf = []
out_lines.append(line)
if in_quote:
out_lines.append("\\begin{quote}")
out_lines.extend(quote_buf)
out_lines.append("\\end{quote}")
text = "\n".join(out_lines)
# 4. T1.2: Remove stray markdown/latex/text fences that appear mid-document.
# LLMs sometimes emit ```markdown or ```latex between sections.
# Only remove documentation fences — preserve code fences (```python etc.)
_CODE_LANGS = frozenset({
"python", "java", "cpp", "c", "javascript", "typescript", "rust",
"go", "ruby", "bash", "sh", "sql", "r", "julia", "lua", "perl",
"scala", "kotlin", "swift", "haskell", "algorithm", "pseudocode",
})
_lines = text.split("\n")
_cleaned: list[str] = []
_in_code = False
for _l in _lines:
_stripped = _l.strip()
if _stripped.startswith("```") and not _in_code:
_lang = _stripped[3:].strip().lower()
if _lang in _CODE_LANGS or _lang.startswith("algorithm"):
# Real code block — keep
_in_code = True
_cleaned.append(_l)
elif _lang in ("markdown", "md", "latex", "tex", "text", "", "bibtex"):
# Documentation/wrapper fence — remove
pass
else:
# Unknown lang — keep to be safe
_in_code = True
_cleaned.append(_l)
elif _stripped == "```" and _in_code:
# Closing fence for a code block — keep
_in_code = False
_cleaned.append(_l)
elif _stripped == "```" and not _in_code:
# Stray fence — remove
pass
else:
_cleaned.append(_l)
text = "\n".join(_cleaned)
# 5. Normalize mid-line section headings (IMP-17)
# LLM output may concatenate sections onto single long lines:
# "...text ## Abstract Body text ## 1. Introduction More text..."
# Ensure each heading marker starts on its own line so _parse_sections
# can detect them with the ^-anchored regex.
text = re.sub(r"(?<=[^\n]) +(#{1,4}) +", r"\n\n\1 ", text)
return text
# ---------------------------------------------------------------------------
# Section parsing
# ---------------------------------------------------------------------------
@dataclass
class _Section:
"""A parsed Markdown section."""
level: int # 1 = ``#``, 2 = ``##``, 3 = ``###``, etc.
heading: str
body: str
heading_lower: str = field(init=False)
def __post_init__(self) -> None:
self.heading_lower = self.heading.strip().lower()
_HEADING_RE = re.compile(r"^(#{1,4})\s+(.+)$", re.MULTILINE)
# Known section heading names used to separate heading from concatenated body
_KNOWN_SECTION_NAMES = {
"abstract",
"introduction",
"related work",
"background",
"method",
"methods",
"methodology",
"approach",
"framework",
"experiments",
"experiment",
"experimental setup",
"experimental results",
"results",
"results and discussion",
"analysis",
"discussion",
"conclusion",
"conclusions",
"limitations",
"acknowledgments",
"acknowledgements",
"references",
"appendix",
"contributions",
"problem setting",
"problem statement",
"problem definition",
"problem formulation",
"study positioning",
"study positioning and scope",
"evaluation",
"evaluation environment",
"design rationale",
"complexity",
"unified algorithm",
"method positioning",
"methods compared",
"common protonet backbone",
"preference optimization backbone",
}
_HEADING_CONNECTORS = frozenset(
{
"and", "or", "for", "in", "of", "the", "a", "an", "with",
"under", "to", "on", "at", "by", "as", "via", "from",
"not", "but", "yet", "nor", "vs", "versus", "is", "are",
}
)
_SENTENCE_STARTERS = frozenset(
{
"the", "a", "an", "this", "these", "those", "that",
"it", "we", "our", "their", "its", "each", "every",
"in", "for", "to", "here", "there", "however", "moreover",
"furthermore", "additionally", "specifically", "notably",
"all", "many", "several", "some", "most", "both",
"among", "between", "across", "unlike", "given", "such",
"while", "although", "because", "since", "when", "where",
"rather", "let", "table", "figure", "as", "at", "if",
}
)
def _separate_heading_body(heading: str) -> tuple[str, str]:
"""Separate heading text from accidentally concatenated body text.
LLM output may produce lines like ``## Abstract Body text here...``
where the heading is just ``Abstract`` and the rest is body.
Returns (heading, extra_body) where extra_body may be empty.
"""
# Very short headings are fine as-is
if len(heading) <= 60:
return heading, ""
# Strip optional leading section number for matching
num_match = re.match(r"^(\d+(?:\.\d+)*\.?\s+)", heading)
num_prefix = num_match.group(1) if num_match else ""
rest = heading[len(num_prefix):]
rest_lower = rest.lower()
# Check against known section heading names
for name in sorted(_KNOWN_SECTION_NAMES, key=len, reverse=True):
if rest_lower.startswith(name) and len(rest) > len(name) + 1:
after = rest[len(name) :]
if after and after[0] in " \t":
return (num_prefix + rest[: len(name)]).strip(), after.strip()
# Word-count heuristic for unknown subsection headings.
# Scan for the first plausible heading-body boundary.
words = heading.split()
if len(words) > 6:
for n in range(2, min(12, len(words) - 2)):
curr = words[n]
if not curr or not curr[0].isupper():
continue
prev_word = words[n - 1].rstrip(".,;:").lower()
if prev_word in _HEADING_CONNECTORS:
continue
remaining = " ".join(words[n:])
if len(remaining) <= 30:
continue
# Strong signal: common sentence-starting word
if curr.lower() in _SENTENCE_STARTERS:
return " ".join(words[:n]).strip(), remaining.strip()
# Medium signal: next word is lowercase (sentence-like)
# and heading has >= 4 words, body is substantial (> 100 chars)
if n >= 4 and n + 1 < len(words):
next_w = words[n + 1].rstrip(".,;:")
if next_w and next_w[0].islower() and len(remaining) > 100:
return " ".join(words[:n]).strip(), remaining.strip()
# Weak fallback for very long headings (conservative)
if n >= 8 and len(remaining) > 100:
return " ".join(words[:n]).strip(), remaining.strip()
# Detect repeated multi-word opening phrase: the body often starts with
# the same words as the heading (e.g. "Graph-memory methods Graph-memory
# methods maintain a graph...").
half = len(rest) // 2
for phrase_len in range(min(30, half), 14, -1):
phrase = rest[:phrase_len]
if " " not in phrase:
continue
repeat_pos = rest.find(phrase, phrase_len)
if repeat_pos > 0:
return (
(num_prefix + rest[:repeat_pos]).strip(),
rest[repeat_pos:].strip(),
)
# Fallback: try to split at a sentence boundary within first 200 chars
if len(heading) > 200:
m = re.search(r"[.;:]\s+([A-Z])", heading[:300])
if m and m.start() > 10:
return heading[: m.start() + 1].strip(), heading[m.start() + 2 :].strip()
return heading, ""
def _parse_sections(md: str) -> list[_Section]:
"""Split Markdown into a flat list of sections by heading."""
matches = list(_HEADING_RE.finditer(md))
if not matches:
return [_Section(level=1, heading="", body=md)]
sections: list[_Section] = []
# Text before first heading (if any)
if matches[0].start() > 0:
preamble_text = md[: matches[0].start()].strip()
if preamble_text:
sections.append(_Section(level=0, heading="", body=preamble_text))
for i, m in enumerate(matches):
level = len(m.group(1))
heading = m.group(2).strip()
start = m.end()
end = matches[i + 1].start() if i + 1 < len(matches) else len(md)
body = md[start:end].strip()
# IMP-17: Handle concatenated heading+body on same line
heading, body_prefix = _separate_heading_body(heading)
if body_prefix:
body = body_prefix + ("\n\n" + body if body else "")
sections.append(_Section(level=level, heading=heading, body=body))
return sections
# ---------------------------------------------------------------------------
# Extraction helpers
# ---------------------------------------------------------------------------
_TITLE_SKIP = {
"title",
"abstract",
"references",
"appendix",
"acknowledgments",
"acknowledgements",
}
# T1.1: Headings that are NOT valid paper titles (tables, figures, etc.)
_TITLE_REJECT_RE = re.compile(
r"^(?:table|figure|fig\.|tab\.|algorithm|listing|appendix)\s",
re.IGNORECASE,
)
# T1.1: Headings that look like metric dumps rather than titles
_METRIC_DUMP_RE = re.compile(
r"(?:primary_metric|accuracy|loss|f1_score|precision|recall)\b",
re.IGNORECASE,
)
def _extract_title(sections: list[_Section], raw_md: str) -> str:
"""Extract paper title from sections or raw markdown."""
# Look for an explicit "# Title" or "## Title" section whose body is the
# actual title, or whose heading is "## Title Actual Paper Title".
for sec in sections:
if sec.level in (1, 2) and sec.heading_lower == "title":
# The body often starts with **Bold Title** on the first line
first_line = sec.body.split("\n")[0].strip()
# Strip bold markers
first_line = re.sub(r"\*\*(.+?)\*\*", r"\1", first_line)
if first_line and not _is_bad_title(first_line):
return first_line
# Handle "## Title Actual Paper Title" pattern (title embedded in heading)
if sec.level in (1, 2) and sec.heading_lower.startswith("title ") and len(sec.heading) > 6:
return sec.heading[6:].strip()
# Fallback: first H1/H2 heading that isn't a meta-heading or artefact
for sec in sections:
if (
sec.level in (1, 2)
and sec.heading
and sec.heading_lower not in _TITLE_SKIP
and not _is_bad_title(sec.heading)
):
return sec.heading
# Last resort: first non-empty line (still filtered)
for line in raw_md.splitlines():
stripped = line.strip().lstrip("#").strip()
if stripped and not _is_bad_title(stripped):
return stripped
return "Untitled Paper"
def _is_bad_title(candidate: str) -> bool:
"""Return True if *candidate* is clearly not a paper title."""
# Reject "Table 1 – ...", "Figure 2: ...", etc.
if _TITLE_REJECT_RE.match(candidate):
return True
# Reject raw metric key dumps
if _METRIC_DUMP_RE.search(candidate):
return True
# Reject if it contains raw underscore variable names (e.g. primary_metric)
if re.search(r"\w+_\w+/\w+", candidate):
return True
return False
def _extract_abstract(sections: list[_Section]) -> str:
"""Extract abstract text from sections."""
for sec in sections:
if sec.heading_lower == "abstract":
return sec.body
# IMP-17 fallback: heading may still contain body text if
# _separate_heading_body didn't recognise the pattern.
if sec.heading_lower.startswith("abstract ") and len(sec.heading) > 20:
extra = sec.heading[len("Abstract") :].strip()
return extra + ("\n\n" + sec.body if sec.body else "")
return ""
# ---------------------------------------------------------------------------
# Body building
# ---------------------------------------------------------------------------
_SKIP_HEADINGS = {"title", "abstract"}
def _build_body(sections: list[_Section], *, title: str = "") -> str:
"""Convert all non-title/abstract sections to LaTeX body text.
When a paper has its title as an H1 heading (``# My Paper Title``),
that heading is already rendered via ``\\title{}`` in the preamble.
We skip it here and promote remaining headings so that H2 (``##``)
maps to ``\\section``, H3 to ``\\subsection``, etc.
"""
title_lower = title.strip().lower()
# Determine minimum heading level used for real body sections
# (skip title/abstract/references).
title_h1_found = False
for sec in sections:
if (
sec.level == 1
and sec.heading
and sec.heading.strip().lower() == title_lower
):
title_h1_found = True
break
# T1.3: Auto-detect when all body sections use H2 (##) instead of H1 (#).
# This happens when the LLM uses ## for main sections (Introduction, Method, etc.)
# without an explicit H1 title heading. We must promote H2→\section.
body_levels: set[int] = set()
for sec in sections:
if sec.heading_lower not in _SKIP_HEADINGS and sec.level >= 1:
if not (sec.level == 1 and sec.heading.strip().lower() == title_lower):
body_levels.add(sec.level)
min_body_level = min(body_levels) if body_levels else 1
# Promote if: (a) title was H1 and body starts at H2, OR
# (b) no title H1 found but all body sections are H2+ (LLM omitted H1 title)
# BUG-166: When title is H1 AND body also uses H1 for main sections,
# offset must be 0 — otherwise H1→max(1,1-1)=1 and H2→max(1,2-1)=1
# both collapse to \section, losing all subsection hierarchy.
if title_h1_found:
level_offset = 1 if min_body_level >= 2 else 0
elif min_body_level >= 2:
# All body sections are H2 or deeper — promote so H2→\section
level_offset = min_body_level - 1
else:
level_offset = 0
_level_map = {
1: "section",
2: "subsection",
3: "subsubsection",
4: "paragraph",
}
parts: list[str] = []
for sec in sections:
# Skip title-only and abstract sections
if sec.heading_lower in _SKIP_HEADINGS:
continue
# Skip the H1 heading that was used as the paper title
if (
sec.level == 1
and sec.heading
and sec.heading.strip().lower() == title_lower
):
continue
if sec.level == 0:
# Preamble text before any heading — include as-is
parts.append(_convert_block(sec.body))
continue
effective_level = max(1, sec.level - level_offset)
cmd = _level_map.get(effective_level, "paragraph")
heading_tex = _escape_latex(sec.heading)
# Strip leading manual section numbers: "1. Introduction" → "Introduction"
# Handles: "1 Intro", "2.1 Related", "3.2.1 Details", "1. Intro"
heading_tex = re.sub(r"^\d+(?:\.\d+)*\.?\s+", "", heading_tex)
parts.append(f"\\{cmd}{{{heading_tex}}}")
# Generate a label for cross-referencing
if cmd in ("section", "subsection", "subsubsection"):
label_key = re.sub(r"[^a-z0-9]+", "_", heading_tex.lower()).strip("_")[:40]
if label_key:
parts.append(f"\\label{{sec:{label_key}}}")
if sec.body:
parts.append(_convert_block(sec.body))
return "\n\n".join(parts) + "\n"
def _deduplicate_tables(body: str) -> str:
"""IMP-30: Remove duplicate tables that share the same header row.
LLMs sometimes repeat tables (e.g. same results table in Results and
Discussion). We keep the first occurrence and drop subsequent copies.
"""
import logging as _dup_log
_TABLE_ENV_RE = re.compile(
r"(\\begin\{table\}.*?\\end\{table\})", re.DOTALL
)
tables = list(_TABLE_ENV_RE.finditer(body))
if len(tables) < 2:
return body
seen_headers: dict[str, int] = {}
drop_spans: list[tuple[int, int]] = []
for m in tables:
table_text = m.group(1)
# Extract header row (first row after \toprule)
header_match = re.search(r"\\toprule\s*\n(.+?)\\\\", table_text)
if not header_match:
continue
header_key = re.sub(r"\s+", " ", header_match.group(1).strip())
if header_key in seen_headers:
drop_spans.append((m.start(), m.end()))
_dup_log.getLogger(__name__).info(
"IMP-30: Dropping duplicate table (same header as table #%d)",
seen_headers[header_key],
)
else:
seen_headers[header_key] = len(seen_headers) + 1
# Remove duplicates in reverse order to preserve offsets
for start, end in reversed(drop_spans):
body = body[:start] + body[end:]
return body
# ---------------------------------------------------------------------------
# Block-level conversion
# ---------------------------------------------------------------------------
# Patterns for block-level structures
_DISPLAY_MATH_RE = re.compile(r"^\\\[(.+?)\\\]$", re.MULTILINE | re.DOTALL)
# $$...$$ display math (single- or multi-line)
_DISPLAY_MATH_DOLLAR_RE = re.compile(
r"^\$\$\s*\n?(.*?)\n?\s*\$\$$", re.MULTILINE | re.DOTALL
)
_FENCED_CODE_RE = re.compile(r"^```(\w*)\n(.*?)^```", re.MULTILINE | re.DOTALL)
_TABLE_SEP_RE = re.compile(r"^\|[-:| ]+\|$")
# Markdown image pattern: 
_IMAGE_RE = re.compile(r"^!\[([^\]]*)\]\(([^)]+)\)\s*$")
# Bullet / numbered list patterns
_BULLET_RE = re.compile(r"^(\s*)-\s+(.+)")
_NUMBERED_RE = re.compile(r"^(\s*)\d+\.\s+(.+)")
def _convert_block(text: str) -> str:
"""Convert a block of Markdown body text to LaTeX."""
# Protect display math from further processing
math_blocks: list[str] = []
def _stash_math(m: re.Match[str]) -> str:
idx = len(math_blocks)
math_blocks.append(m.group(0)) # Keep \\[...\\] as-is
return f"%%MATH_BLOCK_{idx}%%"
def _stash_dollar_math(m: re.Match[str]) -> str:
"""Convert $$...$$ to \\begin{equation}...\\end{equation}."""
idx = len(math_blocks)
inner = m.group(1).strip()
math_blocks.append(
f"\\begin{{equation}}\n{inner}\n\\end{{equation}}"
)
return f"%%MATH_BLOCK_{idx}%%"
text = _DISPLAY_MATH_RE.sub(_stash_math, text)
# Also handle $$...$$ display math
text = _DISPLAY_MATH_DOLLAR_RE.sub(_stash_dollar_math, text)
# Protect fenced code blocks
code_blocks: list[str] = []
def _stash_code(m: re.Match[str]) -> str:
idx = len(code_blocks)
lang = m.group(1) or ""
code = m.group(2)
code_blocks.append(_render_code_block(lang, code))
return f"%%CODE_BLOCK_{idx}%%"
text = _FENCED_CODE_RE.sub(_stash_code, text)
# Protect raw LaTeX environments (table, figure, algorithm, etc.)
# These appear when pre-built LaTeX (e.g. anti-fabrication result tables)
# is embedded directly in the markdown. Without protection, their
# contents go through _convert_inline which double-escapes {, }, _, &.
latex_env_blocks: list[str] = []
def _stash_latex_env(m: re.Match[str]) -> str:
idx = len(latex_env_blocks)
latex_env_blocks.append(m.group(0))
return f"%%LATEX_ENV_{idx}%%"
# Match \begin{env}...\end{env} for environments that should pass through.
text = re.sub(
r"\\begin\{(table|figure|tabular|algorithm|algorithmic|equation|align"
r"|gather|multline|minipage|tikzpicture)\*?\}.*?"
r"\\end\{\1\*?\}",
_stash_latex_env,
text,
flags=re.DOTALL,
)
# Process line by line for lists, tables, and paragraphs
lines = text.split("\n")
output: list[str] = []
i = 0
while i < len(lines):
line = lines[i]
# Check for stashed blocks
if line.strip().startswith("%%MATH_BLOCK_"):
idx = int(re.search(r"\d+", line.strip()).group()) # type: ignore[union-attr]
output.append(math_blocks[idx])
i += 1
continue
if line.strip().startswith("%%CODE_BLOCK_"):
idx = int(re.search(r"\d+", line.strip()).group()) # type: ignore[union-attr]
output.append(code_blocks[idx])
i += 1
continue
# Stashed LaTeX environments — pass through unchanged
if line.strip().startswith("%%LATEX_ENV_"):
idx = int(re.search(r"\d+", line.strip()).group()) # type: ignore[union-attr]
output.append(latex_env_blocks[idx])
i += 1
continue
# Bullet list
if _BULLET_RE.match(line):
items, i = _collect_list(lines, i, _BULLET_RE)
output.append(_render_itemize(items))
continue
# Numbered list
if _NUMBERED_RE.match(line):
items, i = _collect_list(lines, i, _NUMBERED_RE)
output.append(_render_enumerate(items))
continue
# Table detection (line starts with |)
if (
line.strip().startswith("|")
and i + 1 < len(lines)
and _TABLE_SEP_RE.match(lines[i + 1].strip())
):
# Check if previous line is a table caption (e.g. **Table 1: ...**)
table_caption = ""
if output:
prev = output[-1].strip()
# Match bold caption: \textbf{Table N...} (already converted)
# or raw markdown: **Table N: ...**
cap_m = re.match(
r"(?:\\textbf\{|[*]{2})\s*Table\s+\d+[.:]?\s*(.*?)(?:\}|[*]{2})$",
prev,
)
if cap_m:
table_caption = f"Table {cap_m.group(1)}" if cap_m.group(1) else ""
if not table_caption:
table_caption = prev
output.pop() # Remove caption line from output (now inside table)
table_lines, i = _collect_table(lines, i)
output.append(_render_table(table_lines, caption=table_caption))
continue
# Markdown image: 
img_match = _IMAGE_RE.match(line.strip())
if img_match:
output.append(_render_figure(img_match.group(1), img_match.group(2)))
i += 1
continue
# Regular paragraph line
output.append(_convert_inline(line))
i += 1
return "\n".join(output)
# ---------------------------------------------------------------------------
# List handling
# ---------------------------------------------------------------------------
def _collect_list(
lines: list[str], start: int, pattern: re.Pattern[str]
) -> tuple[list[str], int]:
"""Collect consecutive list items matching *pattern*."""
items: list[str] = []
i = start
while i < len(lines):
m = pattern.match(lines[i])
if m:
items.append(m.group(2))
i += 1
elif lines[i].strip() == "":
# Blank line — might continue list or end it
if i + 1 < len(lines) and pattern.match(lines[i + 1]):
i += 1 # skip blank, continue
else:
break
elif lines[i].startswith(" ") or lines[i].startswith("\t"):
# Continuation of previous item
if items:
items[-1] += " " + lines[i].strip()
i += 1
else:
break
return items, i
def _render_itemize(items: list[str]) -> str:
inner = "\n".join(f" \\item {_convert_inline(item)}" for item in items)
return f"\\begin{{itemize}}\n{inner}\n\\end{{itemize}}"
def _render_enumerate(items: list[str]) -> str:
inner = "\n".join(f" \\item {_convert_inline(item)}" for item in items)
return f"\\begin{{enumerate}}\n{inner}\n\\end{{enumerate}}"
# ---------------------------------------------------------------------------
# Table handling
# ---------------------------------------------------------------------------
def _collect_table(lines: list[str], start: int) -> tuple[list[str], int]:
"""Collect table lines (header + separator + body rows)."""
table: list[str] = []
i = start
while i < len(lines) and lines[i].strip().startswith("|"):
table.append(lines[i])
i += 1
return table, i
def _render_table(table_lines: list[str], caption: str = "") -> str:
"""Render a Markdown table as a LaTeX tabular inside a table environment.
IMP-23: Auto-wraps in ``\\resizebox`` when columns > 5 or any cell
text exceeds 25 characters, preventing overflow in conference formats.
IMP-32: Generates descriptive captions from header columns when the
caption is empty or just 'Table N'.
"""
if len(table_lines) < 2:
return ""
header = _parse_table_row(table_lines[0])
# Skip separator (line 1)
body_rows = [_parse_table_row(line) for line in table_lines[2:] if line.strip()]
ncols = len(header)
# Determine alignment from separator
alignments = _parse_alignments(table_lines[1], ncols)
col_spec = "".join(alignments)
table_num = _next_table_num()
# IMP-23: Detect wide tables that need resizebox
max_cell_len = max(
(len(c) for row in [header] + body_rows for c in row),
default=0,
)
needs_resize = ncols > 5 or max_cell_len > 25
lines_out: list[str] = []
lines_out.append("\\begin{table}[ht]")
lines_out.append("\\centering")
# Caption ABOVE table (standard academic convention)
if caption:
cap_text = re.sub(r"^Table\s+\d+[.:]\s*", "", caption).strip()
if cap_text:
lines_out.append(f"\\caption{{{_convert_inline(cap_text)}}}")
else:
auto_cap = _auto_table_caption(header, table_num)
lines_out.append(f"\\caption{{{auto_cap}}}")
else:
auto_cap = _auto_table_caption(header, table_num)
lines_out.append(f"\\caption{{{auto_cap}}}")
lines_out.append(f"\\label{{tab:{table_num}}}")
if needs_resize:
# BUG-109b fix: Use \columnwidth (works in both 1-col and 2-col layouts)
# \textwidth in 2-column formats (ICML) is full page width, causing
# floats wider than a column to be "lost" by LaTeX.
lines_out.append("\\resizebox{\\columnwidth}{!}{%")
lines_out.append(f"\\begin{{tabular}}{{{col_spec}}}")
lines_out.append("\\toprule")
lines_out.append(
" & ".join(f"\\textbf{{{_convert_inline(c)}}}" for c in header) + " \\\\"
)
lines_out.append("\\midrule")
for row in body_rows:
# Pad row to match header length
padded = row + [""] * (ncols - len(row))
lines_out.append(
" & ".join(_convert_inline(c) for c in padded[:ncols]) + " \\\\"
)
lines_out.append("\\bottomrule")
lines_out.append("\\end{tabular}")
if needs_resize:
lines_out.append("}") # close resizebox
lines_out.append("\\end{table}")
return "\n".join(lines_out)
def _auto_table_caption(header: list[str], table_num: int) -> str:
"""IMP-32: Generate a descriptive caption from table header columns."""
if len(header) <= 1:
return f"Table {table_num}"
cols = [c.strip() for c in header if c.strip()]
if len(cols) < 2:
return f"Table {table_num}"
col0 = cols[0].lower()
rest = [_convert_inline(c) for c in cols[1:min(5, len(cols))]]
# Detect common table types by first-column header
_HP_HINTS = {"hyperparameter", "parameter", "param", "hp", "setting", "config"}
_ABL_HINTS = {"component", "variant", "ablation", "configuration", "module"}
_MODEL_HINTS = {"model", "method", "approach", "algorithm", "baseline"}
if any(h in col0 for h in _HP_HINTS):
return f"Hyperparameter settings"
if any(h in col0 for h in _ABL_HINTS):
return f"Ablation study results across {', '.join(rest)}"
if any(h in col0 for h in _MODEL_HINTS):
return f"Performance comparison of different methods on {', '.join(rest)}"
return f"Comparison of {_convert_inline(cols[0])} across {', '.join(rest)}"
def _parse_table_row(line: str) -> list[str]:
"""Parse ``| a | b | c |`` into ``['a', 'b', 'c']``."""
line = line.strip()
if line.startswith("|"):
line = line[1:]
if line.endswith("|"):
line = line[:-1]
return [cell.strip() for cell in line.split("|")]
def _parse_alignments(sep_line: str, ncols: int) -> list[str]:
"""Parse alignment indicators from separator line."""
cells = _parse_table_row(sep_line)
aligns: list[str] = []
for cell in cells:
raw = cell.strip()
left = raw.startswith(":")
right = raw.endswith(":")
if left and right:
aligns.append("c")
elif right:
aligns.append("r")
else:
aligns.append("l")
# Pad to ncols
while len(aligns) < ncols:
aligns.append("l")
return aligns[:ncols]
# ---------------------------------------------------------------------------
# Code block rendering
# ---------------------------------------------------------------------------
_UNICODE_TO_ASCII: dict[str, str] = {
"\u2190": "<-", "\u2192": "->", "\u21d0": "<=", "\u21d2": "=>",
"\u2264": "<=", "\u2265": ">=", "\u2260": "!=", "\u2248": "~=",
"\u2208": " in ", "\u2209": " not in ",
"\u2200": "forall ", "\u2203": "exists ",
"\u2207": "nabla", "\u221e": "inf", "\u00b1": "+/-",
"\u00d7": "x", "\u00b7": "*", "\u2026": "...",
"\u03b1": "alpha", "\u03b2": "beta", "\u03b3": "gamma",
"\u03b4": "delta", "\u03b5": "epsilon", "\u03b6": "zeta",
"\u03b7": "eta", "\u03b8": "theta", "\u03b9": "iota",
"\u03ba": "kappa", "\u03bb": "lambda", "\u03bc": "mu",
"\u03bd": "nu", "\u03be": "xi", "\u03c0": "pi",
"\u03c1": "rho", "\u03c3": "sigma", "\u03c4": "tau",
"\u03c5": "upsilon", "\u03c6": "phi", "\u03c7": "chi",
"\u03c8": "psi", "\u03c9": "omega",
"\u0394": "Delta", "\u0398": "Theta", "\u039b": "Lambda",
"\u03a3": "Sigma", "\u03a6": "Phi", "\u03a8": "Psi",
"\u03a9": "Omega",
"\u2113": "ell", "\u2202": "d", "\u222b": "int",
}
# BUG-110: Unicode Greek → LaTeX math replacements for inline text.
# Used in _convert_inline() and _sanitize_latex_output().
_UNICODE_GREEK_TO_LATEX: dict[str, str] = {
# Lowercase
"\u03b1": "$\\alpha$", "\u03b2": "$\\beta$", "\u03b3": "$\\gamma$",
"\u03b4": "$\\delta$", "\u03b5": "$\\epsilon$", "\u03b6": "$\\zeta$",
"\u03b7": "$\\eta$", "\u03b8": "$\\theta$", "\u03b9": "$\\iota$",
"\u03ba": "$\\kappa$", "\u03bb": "$\\lambda$", "\u03bc": "$\\mu$",
"\u03bd": "$\\nu$", "\u03be": "$\\xi$", "\u03c0": "$\\pi$",
"\u03c1": "$\\rho$", "\u03c3": "$\\sigma$", "\u03c4": "$\\tau$",
"\u03c5": "$\\upsilon$", "\u03c6": "$\\phi$", "\u03c7": "$\\chi$",
"\u03c8": "$\\psi$", "\u03c9": "$\\omega$",
# Uppercase
"\u0393": "$\\Gamma$", "\u0394": "$\\Delta$", "\u0398": "$\\Theta$",
"\u039b": "$\\Lambda$", "\u039e": "$\\Xi$", "\u03a0": "$\\Pi$",
"\u03a3": "$\\Sigma$", "\u03a6": "$\\Phi$", "\u03a8": "$\\Psi$",
"\u03a9": "$\\Omega$",
# Common math symbols not already handled
"\u2200": "$\\forall$", "\u2203": "$\\exists$",
"\u2207": "$\\nabla$", "\u2202": "$\\partial$",
"\u2026": "\\ldots{}", "\u22c5": "$\\cdot$",
"\u2113": "$\\ell$", "\u222b": "$\\int$",
"\u2209": "$\\notin$",
# Common symbols that cause null-byte corruption if not converted
"\u00b1": "$\\pm$", # ±
"\u00d7": "$\\times$", # ×
"\u2248": "$\\approx$", # ≈
"\u2264": "$\\leq$", # ≤
"\u2265": "$\\geq$", # ≥
"\u2260": "$\\neq$", # ≠
"\u221e": "$\\infty$", # ∞
# Additional symbols found in Runs 49-52
"\u2212": "$-$", # − (minus sign, distinct from hyphen)
"\u2282": "$\\subset$", # ⊂
"\u222a": "$\\cup$", # ∪
"\u211d": "$\\mathbb{R}$", # ℝ
"\u0302": "\\^{}", # ̂ (combining circumflex)
"\u0303": "\\~{}", # ̃ (combining tilde — Run 61 pseudocode)
"\u221d": "$\\propto$", # ∝ (proportional to)
"\u2208": "$\\in$", # ∈
}
_ALGO_KEYWORDS = re.compile(
r"\b(Input|Output|Return|While|For|If|Else|Repeat|Until|Function|Procedure|Algorithm)\b",
re.IGNORECASE,
)
def _escape_algo_line(line: str) -> str:
"""Escape LaTeX special characters in an algorithmic pseudocode line.
BUG-177: Raw pseudocode lines contain Python/math syntax that breaks
pdflatex: ``#`` (comment char), ``_`` (subscript), ``%`` (comment),
``&`` (alignment), ``{}``, ``~``, ``^``.
Strategy:
1. Convert ``# comment`` at end of line → ``\\COMMENT{comment}``
2. Protect existing LaTeX commands and math delimiters
3. Escape remaining special characters
"""
# Step 1: Convert Python-style end-of-line comments → \COMMENT{...}
# Match `# comment` that isn't at the start of the line (those are full-line comments)
_comment_match = re.search(r"(?<=\s)#\s*(.+)$", line)
comment_suffix = ""
if _comment_match:
comment_text = _comment_match.group(1).strip()
line = line[: _comment_match.start()].rstrip()
comment_suffix = f" \\COMMENT{{{comment_text}}}"
elif line.strip().startswith("#"):
# Full-line comment
comment_text = line.strip().lstrip("#").strip()
return f"\\COMMENT{{{comment_text}}}"
# Step 2: Protect existing LaTeX commands and math mode from escaping
protected: list[str] = []
def _protect(m: re.Match[str]) -> str:
idx = len(protected)
protected.append(m.group(0))
return f"\x00ALG{idx}\x00"
# Protect: \command{...}, $...$, \(...\)
line = re.sub(r"\\[a-zA-Z]+\{[^}]*\}", _protect, line)
line = re.sub(r"\$[^$]+\$", _protect, line)
line = re.sub(r"\\\(.+?\\\)", _protect, line)
# Step 3: Escape special characters
line = line.replace("&", "\\&")
line = line.replace("%", "\\%")
line = line.replace("#", "\\#")
line = line.replace("_", "\\_")
line = line.replace("{", "\\{")
line = line.replace("}", "\\}")
line = line.replace("~", "\\textasciitilde{}")
line = line.replace("^", "\\textasciicircum{}")
# Step 4: Restore protected regions
for idx, val in enumerate(protected):
line = line.replace(f"\x00ALG{idx}\x00", val)
return line + comment_suffix
def _render_code_block(lang: str, code: str) -> str:
"""Render a fenced code block as a LaTeX environment.
IMP-28: Detects pseudocode blocks (language hint 'algorithm' /
'pseudocode', or 3+ algorithm keywords) and renders them inside an
``algorithm`` + ``algorithmic`` environment instead of verbatim.
Replaces Unicode characters (Greek letters, arrows, math symbols)
with ASCII equivalents so pdflatex can compile the block.
"""
import unicodedata
escaped = code.rstrip("\n")
for uni, ascii_eq in _UNICODE_TO_ASCII.items():
escaped = escaped.replace(uni, ascii_eq)
# Strip combining characters (tildes, hats, etc.) that break pdflatex
escaped = "".join(
c for c in escaped if not unicodedata.combining(c)
)
# IMP-28: Detect pseudocode and use algorithm environment
lang_lower = lang.lower().strip()
is_algo = lang_lower in ("algorithm", "pseudocode", "algo")
if not is_algo:
# Heuristic: ≥3 algorithm keywords → treat as pseudocode
is_algo = len(_ALGO_KEYWORDS.findall(escaped)) >= 3
if is_algo:
# Extract caption from first comment line if present
algo_lines = escaped.split("\n")
caption = "Algorithm"
if algo_lines and algo_lines[0].strip().startswith("//"):
caption = algo_lines[0].strip().lstrip("/ ").strip()
algo_lines = algo_lines[1:]
# Wrap raw lines in \STATE unless they already use algorithmic commands
_algo_cmds = {"\\STATE", "\\IF", "\\ELSE", "\\ELSIF", "\\ENDIF",
"\\FOR", "\\ENDFOR", "\\WHILE", "\\ENDWHILE",
"\\REPEAT", "\\UNTIL", "\\RETURN", "\\REQUIRE", "\\ENSURE"}
wrapped_lines = []
for al in algo_lines:
stripped = al.strip()
if not stripped:
continue
if any(stripped.startswith(cmd) for cmd in _algo_cmds):
wrapped_lines.append(stripped)
else:
# BUG-177: Escape LaTeX special chars in pseudocode lines
wrapped_lines.append(f"\\STATE {_escape_algo_line(stripped)}")
body = "\n".join(wrapped_lines)
return (
"\\begin{algorithm}[ht]\n"
f"\\caption{{{_convert_inline(caption)}}}\n"
"\\begin{algorithmic}[1]\n"
f"{body}\n"
"\\end{algorithmic}\n"
"\\end{algorithm}"
)
return f"\\begin{{verbatim}}\n{escaped}\n\\end{{verbatim}}"
# ---------------------------------------------------------------------------
# Figure rendering
# ---------------------------------------------------------------------------
def _render_figure(caption: str, path: str) -> str:
"""Render a markdown image as a LaTeX figure environment."""
fig_num = _next_figure_num()
# Sanitize path for LaTeX: replace spaces, keep underscores
path = path.replace(" ", "_")
cap_tex = _convert_inline(caption) if caption else f"Figure {fig_num}"
label_key = re.sub(r"[^a-z0-9]+", "_", caption.lower()).strip("_")[:30]
if not label_key:
label_key = str(fig_num)
return (
"\\begin{figure}[t]\n"
"\\centering\n"
f"\\includegraphics[width=0.95\\columnwidth]{{{path}}}\n"
f"\\caption{{{cap_tex}}}\n"
f"\\label{{fig:{label_key}}}\n"
"\\end{figure}"
)
# ---------------------------------------------------------------------------
# Inline conversion
# ---------------------------------------------------------------------------
# Order matters: process bold before italic to avoid conflicts.
_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
_ITALIC_RE = re.compile(r"(? str:
"""Convert inline Markdown formatting to LaTeX.
Preserves:
- Inline math ``\\(...\\)`` and ``$...$``
- ``\\cite{...}`` references
- Display math markers (already handled at block level)
"""
# Normalize Unicode punctuation to LaTeX equivalents
text = text.replace("\u2014", "---") # em-dash —
text = text.replace("\u2013", "--") # en-dash –
text = text.replace("\u201c", "``") # left double quote "
text = text.replace("\u201d", "''") # right double quote "
text = text.replace("\u2018", "`") # left single quote '
text = text.replace("\u2019", "'") # right single quote '
text = text.replace("\u00b1", "$\\pm$") # ±
text = text.replace("\u2248", "$\\approx$") # ≈
text = text.replace("\u2264", "$\\leq$") # ≤
text = text.replace("\u2265", "$\\geq$") # ≥
text = text.replace("\u2192", "$\\rightarrow$") # →
text = text.replace("\u2190", "$\\leftarrow$") # ←
text = text.replace("\u00d7", "$\\times$") # ×
text = text.replace("\u2260", "$\\neq$") # ≠
text = text.replace("\u2208", "$\\in$") # ∈
text = text.replace("\u221e", "$\\infty$") # ∞
# BUG-110: Replace Unicode Greek letters with LaTeX math equivalents.
# These appear when LLMs emit raw Unicode (e.g. "ε-greedy" instead of
# "$\epsilon$-greedy") and cause fatal pdflatex errors.
for _uchar, _lcmd in _UNICODE_GREEK_TO_LATEX.items():
if _uchar in text:
text = text.replace(_uchar, _lcmd)
# Protect math and cite from escaping
protected: list[str] = []
def _protect(m: re.Match[str]) -> str:
idx = len(protected)
protected.append(m.group(0))
return f"\x00PROT{idx}\x00"
# Protect inline math: \(...\) and $...$
text = re.sub(r"\\\(.+?\\\)", _protect, text)
text = re.sub(r"(? str:
href = f"\\href{{{m.group(2)}}}{{{m.group(1)}}}"
idx = len(protected)
protected.append(href)
return f"\x00PROT{idx}\x00"
text = _LINK_RE.sub(_convert_and_protect_link, text)
# Escape special LaTeX characters
text = _LATEX_SPECIAL.sub(r"\\\1", text)
text = _LATEX_TILDE.sub(r"\\textasciitilde{}", text)
text = _LATEX_CARET.sub(r"\\textasciicircum{}", text)
text = _LATEX_DOLLAR.sub(r"\\$", text)
# Convert bold **text** → \textbf{text}
text = _BOLD_RE.sub(r"\\textbf{\1}", text)
# Convert italic *text* → \textit{text}
text = _ITALIC_RE.sub(r"\\textit{\1}", text)
# Convert inline code `text` → \texttt{text}
text = _INLINE_CODE_RE.sub(r"\\texttt{\1}", text)
# Links and images were already converted+protected before escaping.
# Fallback: convert any remaining [cite_key] patterns to \cite{key}
# This catches citations that were not converted upstream.
# BUG-32 fix: key pattern must also match author2017keyword style keys
# (e.g., roijers2017multiobjective, abels2019dynamic)
_CITE_KEY_PAT = r"[a-zA-Z][a-zA-Z0-9_-]*\d{4}[a-zA-Z0-9_]*"
text = re.sub(
rf"\[({_CITE_KEY_PAT}(?:\s*,\s*{_CITE_KEY_PAT})*)\]",
r"\\cite{\1}",
text,
)
# Restore protected segments in reverse order so that nested
# markers (e.g. PROT0 inside PROT1's value) are resolved correctly.
for idx in range(len(protected) - 1, -1, -1):
text = text.replace(f"\x00PROT{idx}\x00", protected[idx])
return text
# ---------------------------------------------------------------------------
# Completeness checking (R10-Fix5)
# ---------------------------------------------------------------------------
_EXPECTED_SECTIONS = {
"introduction",
"related work",
"method",
"experiment",
"result",
"discussion",
"conclusion",
}
_SECTION_ALIASES: dict[str, str] = {
"methodology": "method",
"methods": "method",
"proposed method": "method",
"approach": "method",
"experiments": "experiment",
"experimental setup": "experiment",
"experimental results": "result",
"results": "result",
"results and discussion": "result",
"results and analysis": "result",
"discussion and results": "result",
"conclusions": "conclusion",
"conclusion and future work": "conclusion",
"summary": "conclusion",
"background": "related work",
"literature review": "related work",
"prior work": "related work",
}
def check_paper_completeness(sections: list[_Section]) -> list[str]:
"""Check whether a paper contains all expected sections.
Returns a list of warning strings. Empty list means the paper
structure looks complete.
"""
warnings: list[str] = []
# Check for valid title — look for any H1/H2 heading that could be a title
_has_title = any(
sec.level in (1, 2) and sec.heading_lower not in ("abstract", "introduction",
"related work", "method", "methods", "methodology", "experiments",
"results", "discussion", "conclusion", "limitations", "references")
for sec in sections
)
if not _has_title:
warnings.append(
"No valid title found in paper. The output may lack proper heading structure."
)
found_sections: set[str] = set()
section_headings: list[str] = []
for sec in sections:
if sec.level in (1, 2) and sec.heading:
heading_lower = sec.heading.strip().lower()
section_headings.append(heading_lower)
if heading_lower in _EXPECTED_SECTIONS:
found_sections.add(heading_lower)
elif heading_lower in _SECTION_ALIASES:
found_sections.add(_SECTION_ALIASES[heading_lower])
else:
for expected in _EXPECTED_SECTIONS:
if expected in heading_lower:
found_sections.add(expected)
break
missing = _EXPECTED_SECTIONS - found_sections
if missing:
warnings.append(
f"Missing sections: {', '.join(sorted(missing))}. "
f"Found: {', '.join(section_headings)}"
)
# T2.5: Check for required conference sections (NeurIPS/ICLR mandate Limitations)
_required_extras = {"limitations"}
_extra_aliases = {
"limitation": "limitations",
"limitations and future work": "limitations",
"limitations and broader impact": "limitations",
}
found_extras: set[str] = set()
for sec in sections:
if sec.level in (1, 2) and sec.heading:
hl = sec.heading.strip().lower()
if hl in _required_extras:
found_extras.add(hl)
elif hl in _extra_aliases:
found_extras.add(_extra_aliases[hl])
elif "limitation" in hl:
found_extras.add("limitations")
missing_extras = _required_extras - found_extras
if missing_extras:
warnings.append(
f"Missing required sections for NeurIPS/ICLR: "
f"{', '.join(sorted(missing_extras))}."
)
# T1.5: Abstract length and quality checks
abstract_text = ""
for sec in sections:
if sec.heading_lower == "abstract":
abstract_text = sec.body
break
if abstract_text:
word_count = len(abstract_text.split())
if word_count > 300:
warnings.append(
f"Abstract is {word_count} words (conference limit: 150-250). "
f"Must be shortened."
)
elif word_count < 150:
warnings.append(
f"Abstract is only {word_count} words (expected 150-250 for conferences)."
)
# Detect raw variable names / metric key dumps
raw_vars = re.findall(r"\b\w+_\w+/\w+(?:_\w+)*\s*=", abstract_text)
if raw_vars:
warnings.append(
f"Abstract contains raw variable names: {raw_vars[:3]}. "
f"Replace with human-readable descriptions."
)
# Detect truncation markers
all_body = " ".join(sec.body for sec in sections)
truncation_markers = [
"further sections continue",
"remaining sections unchanged",
"sections continue unchanged",
"content continues",
"[to be continued]",
"[remaining content]",
]
for marker in truncation_markers:
if marker in all_body.lower():
warnings.append(
f"Truncation marker detected: '{marker}'. "
f"Paper content may be incomplete."
)
# Word count check
total_words = sum(len(sec.body.split()) for sec in sections)
if total_words < 2000:
warnings.append(
f"Paper body is only {total_words} words "
f"(expected 5,000-6,500 for conference paper). "
f"Content may be severely truncated."
)
# Per-section word count check (safety net during LaTeX conversion)
from researchclaw.prompts import SECTION_WORD_TARGETS, _SECTION_TARGET_ALIASES
for sec in sections:
if sec.level not in (1, 2) or not sec.heading:
continue
canon = sec.heading_lower
if canon not in SECTION_WORD_TARGETS:
canon = _SECTION_TARGET_ALIASES.get(sec.heading_lower, "")
if not canon or canon not in SECTION_WORD_TARGETS:
continue
lo, hi = SECTION_WORD_TARGETS[canon]
wc = len(sec.body.split())
if wc < int(lo * 0.6):
warnings.append(
f"Section '{sec.heading}' is only {wc} words "
f"(expected {lo}-{hi}). Content may be severely truncated."
)
elif wc > int(hi * 1.5):
warnings.append(
f"Section '{sec.heading}' is {wc} words "
f"(expected {lo}-{hi}). Consider trimming."
)
# Bullet density check for body sections
_bullet_re_cc = re.compile(r"^\s*[-*]\s+", re.MULTILINE)
_numbered_re_cc = re.compile(r"^\s*\d+\.\s+", re.MULTILINE)
_bullet_ok_sections = {"introduction", "limitations", "limitation", "abstract"}
for sec in sections:
if sec.level not in (1, 2) or not sec.heading:
continue
hl = sec.heading_lower
if hl in _bullet_ok_sections:
continue
if not sec.body:
continue
total_lines = len([ln for ln in sec.body.splitlines() if ln.strip()])
if total_lines < 4:
continue
bullet_count = (
len(_bullet_re_cc.findall(sec.body))
+ len(_numbered_re_cc.findall(sec.body))
)
density = bullet_count / total_lines
if density > 0.30:
warnings.append(
f"Section '{sec.heading}' has high bullet-point density "
f"({bullet_count}/{total_lines} lines = {density:.0%}). "
f"Conference papers should use flowing prose."
)
return warnings
def _escape_latex(text: str) -> str:
"""Escape LaTeX special characters in plain text (titles, headings).
Does NOT escape inside math delimiters or \\commands.
"""
# Protect math first
protected: list[str] = []
def _protect(m: re.Match[str]) -> str:
idx = len(protected)
protected.append(m.group(0))
return f"\x00PROT{idx}\x00"
text = re.sub(r"\\\(.+?\\\)", _protect, text)
text = re.sub(r"(? list[LatexTable]:
"""Generate LaTeX tables from a VerifiedRegistry.
Parameters
----------
registry:
The verified registry built from experiment data.
metric_name:
Human-readable name for the primary metric column.
metric_direction:
``"maximize"`` or ``"minimize"`` — determines which result is bolded.
two_column:
If True, use ``table*`` environment (for 2-column formats like ICML).
Returns
-------
list[LatexTable]
One or more tables. Usually just one main results table.
"""
tables: list[LatexTable] = []
# --- Main results table ---
conditions = _get_reportable_conditions(registry)
if not conditions:
logger.warning("No reportable conditions — skipping table generation")
return tables
main_table = _build_main_table(
conditions,
metric_name=metric_name,
metric_direction=metric_direction,
two_column=two_column,
)
tables.append(main_table)
# --- Per-seed breakdown table (if seeds > 1 for any condition) ---
has_multi_seed = any(c.n_seeds >= 2 for c in conditions)
if has_multi_seed:
seed_table = _build_per_seed_table(
conditions,
metric_name=metric_name,
two_column=two_column,
)
tables.append(seed_table)
return tables
def _get_reportable_conditions(registry: VerifiedRegistry) -> list[ConditionResult]:
"""Filter conditions to only those with at least 1 valid seed."""
results = []
for cond in registry.conditions.values():
if cond.n_seeds >= 1 and cond.mean is not None and math.isfinite(cond.mean):
results.append(cond)
# Sort alphabetically for consistency
results.sort(key=lambda c: c.name)
return results
def _build_main_table(
conditions: list[ConditionResult],
*,
metric_name: str,
metric_direction: str,
two_column: bool,
) -> LatexTable:
"""Build the main results table with mean ± std per condition."""
verified: set[float] = set()
# Find best condition for bolding
best_idx = _find_best(conditions, metric_direction)
# Build rows
rows: list[str] = []
for i, cond in enumerate(conditions):
mean_str = _fmt(cond.mean)
if cond.mean is not None:
verified.add(round(cond.mean, 4))
if cond.std is not None and cond.std > 0 and cond.n_seeds >= 2:
std_str = _fmt(cond.std)
val_str = f"{mean_str} $\\pm$ {std_str}"
verified.add(round(cond.std, 4))
elif cond.n_seeds == 1:
val_str = f"{mean_str}$^{{\\ddagger}}$"
else:
val_str = mean_str
if i == best_idx:
val_str = f"\\textbf{{{val_str}}}"
n_str = str(cond.n_seeds)
name_escaped = _escape_latex(cond.name)
rows.append(f"{name_escaped} & {val_str} & {n_str} \\\\")
# Compose table
table_env = "table*" if two_column else "table"
col_spec = "l c r"
body = "\n".join(rows)
note_lines = []
if any(c.n_seeds == 1 for c in conditions):
note_lines.append(
"$^{\\ddagger}$Single seed; no standard deviation available."
)
notes = "\n".join(note_lines)
if notes:
notes = f"\n\\vspace{{2pt}}\\par\\footnotesize {notes}\n"
latex = (
f"\\begin{{{table_env}}}[htbp]\n"
f"\\centering\n"
f"\\caption{{Experimental results. "
f"{len(conditions)} conditions evaluated.}}\n"
f"\\label{{tab:main_results}}\n"
f"% AUTO-GENERATED FROM EXPERIMENT DATA — DO NOT MODIFY NUMBERS\n"
f"\\begin{{tabular}}{{{col_spec}}}\n"
f"\\toprule\n"
f"Method & {metric_name} & $n$ \\\\\n"
f"\\midrule\n"
f"{body}\n"
f"\\bottomrule\n"
f"\\end{{tabular}}{notes}\n"
f"\\end{{{table_env}}}"
)
return LatexTable(
label="tab:main_results",
caption=f"Experimental results. {len(conditions)} conditions evaluated.",
latex_code=latex,
verified_values=verified,
n_conditions=len(conditions),
n_total_seeds=sum(c.n_seeds for c in conditions),
)
def _build_per_seed_table(
conditions: list[ConditionResult],
*,
metric_name: str,
two_column: bool,
) -> LatexTable:
"""Build per-seed breakdown table."""
verified: set[float] = set()
# Determine max seeds across conditions
max_seeds = max(c.n_seeds for c in conditions)
# Build header
seed_cols = " & ".join(f"Seed {i}" for i in range(max_seeds))
col_spec = "l " + " ".join("r" for _ in range(max_seeds)) + " r"
# Build rows
rows: list[str] = []
for cond in conditions:
name_escaped = _escape_latex(cond.name)
cells = []
for seed_idx in range(max_seeds):
val = cond.per_seed_values.get(seed_idx)
if val is not None and math.isfinite(val):
cells.append(_fmt(val))
verified.add(round(val, 4))
else:
cells.append("---")
mean_str = _fmt(cond.mean) if cond.mean is not None else "---"
cells_str = " & ".join(cells)
rows.append(f"{name_escaped} & {cells_str} & {mean_str} \\\\")
body = "\n".join(rows)
table_env = "table*" if two_column else "table"
latex = (
f"\\begin{{{table_env}}}[htbp]\n"
f"\\centering\n"
f"\\caption{{Per-seed results breakdown.}}\n"
f"\\label{{tab:per_seed}}\n"
f"% AUTO-GENERATED FROM EXPERIMENT DATA — DO NOT MODIFY NUMBERS\n"
f"\\begin{{tabular}}{{{col_spec}}}\n"
f"\\toprule\n"
f"Method & {seed_cols} & Mean \\\\\n"
f"\\midrule\n"
f"{body}\n"
f"\\bottomrule\n"
f"\\end{{tabular}}\n"
f"\\end{{{table_env}}}"
)
return LatexTable(
label="tab:per_seed",
caption="Per-seed results breakdown.",
latex_code=latex,
verified_values=verified,
n_conditions=len(conditions),
n_total_seeds=sum(c.n_seeds for c in conditions),
)
def build_condition_whitelist(registry: VerifiedRegistry) -> str:
"""Generate a human-readable condition whitelist for the LLM prompt.
Example output::
CONDITION WHITELIST (you may ONLY discuss these conditions):
- DQN (3 seeds, mean=206.10)
- DQN+Abstraction (3 seeds, mean=278.93)
- DQN+RawCount (3 seeds, mean=180.80)
"""
lines = ["CONDITION WHITELIST (you may ONLY discuss these conditions):"]
for cond in sorted(registry.conditions.values(), key=lambda c: c.name):
if cond.n_seeds == 0 or cond.mean is None or not math.isfinite(cond.mean):
continue
mean_str = f"{cond.mean:.4f}"
lines.append(f"- {cond.name} ({cond.n_seeds} seed(s), mean={mean_str})")
if len(lines) == 1:
lines.append("- (no conditions completed)")
return "\n".join(lines)
def _find_best(conditions: list[ConditionResult], direction: str) -> int | None:
"""Return index of best condition, or None if empty."""
if not conditions:
return None
best_idx = 0
for i, c in enumerate(conditions):
if c.mean is None:
continue
if conditions[best_idx].mean is None:
best_idx = i
continue
if direction == "maximize" and c.mean > conditions[best_idx].mean:
best_idx = i
elif direction == "minimize" and c.mean < conditions[best_idx].mean:
best_idx = i
return best_idx
def _fmt(value: float | None) -> str:
"""Format a number for LaTeX tables with sig-fig-aware rounding."""
if value is None or not math.isfinite(value):
return "---"
# Sig-fig-aware formatting (same approach as BUG-83 fix)
av = abs(value)
if av >= 100:
return f"{value:.2f}"
elif av >= 1:
return f"{value:.4f}"
elif av >= 0.001:
return f"{value:.4f}"
elif av > 0:
# Very small values: use 2 significant figures
import decimal
d = decimal.Decimal(str(value)).normalize()
# Count leading zeros after decimal point
exp = d.adjusted()
sig_digits = max(2, -exp + 1)
return f"{value:.{sig_digits}f}"
else:
return "0.0000"
def _escape_latex(text: str) -> str:
"""Escape special LaTeX characters in condition names."""
# Backslash must be first to avoid double-escaping
replacements = [
("\\", "\\textbackslash{}"),
("&", "\\&"),
("%", "\\%"),
("#", "\\#"),
("_", "\\_"),
("$", "\\$"),
("{", "\\{"),
("}", "\\}"),
("~", "\\textasciitilde{}"),
("^", "\\textasciicircum{}"),
]
for old, new in replacements:
text = text.replace(old, new)
return text
================================================
FILE: researchclaw/templates/styles/iclr_2025/iclr2025_conference.bst
================================================
%% iclr2025_conference.bst — ICLR 2025 bibliography style
%% Symlink-equivalent to iclr2026_conference.bst (same format).
%% Bundled by AutoResearchClaw for offline compilation.
ENTRY
{ author title journal booktitle year volume number pages doi url note publisher address edition eprint archiveprefix primaryclass }
{}
{ label }
INTEGERS { output.state before.all mid.sentence after.sentence after.block }
FUNCTION {init.state.consts}
{ #0 'before.all := #1 'mid.sentence := #2 'after.sentence := #3 'after.block := }
STRINGS { s t }
FUNCTION {output.nonnull}
{ 's :=
output.state mid.sentence =
{ ", " * write$ }
{ output.state after.block =
{ add.period$ write$ newline$ "\newblock " write$ }
{ output.state before.all = 'write$ { add.period$ " " * write$ } if$ }
if$
mid.sentence 'output.state :=
}
if$
s
}
FUNCTION {output}
{ duplicate$ empty$ 'pop$ 'output.nonnull if$ }
FUNCTION {output.check}
{ 't := duplicate$ empty$ { pop$ "empty " t * " in " * cite$ * warning$ } 'output.nonnull if$ }
FUNCTION {fin.entry} { add.period$ write$ newline$ }
FUNCTION {new.block}
{ output.state before.all = 'skip$ { after.block 'output.state := } if$ }
FUNCTION {not} { { #0 } { #1 } if$ }
FUNCTION {and} { 'skip$ { pop$ #0 } if$ }
FUNCTION {or} { { pop$ #1 } 'skip$ if$ }
FUNCTION {field.or.null} { duplicate$ empty$ { pop$ "" } 'skip$ if$ }
FUNCTION {emphasize} { duplicate$ empty$ { pop$ "" } { "\emph{" swap$ * "}" * } if$ }
INTEGERS { nameptr namesleft numnames }
FUNCTION {format.names}
{ 's := #1 'nameptr := s num.names$ 'numnames := numnames 'namesleft :=
{ namesleft #0 > }
{ s nameptr "{ff~}{vv~}{ll}{, jj}" format.name$ 't :=
nameptr #1 > { namesleft #1 > { ", " * t * } { numnames #2 > { "," * } 'skip$ if$ t "others" = { " et~al." * } { " and " * t * } if$ } if$ } 't if$
nameptr #1 + 'nameptr := namesleft #1 - 'namesleft :=
}
while$
}
FUNCTION {format.authors} { author empty$ { "" } { author format.names } if$ }
FUNCTION {format.title} { title empty$ { "" } { title } if$ }
FUNCTION {format.btitle} { title emphasize }
FUNCTION {format.date} { year empty$ { "" } { year } if$ }
FUNCTION {format.bvolume} { volume empty$ { "" } { "volume " volume * } if$ }
FUNCTION {format.pages} { pages empty$ { "" } { "pp. " pages * } if$ }
FUNCTION {format.url} { url empty$ { "" } { "\url{" url * "}" * } if$ }
FUNCTION {output.bibitem}
{ newline$ "\bibitem{" write$ cite$ write$ "}" write$ newline$ "" before.all 'output.state := }
FUNCTION {article}
{ output.bibitem format.authors "author" output.check new.block format.title "title" output.check new.block journal emphasize "journal" output.check format.bvolume output format.pages output format.date "year" output.check format.url output fin.entry }
FUNCTION {inproceedings}
{ output.bibitem format.authors "author" output.check new.block format.title "title" output.check new.block "In " booktitle emphasize * output format.pages output format.date "year" output.check format.url output fin.entry }
FUNCTION {conference} { inproceedings }
FUNCTION {book}
{ output.bibitem format.authors "author" output.check new.block format.btitle "title" output.check publisher output format.date "year" output.check fin.entry }
FUNCTION {misc}
{ output.bibitem format.authors output new.block format.title output new.block note output format.date output format.url output fin.entry }
FUNCTION {techreport} { misc }
FUNCTION {phdthesis} { misc }
FUNCTION {mastersthesis} { misc }
FUNCTION {unpublished} { misc }
FUNCTION {default.type} { misc }
READ
FUNCTION {sortify} { purify$ "l" change.case$ }
FUNCTION {presort} { cite$ 'label := label sortify " " * #1 entry.max$ substring$ 'sort.key$ := }
ITERATE {presort}
SORT
FUNCTION {begin.bib} { preamble$ empty$ 'skip$ { preamble$ write$ newline$ } if$ "\begin{thebibliography}{99}" write$ newline$ }
FUNCTION {end.bib} { newline$ "\end{thebibliography}" write$ newline$ }
EXECUTE {begin.bib}
EXECUTE {init.state.consts}
ITERATE {call.type$}
EXECUTE {end.bib}
================================================
FILE: researchclaw/templates/styles/iclr_2025/iclr2025_conference.sty
================================================
% iclr2025_conference.sty — ICLR 2025 conference style file
% Based on the official ICLR submission template structure.
% Bundled by AutoResearchClaw for offline compilation.
% Official source: https://github.com/ICLR/Master-Template/raw/master/iclr2025.zip
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{iclr2025_conference}[2025/01/15 ICLR 2025 conference style]
\newif\if@iclr@final \@iclr@finalfalse
\newif\if@iclr@preprint \@iclr@preprintfalse
\DeclareOption{final}{\@iclr@finaltrue}
\DeclareOption{preprint}{\@iclr@preprinttrue}
\ProcessOptions\relax
\RequirePackage{geometry}
\geometry{textwidth=5.5in,textheight=9.0in,top=1.0in,headheight=12pt,headsep=25pt,footskip=30pt}
\RequirePackage{times}
\renewcommand{\baselinestretch}{1.0}
\setlength{\parskip}{0pt}
\setlength{\parindent}{1em}
\renewcommand{\section}{\@startsection{section}{1}{0mm}{-2.0ex plus -0.5ex minus -.2ex}{1.0ex plus .2ex}{\normalfont\large\bfseries}}
\renewcommand{\subsection}{\@startsection{subsection}{2}{0mm}{-1.5ex plus -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalfont\normalsize\bfseries}}
\renewcommand{\subsubsection}{\@startsection{subsubsection}{3}{0mm}{-1.0ex plus -0.5ex minus -.2ex}{0.5ex plus .2ex}{\normalfont\normalsize\bfseries}}
\def\@maketitle{%
\vbox to 0pt{}\vskip -0.5in
\begin{center}%
{\LARGE\bfseries \@title \par}\vskip 0.3in
\if@iclr@final
{\large\lineskip .5em\begin{tabular}[t]{c}\@author\end{tabular}\par}%
\else
{\large Anonymous authors\par}{\normalsize Paper under double-blind review\par}%
\fi
\vskip 0.3in
\end{center}\par\vskip 0.5em
}
\renewenvironment{abstract}{\centerline{\large\bfseries Abstract}\vspace{0.5ex}\begin{quote}}{\par\end{quote}\vskip 1ex}
\RequirePackage{fancyhdr}
\pagestyle{fancy}\fancyhf{}
\fancyfoot[C]{\thepage}
\renewcommand{\headrulewidth}{0pt}
\RequirePackage[numbers,sort&compress]{natbib}
\endinput
================================================
FILE: researchclaw/templates/styles/iclr_2026/iclr2026_conference.bst
================================================
%% iclr2026_conference.bst — ICLR 2026 bibliography style
%% Bundled by AutoResearchClaw for offline compilation.
%% This is a minimal numeric bibliography style compatible with natbib.
%% For full-fidelity formatting, download from https://github.com/ICLR/Master-Template
ENTRY
{ author
title
journal
booktitle
year
volume
number
pages
doi
url
note
publisher
address
edition
eprint
archiveprefix
primaryclass
}
{}
{ label }
INTEGERS { output.state before.all mid.sentence after.sentence after.block }
FUNCTION {init.state.consts}
{ #0 'before.all :=
#1 'mid.sentence :=
#2 'after.sentence :=
#3 'after.block :=
}
STRINGS { s t }
FUNCTION {output.nonnull}
{ 's :=
output.state mid.sentence =
{ ", " * write$ }
{ output.state after.block =
{ add.period$ write$
newline$
"\newblock " write$
}
{ output.state before.all =
'write$
{ add.period$ " " * write$ }
if$
}
if$
mid.sentence 'output.state :=
}
if$
s
}
FUNCTION {output}
{ duplicate$ empty$
'pop$
'output.nonnull
if$
}
FUNCTION {output.check}
{ 't :=
duplicate$ empty$
{ pop$ "empty " t * " in " * cite$ * warning$ }
'output.nonnull
if$
}
FUNCTION {fin.entry}
{ add.period$
write$
newline$
}
FUNCTION {new.block}
{ output.state before.all =
'skip$
{ after.block 'output.state := }
if$
}
FUNCTION {not}
{ { #0 }
{ #1 }
if$
}
FUNCTION {and}
{ 'skip$
{ pop$ #0 }
if$
}
FUNCTION {or}
{ { pop$ #1 }
'skip$
if$
}
FUNCTION {field.or.null}
{ duplicate$ empty$
{ pop$ "" }
'skip$
if$
}
FUNCTION {emphasize}
{ duplicate$ empty$
{ pop$ "" }
{ "\emph{" swap$ * "}" * }
if$
}
INTEGERS { nameptr namesleft numnames }
FUNCTION {format.names}
{ 's :=
#1 'nameptr :=
s num.names$ 'numnames :=
numnames 'namesleft :=
{ namesleft #0 > }
{ s nameptr "{ff~}{vv~}{ll}{, jj}" format.name$ 't :=
nameptr #1 >
{ namesleft #1 >
{ ", " * t * }
{ numnames #2 >
{ "," * }
'skip$
if$
t "others" =
{ " et~al." * }
{ " and " * t * }
if$
}
if$
}
't
if$
nameptr #1 + 'nameptr :=
namesleft #1 - 'namesleft :=
}
while$
}
FUNCTION {format.authors}
{ author empty$
{ "" }
{ author format.names }
if$
}
FUNCTION {format.title}
{ title empty$
{ "" }
{ title }
if$
}
FUNCTION {format.btitle}
{ title emphasize
}
FUNCTION {format.date}
{ year empty$
{ "" }
{ year }
if$
}
FUNCTION {format.bvolume}
{ volume empty$
{ "" }
{ "volume " volume * }
if$
}
FUNCTION {format.pages}
{ pages empty$
{ "" }
{ "pp. " pages * }
if$
}
FUNCTION {format.url}
{ url empty$
{ "" }
{ "\url{" url * "}" * }
if$
}
FUNCTION {output.bibitem}
{ newline$
"\bibitem{" write$
cite$ write$
"}" write$
newline$
""
before.all 'output.state :=
}
FUNCTION {article}
{ output.bibitem
format.authors "author" output.check
new.block
format.title "title" output.check
new.block
journal emphasize "journal" output.check
format.bvolume output
format.pages output
format.date "year" output.check
format.url output
fin.entry
}
FUNCTION {inproceedings}
{ output.bibitem
format.authors "author" output.check
new.block
format.title "title" output.check
new.block
"In " booktitle emphasize * output
format.pages output
format.date "year" output.check
format.url output
fin.entry
}
FUNCTION {conference} { inproceedings }
FUNCTION {book}
{ output.bibitem
format.authors "author" output.check
new.block
format.btitle "title" output.check
publisher output
format.date "year" output.check
fin.entry
}
FUNCTION {misc}
{ output.bibitem
format.authors output
new.block
format.title output
new.block
note output
format.date output
format.url output
fin.entry
}
FUNCTION {techreport} { misc }
FUNCTION {phdthesis} { misc }
FUNCTION {mastersthesis} { misc }
FUNCTION {unpublished} { misc }
FUNCTION {default.type} { misc }
READ
FUNCTION {sortify}
{ purify$
"l" change.case$
}
FUNCTION {presort}
{ cite$ 'label :=
label sortify
" "
*
#1 entry.max$ substring$
'sort.key$ :=
}
ITERATE {presort}
SORT
FUNCTION {begin.bib}
{ preamble$ empty$
'skip$
{ preamble$ write$ newline$ }
if$
"\begin{thebibliography}{99}" write$ newline$
}
FUNCTION {end.bib}
{ newline$
"\end{thebibliography}" write$ newline$
}
EXECUTE {begin.bib}
EXECUTE {init.state.consts}
ITERATE {call.type$}
EXECUTE {end.bib}
================================================
FILE: researchclaw/templates/styles/iclr_2026/iclr2026_conference.sty
================================================
% iclr2026_conference.sty — ICLR 2026 conference style file
% Based on the official ICLR submission template structure.
% Bundled by AutoResearchClaw for offline compilation.
% Official source: https://github.com/ICLR/Master-Template
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{iclr2026_conference}[2026/01/15 ICLR 2026 conference style]
% ── Options ──────────────────────────────────────────────────────────
\newif\if@iclr@final \@iclr@finalfalse
\newif\if@iclr@preprint \@iclr@preprintfalse
\DeclareOption{final}{\@iclr@finaltrue}
\DeclareOption{preprint}{\@iclr@preprinttrue}
\ProcessOptions\relax
% ── Page geometry ────────────────────────────────────────────────────
\RequirePackage{geometry}
\geometry{
textwidth=5.5in,
textheight=9.0in,
top=1.0in,
headheight=12pt,
headsep=25pt,
footskip=30pt,
}
% ── Fonts ────────────────────────────────────────────────────────────
\RequirePackage{times}
% ── Spacing ──────────────────────────────────────────────────────────
\renewcommand{\baselinestretch}{1.0}
\setlength{\parskip}{0pt}
\setlength{\parindent}{1em}
% ── Section formatting ───────────────────────────────────────────────
\renewcommand{\section}{\@startsection
{section}{1}{0mm}{-2.0ex plus -0.5ex minus -.2ex}%
{1.0ex plus .2ex}{\normalfont\large\bfseries}}
\renewcommand{\subsection}{\@startsection
{subsection}{2}{0mm}{-1.5ex plus -0.5ex minus -.2ex}%
{0.8ex plus .2ex}{\normalfont\normalsize\bfseries}}
\renewcommand{\subsubsection}{\@startsection
{subsubsection}{3}{0mm}{-1.0ex plus -0.5ex minus -.2ex}%
{0.5ex plus .2ex}{\normalfont\normalsize\bfseries}}
% ── Title formatting ────────────────────────────────────────────────
\def\@maketitle{%
\vbox to 0pt{}%
\vskip -0.5in
\begin{center}%
{\LARGE\bfseries \@title \par}%
\vskip 0.3in
\if@iclr@final
{\large
\lineskip .5em
\begin{tabular}[t]{c}%
\@author
\end{tabular}\par}%
\else
{\large Anonymous authors\par}%
{\normalsize Paper under double-blind review\par}%
\fi
\vskip 0.3in
\end{center}%
\par
\vskip 0.5em
}
% ── Abstract ─────────────────────────────────────────────────────────
\renewenvironment{abstract}{%
\centerline{\large\bfseries Abstract}%
\vspace{0.5ex}%
\begin{quote}%
}{%
\par
\end{quote}%
\vskip 1ex
}
% ── Headers ──────────────────────────────────────────────────────────
\RequirePackage{fancyhdr}
\pagestyle{fancy}
\fancyhf{}
\if@iclr@final
\fancyhead[C]{Published as a conference paper at ICLR 2026}
\else
\fancyhead[C]{}
\fi
\fancyfoot[C]{\thepage}
\renewcommand{\headrulewidth}{0pt}
% ── Natbib ───────────────────────────────────────────────────────────
\RequirePackage[numbers,sort&compress]{natbib}
\endinput
================================================
FILE: researchclaw/templates/styles/icml_2025/icml2025.bst
================================================
%% icml2025.bst — ICML 2025 bibliography style
%% Bundled by AutoResearchClaw for offline compilation.
%% Identical format to icml2026.bst.
ENTRY
{ author title journal booktitle year volume number pages doi url note publisher address edition eprint archiveprefix primaryclass }
{}
{ label }
INTEGERS { output.state before.all mid.sentence after.sentence after.block }
FUNCTION {init.state.consts}
{ #0 'before.all := #1 'mid.sentence := #2 'after.sentence := #3 'after.block := }
STRINGS { s t }
FUNCTION {output.nonnull}
{ 's :=
output.state mid.sentence =
{ ", " * write$ }
{ output.state after.block =
{ add.period$ write$ newline$ "\newblock " write$ }
{ output.state before.all = 'write$ { add.period$ " " * write$ } if$ }
if$
mid.sentence 'output.state :=
}
if$
s
}
FUNCTION {output}
{ duplicate$ empty$ 'pop$ 'output.nonnull if$ }
FUNCTION {output.check}
{ 't := duplicate$ empty$ { pop$ "empty " t * " in " * cite$ * warning$ } 'output.nonnull if$ }
FUNCTION {fin.entry} { add.period$ write$ newline$ }
FUNCTION {new.block}
{ output.state before.all = 'skip$ { after.block 'output.state := } if$ }
FUNCTION {not} { { #0 } { #1 } if$ }
FUNCTION {and} { 'skip$ { pop$ #0 } if$ }
FUNCTION {or} { { pop$ #1 } 'skip$ if$ }
FUNCTION {field.or.null} { duplicate$ empty$ { pop$ "" } 'skip$ if$ }
FUNCTION {emphasize} { duplicate$ empty$ { pop$ "" } { "\emph{" swap$ * "}" * } if$ }
INTEGERS { nameptr namesleft numnames }
FUNCTION {format.names}
{ 's := #1 'nameptr := s num.names$ 'numnames := numnames 'namesleft :=
{ namesleft #0 > }
{ s nameptr "{ff~}{vv~}{ll}{, jj}" format.name$ 't :=
nameptr #1 > { namesleft #1 > { ", " * t * } { numnames #2 > { "," * } 'skip$ if$ t "others" = { " et~al." * } { " and " * t * } if$ } if$ } 't if$
nameptr #1 + 'nameptr := namesleft #1 - 'namesleft :=
}
while$
}
FUNCTION {format.authors} { author empty$ { "" } { author format.names } if$ }
FUNCTION {format.title} { title empty$ { "" } { title } if$ }
FUNCTION {format.btitle} { title emphasize }
FUNCTION {format.date} { year empty$ { "" } { year } if$ }
FUNCTION {format.bvolume} { volume empty$ { "" } { "volume " volume * } if$ }
FUNCTION {format.pages} { pages empty$ { "" } { "pp. " pages * } if$ }
FUNCTION {format.url} { url empty$ { "" } { "\url{" url * "}" * } if$ }
FUNCTION {output.bibitem}
{ newline$ "\bibitem{" write$ cite$ write$ "}" write$ newline$ "" before.all 'output.state := }
FUNCTION {article}
{ output.bibitem format.authors "author" output.check new.block format.title "title" output.check new.block journal emphasize "journal" output.check format.bvolume output format.pages output format.date "year" output.check format.url output fin.entry }
FUNCTION {inproceedings}
{ output.bibitem format.authors "author" output.check new.block format.title "title" output.check new.block "In " booktitle emphasize * output format.pages output format.date "year" output.check format.url output fin.entry }
FUNCTION {conference} { inproceedings }
FUNCTION {book}
{ output.bibitem format.authors "author" output.check new.block format.btitle "title" output.check publisher output format.date "year" output.check fin.entry }
FUNCTION {misc}
{ output.bibitem format.authors output new.block format.title output new.block note output format.date output format.url output fin.entry }
FUNCTION {techreport} { misc }
FUNCTION {phdthesis} { misc }
FUNCTION {mastersthesis} { misc }
FUNCTION {unpublished} { misc }
FUNCTION {default.type} { misc }
READ
FUNCTION {sortify} { purify$ "l" change.case$ }
FUNCTION {presort} { cite$ 'label := label sortify " " * #1 entry.max$ substring$ 'sort.key$ := }
ITERATE {presort}
SORT
FUNCTION {begin.bib} { preamble$ empty$ 'skip$ { preamble$ write$ newline$ } if$ "\begin{thebibliography}{99}" write$ newline$ }
FUNCTION {end.bib} { newline$ "\end{thebibliography}" write$ newline$ }
EXECUTE {begin.bib}
EXECUTE {init.state.consts}
ITERATE {call.type$}
EXECUTE {end.bib}
================================================
FILE: researchclaw/templates/styles/icml_2025/icml2025.sty
================================================
% icml2025.sty — ICML 2025 style file
% Based on the official ICML submission template structure.
% Bundled by AutoResearchClaw for offline compilation.
% Official source: https://icml.cc/Conferences/2025/StyleAuthorInstructions
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{icml2025}[2025/01/15 ICML 2025 submission style]
\newif\if@icml@accepted \@icml@acceptedfalse
\newif\if@icml@preprint \@icml@preprintfalse
\DeclareOption{accepted}{\@icml@acceptedtrue}
\DeclareOption{preprint}{\@icml@preprinttrue}
\ProcessOptions\relax
\RequirePackage{geometry}
\geometry{textwidth=6.875in,textheight=9.25in,columnsep=0.25in,top=0.75in,headheight=12pt,headsep=12pt,footskip=20pt}
\twocolumn
\RequirePackage{times}
\renewcommand{\baselinestretch}{1.0}
\setlength{\parskip}{0pt}
\setlength{\parindent}{1em}
\renewcommand{\section}{\@startsection{section}{1}{0mm}{-2.0ex plus -0.5ex minus -.2ex}{0.8ex plus .2ex}{\normalfont\large\bfseries}}
\renewcommand{\subsection}{\@startsection{subsection}{2}{0mm}{-1.5ex plus -0.5ex minus -.2ex}{0.5ex plus .2ex}{\normalfont\normalsize\bfseries}}
\renewcommand{\subsubsection}{\@startsection{subsubsection}{3}{0mm}{-1.0ex plus -0.5ex minus -.2ex}{0.3ex plus .2ex}{\normalfont\normalsize\bfseries}}
\newenvironment{icmlauthorlist}{\begin{center}\large}{\end{center}}
\newcommand{\icmlauthor}[2]{#1\textsuperscript{#2}}
\newcommand{\icmlaffiliation}[2]{\par\normalsize\textsuperscript{#1}#2}
\newcommand{\icmltitlerunning}[1]{\def\@icml@runningtitle{#1}}
\def\@icml@runningtitle{}
\def\@maketitle{%
\twocolumn[%
\vskip -0.3in
\begin{center}%
{\LARGE\bfseries \@title \par}\vskip 0.2in
\if@icml@accepted
{\large\lineskip .5em\begin{tabular}[t]{c}\@author\end{tabular}\par}%
\else\if@icml@preprint
{\large\lineskip .5em\begin{tabular}[t]{c}\@author\end{tabular}\par}%
\else
{\large Anonymous submission\par}%
\fi\fi
\vskip 0.2in
\end{center}%
]%
}
\renewenvironment{abstract}{\centerline{\bfseries Abstract}\vspace{0.5ex}\begin{quote}\small}{\par\end{quote}\vskip 1ex}
\RequirePackage{fancyhdr}
\pagestyle{fancy}\fancyhf{}
\if@icml@accepted
\fancyhead[C]{\small Proceedings of the $42^{nd}$ International Conference on Machine Learning, 2025}
\else
\fancyhead[C]{\small\@icml@runningtitle}
\fi
\fancyfoot[C]{\thepage}
\renewcommand{\headrulewidth}{0pt}
\RequirePackage[numbers,sort&compress]{natbib}
\endinput
================================================
FILE: researchclaw/templates/styles/icml_2026/icml2026.bst
================================================
%% icml2026.bst — ICML 2026 bibliography style
%% Bundled by AutoResearchClaw for offline compilation.
%% Minimal numeric bibliography style compatible with natbib.
ENTRY
{ author title journal booktitle year volume number pages doi url note publisher address edition eprint archiveprefix primaryclass }
{}
{ label }
INTEGERS { output.state before.all mid.sentence after.sentence after.block }
FUNCTION {init.state.consts}
{ #0 'before.all := #1 'mid.sentence := #2 'after.sentence := #3 'after.block := }
STRINGS { s t }
FUNCTION {output.nonnull}
{ 's :=
output.state mid.sentence =
{ ", " * write$ }
{ output.state after.block =
{ add.period$ write$ newline$ "\newblock " write$ }
{ output.state before.all = 'write$ { add.period$ " " * write$ } if$ }
if$
mid.sentence 'output.state :=
}
if$
s
}
FUNCTION {output}
{ duplicate$ empty$ 'pop$ 'output.nonnull if$ }
FUNCTION {output.check}
{ 't := duplicate$ empty$ { pop$ "empty " t * " in " * cite$ * warning$ } 'output.nonnull if$ }
FUNCTION {fin.entry} { add.period$ write$ newline$ }
FUNCTION {new.block}
{ output.state before.all = 'skip$ { after.block 'output.state := } if$ }
FUNCTION {not} { { #0 } { #1 } if$ }
FUNCTION {and} { 'skip$ { pop$ #0 } if$ }
FUNCTION {or} { { pop$ #1 } 'skip$ if$ }
FUNCTION {field.or.null} { duplicate$ empty$ { pop$ "" } 'skip$ if$ }
FUNCTION {emphasize} { duplicate$ empty$ { pop$ "" } { "\emph{" swap$ * "}" * } if$ }
INTEGERS { nameptr namesleft numnames }
FUNCTION {format.names}
{ 's := #1 'nameptr := s num.names$ 'numnames := numnames 'namesleft :=
{ namesleft #0 > }
{ s nameptr "{ff~}{vv~}{ll}{, jj}" format.name$ 't :=
nameptr #1 > { namesleft #1 > { ", " * t * } { numnames #2 > { "," * } 'skip$ if$ t "others" = { " et~al." * } { " and " * t * } if$ } if$ } 't if$
nameptr #1 + 'nameptr := namesleft #1 - 'namesleft :=
}
while$
}
FUNCTION {format.authors} { author empty$ { "" } { author format.names } if$ }
FUNCTION {format.title} { title empty$ { "" } { title } if$ }
FUNCTION {format.btitle} { title emphasize }
FUNCTION {format.date} { year empty$ { "" } { year } if$ }
FUNCTION {format.bvolume} { volume empty$ { "" } { "volume " volume * } if$ }
FUNCTION {format.pages} { pages empty$ { "" } { "pp. " pages * } if$ }
FUNCTION {format.url} { url empty$ { "" } { "\url{" url * "}" * } if$ }
FUNCTION {output.bibitem}
{ newline$ "\bibitem{" write$ cite$ write$ "}" write$ newline$ "" before.all 'output.state := }
FUNCTION {article}
{ output.bibitem format.authors "author" output.check new.block format.title "title" output.check new.block journal emphasize "journal" output.check format.bvolume output format.pages output format.date "year" output.check format.url output fin.entry }
FUNCTION {inproceedings}
{ output.bibitem format.authors "author" output.check new.block format.title "title" output.check new.block "In " booktitle emphasize * output format.pages output format.date "year" output.check format.url output fin.entry }
FUNCTION {conference} { inproceedings }
FUNCTION {book}
{ output.bibitem format.authors "author" output.check new.block format.btitle "title" output.check publisher output format.date "year" output.check fin.entry }
FUNCTION {misc}
{ output.bibitem format.authors output new.block format.title output new.block note output format.date output format.url output fin.entry }
FUNCTION {techreport} { misc }
FUNCTION {phdthesis} { misc }
FUNCTION {mastersthesis} { misc }
FUNCTION {unpublished} { misc }
FUNCTION {default.type} { misc }
READ
FUNCTION {sortify} { purify$ "l" change.case$ }
FUNCTION {presort} { cite$ 'label := label sortify " " * #1 entry.max$ substring$ 'sort.key$ := }
ITERATE {presort}
SORT
FUNCTION {begin.bib} { preamble$ empty$ 'skip$ { preamble$ write$ newline$ } if$ "\begin{thebibliography}{99}" write$ newline$ }
FUNCTION {end.bib} { newline$ "\end{thebibliography}" write$ newline$ }
EXECUTE {begin.bib}
EXECUTE {init.state.consts}
ITERATE {call.type$}
EXECUTE {end.bib}
================================================
FILE: researchclaw/templates/styles/icml_2026/icml2026.sty
================================================
% icml2026.sty — ICML 2026 style file
% Based on the official ICML submission template structure.
% Bundled by AutoResearchClaw for offline compilation.
% Official source: https://icml.cc/Conferences/2026/AuthorInstructions
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{icml2026}[2026/01/15 ICML 2026 submission style]
% ── Options ──────────────────────────────────────────────────────────
\newif\if@icml@accepted \@icml@acceptedfalse
\newif\if@icml@preprint \@icml@preprintfalse
\DeclareOption{accepted}{\@icml@acceptedtrue}
\DeclareOption{preprint}{\@icml@preprinttrue}
\ProcessOptions\relax
% ── Page geometry (2-column) ─────────────────────────────────────────
\RequirePackage{geometry}
\geometry{
textwidth=6.875in,
textheight=9.25in,
columnsep=0.25in,
top=0.75in,
headheight=12pt,
headsep=12pt,
footskip=20pt,
}
\twocolumn
% ── Fonts ────────────────────────────────────────────────────────────
\RequirePackage{times}
% ── Spacing ──────────────────────────────────────────────────────────
\renewcommand{\baselinestretch}{1.0}
\setlength{\parskip}{0pt}
\setlength{\parindent}{1em}
% ── Section formatting ───────────────────────────────────────────────
\renewcommand{\section}{\@startsection
{section}{1}{0mm}{-2.0ex plus -0.5ex minus -.2ex}%
{0.8ex plus .2ex}{\normalfont\large\bfseries}}
\renewcommand{\subsection}{\@startsection
{subsection}{2}{0mm}{-1.5ex plus -0.5ex minus -.2ex}%
{0.5ex plus .2ex}{\normalfont\normalsize\bfseries}}
\renewcommand{\subsubsection}{\@startsection
{subsubsection}{3}{0mm}{-1.0ex plus -0.5ex minus -.2ex}%
{0.3ex plus .2ex}{\normalfont\normalsize\bfseries}}
% ── Title formatting ────────────────────────────────────────────────
% ICML-specific author macros
\newenvironment{icmlauthorlist}{\begin{center}\large}{\end{center}}
\newcommand{\icmlauthor}[2]{#1\textsuperscript{#2}}
\newcommand{\icmlaffiliation}[2]{\par\normalsize\textsuperscript{#1}#2}
\newcommand{\icmltitlerunning}[1]{\def\@icml@runningtitle{#1}}
\def\@icml@runningtitle{}
\def\@maketitle{%
\twocolumn[%
\vskip -0.3in
\begin{center}%
{\LARGE\bfseries \@title \par}%
\vskip 0.2in
\if@icml@accepted
{\large
\lineskip .5em
\begin{tabular}[t]{c}%
\@author
\end{tabular}\par}%
\else
\if@icml@preprint
{\large
\lineskip .5em
\begin{tabular}[t]{c}%
\@author
\end{tabular}\par}%
\else
{\large Anonymous submission\par}%
\fi
\fi
\vskip 0.2in
\end{center}%
]%
}
% ── Abstract ─────────────────────────────────────────────────────────
\renewenvironment{abstract}{%
\centerline{\bfseries Abstract}%
\vspace{0.5ex}%
\begin{quote}\small%
}{%
\par
\end{quote}%
\vskip 1ex
}
% ── Headers ──────────────────────────────────────────────────────────
\RequirePackage{fancyhdr}
\pagestyle{fancy}
\fancyhf{}
\if@icml@accepted
\fancyhead[C]{\small Proceedings of the $43^{rd}$ International Conference on Machine Learning, 2026}
\else
\fancyhead[C]{\small\@icml@runningtitle}
\fi
\fancyfoot[C]{\thepage}
\renewcommand{\headrulewidth}{0pt}
% ── Natbib ───────────────────────────────────────────────────────────
\RequirePackage[numbers,sort&compress]{natbib}
\endinput
================================================
FILE: researchclaw/templates/styles/neurips_2024/neurips_2024.sty
================================================
% neurips_2024.sty — NeurIPS 2024 style file
% Based on the official NeurIPS submission template structure.
% Bundled by AutoResearchClaw for offline compilation.
% Official source: https://media.neurips.cc/Conferences/NeurIPS2024/Styles.zip
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{neurips_2024}[2024/01/15 NeurIPS 2024 submission style]
% ── Options ──────────────────────────────────────────────────────────
\newif\if@neurips@preprint \@neurips@preprinttrue
\newif\if@neurips@final \@neurips@finalfalse
\newif\if@neurips@nonatbib \@neurips@nonatbibfalse
\DeclareOption{preprint}{\@neurips@preprinttrue\@neurips@finalfalse}
\DeclareOption{final}{\@neurips@finaltrue\@neurips@preprintfalse}
\DeclareOption{nonatbib}{\@neurips@nonatbibtrue}
\ProcessOptions\relax
% ── Page geometry ────────────────────────────────────────────────────
\RequirePackage{geometry}
\geometry{
textwidth=6.0in,
textheight=9.0in,
top=1.0in,
headheight=12pt,
headsep=25pt,
footskip=30pt,
}
% ── Fonts ────────────────────────────────────────────────────────────
\RequirePackage{times}
% ── Spacing ──────────────────────────────────────────────────────────
\renewcommand{\baselinestretch}{1.0}
\setlength{\parskip}{0pt}
\setlength{\parindent}{1em}
% ── Section formatting ───────────────────────────────────────────────
\renewcommand{\section}{\@startsection
{section}{1}{0mm}{-2.0ex plus -0.5ex minus -.2ex}%
{1.0ex plus .2ex}{\normalfont\large\bfseries}}
\renewcommand{\subsection}{\@startsection
{subsection}{2}{0mm}{-1.5ex plus -0.5ex minus -.2ex}%
{0.8ex plus .2ex}{\normalfont\normalsize\bfseries}}
\renewcommand{\subsubsection}{\@startsection
{subsubsection}{3}{0mm}{-1.0ex plus -0.5ex minus -.2ex}%
{0.5ex plus .2ex}{\normalfont\normalsize\bfseries}}
% ── Title formatting ────────────────────────────────────────────────
\def\@maketitle{%
\vbox to 0pt{}%
\vskip -0.5in
\begin{center}%
{\LARGE\bfseries \@title \par}%
\vskip 0.3in
\if@neurips@preprint
{\large\textit{Preprint. Under review.}\par}%
\vskip 0.1in
\fi
{\large
\lineskip .5em
\begin{tabular}[t]{c}%
\@author
\end{tabular}\par}%
\vskip 0.3in
\end{center}%
\par
\vskip 0.5em
}
% ── Abstract ─────────────────────────────────────────────────────────
\renewenvironment{abstract}{%
\centerline{\large\bfseries Abstract}%
\vspace{0.5ex}%
\begin{quote}%
}{%
\par
\end{quote}%
\vskip 1ex
}
% ── Headers ──────────────────────────────────────────────────────────
\RequirePackage{fancyhdr}
\pagestyle{fancy}
\fancyhf{}
\fancyhead[C]{}
\fancyfoot[C]{\thepage}
\renewcommand{\headrulewidth}{0pt}
% ── Natbib ───────────────────────────────────────────────────────────
\if@neurips@nonatbib\else
\RequirePackage[numbers,sort&compress]{natbib}
\fi
\endinput
================================================
FILE: researchclaw/templates/styles/neurips_2025/neurips_2025.sty
================================================
% neurips_2025.sty — NeurIPS 2025 style file
% Based on the official NeurIPS submission template structure.
% Bundled by AutoResearchClaw for offline compilation.
% Official source: https://media.neurips.cc/Conferences/NeurIPS2025/Styles.zip
\NeedsTeXFormat{LaTeX2e}
\ProvidesPackage{neurips_2025}[2025/01/15 NeurIPS 2025 submission style]
% ── Options ──────────────────────────────────────────────────────────
\newif\if@neurips@preprint \@neurips@preprinttrue
\newif\if@neurips@final \@neurips@finalfalse
\newif\if@neurips@nonatbib \@neurips@nonatbibfalse
\DeclareOption{preprint}{\@neurips@preprinttrue\@neurips@finalfalse}
\DeclareOption{final}{\@neurips@finaltrue\@neurips@preprintfalse}
\DeclareOption{nonatbib}{\@neurips@nonatbibtrue}
\ProcessOptions\relax
% ── Page geometry ────────────────────────────────────────────────────
\RequirePackage{geometry}
\geometry{
textwidth=6.0in,
textheight=9.0in,
top=1.0in,
headheight=12pt,
headsep=25pt,
footskip=30pt,
}
% ── Fonts ────────────────────────────────────────────────────────────
\RequirePackage{times}
% ── Spacing ──────────────────────────────────────────────────────────
\renewcommand{\baselinestretch}{1.0}
\setlength{\parskip}{0pt}
\setlength{\parindent}{1em}
% ── Section formatting ───────────────────────────────────────────────
\renewcommand{\section}{\@startsection
{section}{1}{0mm}{-2.0ex plus -0.5ex minus -.2ex}%
{1.0ex plus .2ex}{\normalfont\large\bfseries}}
\renewcommand{\subsection}{\@startsection
{subsection}{2}{0mm}{-1.5ex plus -0.5ex minus -.2ex}%
{0.8ex plus .2ex}{\normalfont\normalsize\bfseries}}
\renewcommand{\subsubsection}{\@startsection
{subsubsection}{3}{0mm}{-1.0ex plus -0.5ex minus -.2ex}%
{0.5ex plus .2ex}{\normalfont\normalsize\bfseries}}
% ── Title formatting ────────────────────────────────────────────────
\def\@maketitle{%
\vbox to 0pt{}%
\vskip -0.5in
\begin{center}%
{\LARGE\bfseries \@title \par}%
\vskip 0.3in
\if@neurips@preprint
{\large\textit{Preprint. Under review.}\par}%
\vskip 0.1in
\fi
{\large
\lineskip .5em
\begin{tabular}[t]{c}%
\@author
\end{tabular}\par}%
\vskip 0.3in
\end{center}%
\par
\vskip 0.5em
}
% ── Abstract ─────────────────────────────────────────────────────────
\renewenvironment{abstract}{%
\centerline{\large\bfseries Abstract}%
\vspace{0.5ex}%
\begin{quote}%
}{%
\par
\end{quote}%
\vskip 1ex
}
% ── Headers ──────────────────────────────────────────────────────────
\RequirePackage{fancyhdr}
\pagestyle{fancy}
\fancyhf{}
\fancyhead[C]{}
\fancyfoot[C]{\thepage}
\renewcommand{\headrulewidth}{0pt}
% ── Natbib ───────────────────────────────────────────────────────────
\if@neurips@nonatbib\else
\RequirePackage[numbers,sort&compress]{natbib}
\fi
% ── Hyperref-friendly ────────────────────────────────────────────────
\AtBeginDocument{%
\@ifpackageloaded{hyperref}{%
\hypersetup{colorlinks=true,linkcolor=red,citecolor=green,urlcolor=blue}%
}{}%
}
\endinput
================================================
FILE: researchclaw/trends/__init__.py
================================================
"""Research trend tracking and automatic topic generation."""
from researchclaw.trends.daily_digest import DailyDigest
from researchclaw.trends.trend_analyzer import TrendAnalyzer
from researchclaw.trends.opportunity_finder import OpportunityFinder
from researchclaw.trends.auto_topic import AutoTopicGenerator
from researchclaw.trends.feeds import FeedManager
__all__ = [
"AutoTopicGenerator",
"DailyDigest",
"FeedManager",
"OpportunityFinder",
"TrendAnalyzer",
]
================================================
FILE: researchclaw/trends/auto_topic.py
================================================
"""Automatic research topic generation (ClawZero mode)."""
from __future__ import annotations
import logging
from typing import Any
from researchclaw.trends.opportunity_finder import OpportunityFinder
from researchclaw.trends.trend_analyzer import TrendAnalyzer
logger = logging.getLogger(__name__)
class AutoTopicGenerator:
"""Generate and rank candidate research topics automatically."""
def __init__(
self,
trend_analyzer: TrendAnalyzer,
opportunity_finder: OpportunityFinder,
llm_client: Any = None,
):
self.trend_analyzer = trend_analyzer
self.opportunity_finder = opportunity_finder
self.llm = llm_client
async def generate_candidates(
self,
domains: list[str],
papers: list[dict[str, Any]] | None = None,
count: int = 5,
) -> list[dict[str, Any]]:
"""Generate ranked candidate research topics."""
# 1. Analyze trends
trend_analysis = self.trend_analyzer.analyze(papers or [])
# 2. Find opportunities
opportunities = await self.opportunity_finder.find_opportunities(
trend_analysis, domains
)
# 3. Score and rank candidates
candidates = []
for opp in opportunities[:count]:
score = self._score_candidate(opp, trend_analysis)
candidates.append({
"topic": opp["topic"],
"rationale": opp.get("rationale", ""),
"feasibility": opp.get("feasibility", "medium"),
"novelty_score": score["novelty"],
"feasibility_score": score["feasibility"],
"impact_score": score["impact"],
"overall_score": score["overall"],
"source": opp.get("source", "unknown"),
})
candidates.sort(key=lambda c: -c["overall_score"])
return candidates[:count]
async def auto_select(
self,
domains: list[str],
papers: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Fully automatic topic selection (Zero-Touch mode)."""
candidates = await self.generate_candidates(domains, papers, count=5)
if not candidates:
return {
"topic": f"Novel approaches in {domains[0] if domains else 'ML'}",
"rationale": "Default topic (no trends data available)",
"overall_score": 0.0,
"source": "default",
}
return candidates[0]
@staticmethod
def _score_candidate(
opportunity: dict[str, Any],
trend_analysis: dict[str, Any],
) -> dict[str, float]:
"""Score a candidate topic on novelty, feasibility, and impact."""
feasibility_map = {"high": 0.9, "medium": 0.6, "low": 0.3}
feasibility = feasibility_map.get(
opportunity.get("feasibility", "medium"), 0.6
)
# Novelty: inverse of how much it's already been studied
topic_words = set(opportunity.get("topic", "").lower().split())
keyword_overlap = 0
for kw in trend_analysis.get("rising_keywords", []):
kw_words = set(kw.get("keyword", "").lower().split())
if topic_words & kw_words:
keyword_overlap += 1
novelty = max(0.3, 1.0 - keyword_overlap * 0.15)
# Impact: based on trend momentum
paper_count = trend_analysis.get("paper_count", 0)
impact = min(1.0, paper_count / 50) if paper_count > 0 else 0.5
overall = round(
0.4 * novelty + 0.3 * feasibility + 0.3 * impact, 3
)
return {
"novelty": round(novelty, 3),
"feasibility": round(feasibility, 3),
"impact": round(impact, 3),
"overall": overall,
}
def format_candidates(
self,
candidates: list[dict[str, Any]],
) -> str:
"""Format candidates as a readable string."""
if not candidates:
return "No candidate topics generated."
lines = ["Candidate Research Topics:", "=" * 40, ""]
for i, c in enumerate(candidates, 1):
lines.extend([
f"{i}. {c['topic']}",
f" Score: {c['overall_score']:.2f} "
f"(novelty={c['novelty_score']:.2f}, "
f"feasibility={c['feasibility_score']:.2f}, "
f"impact={c['impact_score']:.2f})",
f" Rationale: {c.get('rationale', 'N/A')}",
"",
])
return "\n".join(lines)
================================================
FILE: researchclaw/trends/daily_digest.py
================================================
"""Daily paper digest generation."""
from __future__ import annotations
import logging
from datetime import date
from pathlib import Path
from typing import Any
from researchclaw.trends.feeds import FeedManager
logger = logging.getLogger(__name__)
class DailyDigest:
"""Generate daily paper digest reports."""
def __init__(
self,
feed_manager: FeedManager,
llm_client: Any = None,
):
self.feeds = feed_manager
self.llm = llm_client
async def generate(
self,
domains: list[str] | None = None,
max_papers: int = 20,
target_date: date | None = None,
) -> str:
"""Generate a daily paper digest as Markdown."""
effective_domains = domains or ["machine learning"]
today = target_date or date.today()
papers = self.feeds.fetch_recent_papers(
domains=effective_domains,
max_papers=max_papers,
since_date=today,
)
if not papers:
return (
f"## Daily Paper Digest ({today})\n\n"
f"No new papers found for domains: {', '.join(effective_domains)}\n"
)
if self.llm is not None:
return await self._generate_with_llm(papers, effective_domains, today)
return self._generate_basic(papers, effective_domains, today)
async def _generate_with_llm(
self,
papers: list[dict[str, Any]],
domains: list[str],
today: date,
) -> str:
"""Generate digest with LLM-enhanced summaries."""
lines = [
f"## Daily Paper Digest ({today})",
f"Domains: {', '.join(domains)}",
f"Papers found: {len(papers)}",
"",
]
for i, paper in enumerate(papers, 1):
title = paper.get("title", "Untitled")
url = paper.get("url", "")
abstract = paper.get("abstract", "")[:500]
authors = paper.get("authors", [])
if isinstance(authors, list):
author_str = ", ".join(
a if isinstance(a, str) else a.get("name", "")
for a in authors[:3]
)
if len(authors) > 3:
author_str += " et al."
else:
author_str = str(authors)
# Get LLM summary
try:
prompt = (
f"Summarize this paper in 2 sentences and rate its relevance "
f"to {', '.join(domains)} on a scale of 1-5 stars.\n\n"
f"Title: {title}\nAbstract: {abstract}\n\n"
f"Format: SUMMARY: | RELEVANCE: <1-5>"
)
response = await self.llm.chat_async(prompt)
summary, relevance = self._parse_summary(response)
except Exception:
summary = abstract[:200] + "..." if len(abstract) > 200 else abstract
relevance = 3
stars = "*" * relevance
link = f"[{title}]({url})" if url else title
lines.extend([
f"### {i}. {link}",
f"**Authors**: {author_str}",
f"**Relevance**: {stars}",
f"**Summary**: {summary}",
"",
])
return "\n".join(lines)
def _generate_basic(
self,
papers: list[dict[str, Any]],
domains: list[str],
today: date,
) -> str:
"""Generate basic digest without LLM."""
lines = [
f"## Daily Paper Digest ({today})",
f"Domains: {', '.join(domains)}",
f"Papers found: {len(papers)}",
"",
]
for i, paper in enumerate(papers, 1):
title = paper.get("title", "Untitled")
url = paper.get("url", "")
abstract = paper.get("abstract", "")
authors = paper.get("authors", [])
if isinstance(authors, list):
author_str = ", ".join(
a if isinstance(a, str) else a.get("name", "")
for a in authors[:3]
)
if len(authors) > 3:
author_str += " et al."
else:
author_str = str(authors)
short_abstract = (
abstract[:200] + "..." if len(abstract) > 200 else abstract
)
link = f"[{title}]({url})" if url else title
lines.extend([
f"### {i}. {link}",
f"**Authors**: {author_str}",
f"**Abstract**: {short_abstract}",
"",
])
return "\n".join(lines)
@staticmethod
def _parse_summary(response: str) -> tuple[str, int]:
"""Parse LLM summary response."""
summary = response
relevance = 3
if "SUMMARY:" in response:
parts = response.split("|")
summary = parts[0].split("SUMMARY:", 1)[-1].strip()
if len(parts) > 1 and "RELEVANCE:" in parts[1]:
try:
rel_str = parts[1].split("RELEVANCE:", 1)[-1].strip()
relevance = int(rel_str.strip("* "))
relevance = max(1, min(5, relevance))
except (ValueError, IndexError):
pass
return summary, relevance
async def generate_and_save(
self,
output_dir: Path,
domains: list[str] | None = None,
max_papers: int = 20,
) -> Path:
"""Generate digest and save to a file."""
today = date.today()
content = await self.generate(domains, max_papers, today)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"digest_{today}.md"
output_file.write_text(content, encoding="utf-8")
return output_file
================================================
FILE: researchclaw/trends/feeds.py
================================================
"""ArXiv / Semantic Scholar / OpenAlex feed management."""
from __future__ import annotations
import logging
from datetime import date, datetime
from typing import Any
logger = logging.getLogger(__name__)
class FeedManager:
"""Manage paper feeds from multiple sources."""
SUPPORTED_SOURCES = ("arxiv", "semantic_scholar", "openalex")
def __init__(
self,
sources: tuple[str, ...] = ("arxiv", "semantic_scholar"),
s2_api_key: str = "",
):
self.sources = tuple(
s for s in sources if s in self.SUPPORTED_SOURCES
)
self.s2_api_key = s2_api_key
def fetch_recent_papers(
self,
domains: list[str],
max_papers: int = 20,
since_date: date | None = None,
) -> list[dict[str, Any]]:
"""Fetch recent papers from configured sources.
Returns a list of paper dicts with: title, authors, abstract,
url, source, published_date, domains.
"""
all_papers: list[dict[str, Any]] = []
target_date = since_date or date.today()
for source in self.sources:
try:
if source == "arxiv":
papers = self._fetch_arxiv(domains, max_papers, target_date)
elif source == "semantic_scholar":
papers = self._fetch_s2(domains, max_papers, target_date)
elif source == "openalex":
papers = self._fetch_openalex(domains, max_papers, target_date)
else:
continue
all_papers.extend(papers)
except Exception as exc:
logger.warning("Feed fetch failed for %s: %s", source, exc)
# Deduplicate by title similarity
seen_titles: set[str] = set()
deduped: list[dict[str, Any]] = []
for paper in all_papers:
norm_title = paper.get("title", "").lower().strip()
if norm_title and norm_title not in seen_titles:
seen_titles.add(norm_title)
deduped.append(paper)
return deduped[:max_papers]
def _fetch_arxiv(
self,
domains: list[str],
max_papers: int,
since_date: date,
) -> list[dict[str, Any]]:
"""Fetch papers from arXiv API."""
try:
from researchclaw.literature.arxiv_client import search_arxiv
except ImportError:
logger.debug("arxiv_client not available")
return []
query = " OR ".join(domains) if domains else "machine learning"
try:
results = search_arxiv(query, limit=max_papers)
return [
{
"title": r.get("title", ""),
"authors": r.get("authors", []),
"abstract": r.get("abstract", ""),
"url": r.get("url", ""),
"source": "arxiv",
"published_date": r.get("published", since_date.isoformat()),
"arxiv_id": r.get("arxiv_id", ""),
}
for r in results
]
except Exception as exc:
logger.warning("ArXiv fetch failed: %s", exc)
return []
def _fetch_s2(
self,
domains: list[str],
max_papers: int,
since_date: date,
) -> list[dict[str, Any]]:
"""Fetch papers from Semantic Scholar API."""
try:
from researchclaw.literature.semantic_scholar import search_s2
except ImportError:
logger.debug("semantic_scholar client not available")
return []
query = " ".join(domains) if domains else "machine learning"
try:
results = search_s2(
query,
limit=max_papers,
year_min=since_date.year,
api_key=self.s2_api_key,
)
return [
{
"title": r.get("title", ""),
"authors": [
a.get("name", "") for a in r.get("authors", [])
],
"abstract": r.get("abstract", ""),
"url": r.get("url", ""),
"source": "semantic_scholar",
"published_date": str(r.get("year", since_date.year)),
"citation_count": r.get("citationCount", 0),
}
for r in results
]
except Exception as exc:
logger.warning("S2 fetch failed: %s", exc)
return []
def _fetch_openalex(
self,
domains: list[str],
max_papers: int,
since_date: date,
) -> list[dict[str, Any]]:
"""Fetch papers from OpenAlex API."""
try:
from researchclaw.literature.openalex_client import search_openalex
except ImportError:
logger.debug("openalex_client not available")
return []
query = " ".join(domains) if domains else "machine learning"
try:
results = search_openalex(query, limit=max_papers)
return [
{
"title": r.get("title", ""),
"authors": r.get("authors", []),
"abstract": r.get("abstract", ""),
"url": r.get("url", ""),
"source": "openalex",
"published_date": r.get("publication_date", ""),
"citation_count": r.get("cited_by_count", 0),
}
for r in results
]
except Exception as exc:
logger.warning("OpenAlex fetch failed: %s", exc)
return []
================================================
FILE: researchclaw/trends/opportunity_finder.py
================================================
"""Research opportunity discovery."""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
class OpportunityFinder:
"""Identify research opportunities from trend analysis."""
def __init__(self, llm_client: Any = None):
self.llm = llm_client
async def find_opportunities(
self,
trend_analysis: dict[str, Any],
domains: list[str],
) -> list[dict[str, Any]]:
"""Identify research gaps and opportunities."""
if self.llm is not None:
return await self._llm_find_opportunities(trend_analysis, domains)
return self._heuristic_find_opportunities(trend_analysis, domains)
async def _llm_find_opportunities(
self,
trend_analysis: dict[str, Any],
domains: list[str],
) -> list[dict[str, Any]]:
"""Use LLM to identify research opportunities."""
keywords = trend_analysis.get("rising_keywords", [])[:10]
methods = trend_analysis.get("method_trends", [])[:5]
prompt = (
"Based on the following research trends, identify 5 promising "
"research opportunities:\n\n"
f"Domains: {', '.join(domains)}\n"
f"Trending keywords: {[k['keyword'] for k in keywords]}\n"
f"Popular methods: {[m['method'] for m in methods]}\n\n"
"For each opportunity, provide:\n"
"1. A concise research question\n"
"2. Why it's promising (1 sentence)\n"
"3. Feasibility estimate (high/medium/low)\n\n"
"Format each as: TOPIC: ... | WHY: ... | FEASIBILITY: ..."
)
try:
response = await self.llm.chat_async(prompt)
return self._parse_opportunities(response)
except Exception as exc:
logger.warning("LLM opportunity finding failed: %s", exc)
return self._heuristic_find_opportunities(trend_analysis, domains)
@staticmethod
def _parse_opportunities(response: str) -> list[dict[str, Any]]:
"""Parse LLM response into structured opportunities."""
opportunities = []
for line in response.strip().split("\n"):
line = line.strip()
if not line or not any(
marker in line for marker in ("TOPIC:", "topic:", "1.", "2.", "3.")
):
continue
parts = line.split("|")
topic = parts[0].split(":", 1)[-1].strip() if parts else line
why = parts[1].split(":", 1)[-1].strip() if len(parts) > 1 else ""
feasibility = (
parts[2].split(":", 1)[-1].strip().lower()
if len(parts) > 2
else "medium"
)
if topic:
opportunities.append({
"topic": topic,
"rationale": why,
"feasibility": feasibility,
"source": "llm",
})
return opportunities[:5]
@staticmethod
def _heuristic_find_opportunities(
trend_analysis: dict[str, Any],
domains: list[str],
) -> list[dict[str, Any]]:
"""Simple heuristic-based opportunity finding."""
opportunities: list[dict[str, Any]] = []
keywords = trend_analysis.get("rising_keywords", [])
methods = trend_analysis.get("method_trends", [])
# Combine trending keywords with methods for opportunity generation
for i, kw in enumerate(keywords[:3]):
for j, method in enumerate(methods[:2]):
topic = (
f"Applying {method['method']} to "
f"{kw['keyword']} in {domains[0] if domains else 'ML'}"
)
opportunities.append({
"topic": topic,
"rationale": (
f"'{kw['keyword']}' is trending ({kw['count']} mentions) "
f"and '{method['method']}' is a popular method"
),
"feasibility": "medium",
"source": "heuristic",
})
if len(opportunities) >= 5:
break
if len(opportunities) >= 5:
break
return opportunities
================================================
FILE: researchclaw/trends/trend_analyzer.py
================================================
"""Research trend analysis engine."""
from __future__ import annotations
import re
import logging
from collections import Counter
from typing import Any
logger = logging.getLogger(__name__)
# Common stopwords to exclude from keyword analysis
_STOPWORDS = frozenset({
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "shall", "can", "need", "must",
"that", "this", "these", "those", "it", "its", "we", "our", "their",
"which", "what", "how", "when", "where", "who", "whom", "why",
"not", "no", "nor", "as", "if", "then", "than", "both", "each",
"all", "any", "few", "more", "most", "some", "such", "only", "very",
"also", "about", "up", "out", "so", "into", "over", "after", "before",
"between", "under", "through", "during", "using", "based", "via",
"paper", "propose", "proposed", "method", "approach", "results", "show",
"new", "novel", "model", "models", "data", "dataset", "task", "tasks",
"performance", "learning", "training",
})
class TrendAnalyzer:
"""Analyze research trends from paper collections."""
def __init__(self, min_keyword_length: int = 3):
self.min_keyword_length = min_keyword_length
def analyze(
self,
papers: list[dict[str, Any]],
window_days: int = 30,
) -> dict[str, Any]:
"""Analyze trends in a collection of papers."""
if not papers:
return {
"rising_keywords": [],
"hot_authors": [],
"popular_datasets": [],
"method_trends": [],
"paper_count": 0,
}
keywords = self._extract_keywords(papers)
authors = self._extract_authors(papers)
datasets = self._extract_datasets(papers)
methods = self._extract_methods(papers)
return {
"rising_keywords": keywords[:20],
"hot_authors": authors[:10],
"popular_datasets": datasets[:10],
"method_trends": methods[:10],
"paper_count": len(papers),
"source_distribution": self._source_distribution(papers),
}
def _extract_keywords(
self,
papers: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Extract and rank keywords from paper titles and abstracts."""
word_counts: Counter[str] = Counter()
bigram_counts: Counter[str] = Counter()
for paper in papers:
text = f"{paper.get('title', '')} {paper.get('abstract', '')}"
words = self._tokenize(text)
for w in words:
if w not in _STOPWORDS and len(w) >= self.min_keyword_length:
word_counts[w] += 1
for i in range(len(words) - 1):
w1, w2 = words[i], words[i + 1]
if (
w1 not in _STOPWORDS
and w2 not in _STOPWORDS
and len(w1) >= self.min_keyword_length
):
bigram_counts[f"{w1} {w2}"] += 1
results = []
for keyword, count in bigram_counts.most_common(30):
if count >= 2:
results.append({"keyword": keyword, "count": count, "type": "bigram"})
for keyword, count in word_counts.most_common(30):
if count >= 2:
results.append({"keyword": keyword, "count": count, "type": "unigram"})
results.sort(key=lambda x: -x["count"])
return results[:20]
def _extract_authors(
self,
papers: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Extract most prolific authors."""
author_counts: Counter[str] = Counter()
for paper in papers:
authors = paper.get("authors", [])
if isinstance(authors, list):
for author in authors:
name = author if isinstance(author, str) else author.get("name", "")
if name:
author_counts[name] += 1
return [
{"author": name, "paper_count": count}
for name, count in author_counts.most_common(10)
if count >= 2
]
def _extract_datasets(
self,
papers: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Extract commonly mentioned datasets."""
dataset_patterns = [
"ImageNet", "CIFAR", "MNIST", "COCO", "SQuAD", "GLUE",
"SuperGLUE", "WikiText", "Penn Treebank", "WMT",
"OpenWebText", "Common Crawl", "BookCorpus",
"MMLU", "HumanEval", "GSM8K", "ARC", "HellaSwag",
]
dataset_counts: Counter[str] = Counter()
for paper in papers:
text = f"{paper.get('title', '')} {paper.get('abstract', '')}"
for ds in dataset_patterns:
if ds.lower() in text.lower():
dataset_counts[ds] += 1
return [
{"dataset": ds, "mention_count": count}
for ds, count in dataset_counts.most_common(10)
if count >= 1
]
def _extract_methods(
self,
papers: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Extract commonly mentioned methods/architectures."""
method_patterns = [
"transformer", "attention", "diffusion", "GAN", "VAE",
"reinforcement learning", "contrastive learning",
"self-supervised", "few-shot", "zero-shot", "in-context",
"fine-tuning", "pre-training", "RLHF", "DPO",
"chain-of-thought", "retrieval-augmented", "RAG",
"mixture of experts", "MoE", "LoRA", "quantization",
"knowledge distillation", "pruning", "graph neural",
]
method_counts: Counter[str] = Counter()
for paper in papers:
text = f"{paper.get('title', '')} {paper.get('abstract', '')}"
for method in method_patterns:
if method.lower() in text.lower():
method_counts[method] += 1
return [
{"method": method, "mention_count": count}
for method, count in method_counts.most_common(10)
if count >= 1
]
@staticmethod
def _source_distribution(
papers: list[dict[str, Any]],
) -> dict[str, int]:
"""Count papers by source."""
dist: Counter[str] = Counter()
for paper in papers:
dist[paper.get("source", "unknown")] += 1
return dict(dist)
@staticmethod
def _tokenize(text: str) -> list[str]:
"""Simple word tokenization."""
return [w.lower() for w in re.findall(r"[a-zA-Z]+(?:[-'][a-zA-Z]+)*", text)]
def generate_trend_report(
self,
analysis: dict[str, Any],
) -> str:
"""Format trend analysis as a readable report."""
lines = [
f"Research Trend Analysis ({analysis.get('paper_count', 0)} papers)",
"=" * 50,
"",
]
keywords = analysis.get("rising_keywords", [])
if keywords:
lines.append("Top Keywords:")
for kw in keywords[:10]:
lines.append(f" - {kw['keyword']} ({kw['count']} mentions)")
lines.append("")
authors = analysis.get("hot_authors", [])
if authors:
lines.append("Most Active Authors:")
for a in authors[:5]:
lines.append(f" - {a['author']} ({a['paper_count']} papers)")
lines.append("")
methods = analysis.get("method_trends", [])
if methods:
lines.append("Method Trends:")
for m in methods[:5]:
lines.append(f" - {m['method']} ({m['mention_count']} mentions)")
lines.append("")
return "\n".join(lines)
================================================
FILE: researchclaw/utils/__init__.py
================================================
"""ResearchClaw utility functions."""
from researchclaw.utils.sanitize import sanitize_figure_id
__all__ = ["sanitize_figure_id"]
================================================
FILE: researchclaw/utils/sanitize.py
================================================
"""Input sanitization utilities for untrusted LLM-generated values."""
from __future__ import annotations
import re
def sanitize_figure_id(raw_id: str, *, fallback: str = "figure") -> str:
"""Sanitize a figure ID for safe use in file paths and Docker names.
Strips path separators, dotdot sequences, and shell metacharacters.
Returns *fallback* if the sanitized result is empty.
>>> sanitize_figure_id("../../etc/evil")
'etc_evil'
>>> sanitize_figure_id("fig test (v2)")
'fig_test_v2'
>>> sanitize_figure_id("")
'figure'
"""
# Replace path separators and dangerous sequences
cleaned = raw_id.replace("..", "").replace("/", "_").replace("\\", "_")
# Keep only safe characters: alphanumeric, hyphen, underscore, dot
cleaned = re.sub(r"[^a-zA-Z0-9_.-]", "_", cleaned)
# Collapse multiple underscores
cleaned = re.sub(r"_+", "_", cleaned).strip("_.")
return cleaned or fallback
================================================
FILE: researchclaw/utils/thinking_tags.py
================================================
"""Strip reasoning artifacts from LLM output before they leak into papers.
Handles ALL known thinking/reasoning formats:
- ``... `` -- DeepSeek-R1, QwQ, Gemini 2.5 format
- ``[thinking] ...`` -- Claude Code / ACP output format (bracket-style)
- Insight blocks -- Claude Code explanatory mode decorators
- ``[plan] ...`` -- Claude Code plan mode markers
- ``[tool] ...`` -- ACP tool invocation output
- ``[client] ...``, ``[acpx] ...``, ``[done] ...`` -- acpx metadata
Without this stripping, these artifacts contaminate:
- Paper drafts (LaTeX / Markdown)
- Generated experiment code
- YAML/JSON responses (search plans, experiment plans)
- Citation references
Usage::
from researchclaw.utils.thinking_tags import strip_thinking_tags
clean = strip_thinking_tags(raw_llm_output)
"""
from __future__ import annotations
import re
# ---------------------------------------------------------------------------
# Pattern 1: XML-style ... (DeepSeek-R1, QwQ, Gemini)
# ---------------------------------------------------------------------------
_THINK_BLOCK_RE = re.compile(
r".*? ",
re.DOTALL | re.IGNORECASE,
)
_THINK_UNCLOSED_RE = re.compile(
r".*",
re.DOTALL | re.IGNORECASE,
)
_THINK_STRAY_CLOSE_RE = re.compile(
r" ",
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Pattern 2: [thinking] blocks (Claude Code / ACP)
# ---------------------------------------------------------------------------
_BRACKET_THINKING_RE = re.compile(
r"\[thinking\].*?(?=\n\n(?!\[thinking\])|\n(?:#{1,3}\s)|\n```|\Z)",
re.DOTALL | re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Pattern 3: Insight blocks (Claude Code explanatory style)
# ---------------------------------------------------------------------------
_INSIGHT_BLOCK_RE = re.compile(
r"`[*\u2605]\s*Insight[^`]*`\s*\n.*?`[\u2500-]+`",
re.DOTALL,
)
_INSIGHT_ASCII_RE = re.compile(
r"`\*\s*Insight[-]+`\s*\n.*?`[-]+`",
re.DOTALL,
)
# ---------------------------------------------------------------------------
# Pattern 4: [plan] blocks (Claude Code plan mode)
# ---------------------------------------------------------------------------
_PLAN_BLOCK_RE = re.compile(
r"\[plan\].*?(?=\n\n|\Z)",
re.DOTALL,
)
# ---------------------------------------------------------------------------
# Pattern 5: ACP/acpx metadata lines
# ---------------------------------------------------------------------------
_ACPX_LINE_RE = re.compile(
r"^\[(client|acpx|tool|done)\](?!\().*$",
re.MULTILINE | re.IGNORECASE,
)
def strip_thinking_tags(text: str) -> str:
"""Remove all reasoning artifacts from LLM output.
Handles XML tags, bracket [thinking] blocks, insight
decorators, plan markers, and acpx metadata.
Returns cleaned text suitable for paper drafts, code, or YAML/JSON.
"""
if not text:
return text
result = text
# Phase 1: XML ... blocks
if "think" in result.lower():
result = _THINK_BLOCK_RE.sub("", result)
result = _THINK_UNCLOSED_RE.sub("", result)
result = _THINK_STRAY_CLOSE_RE.sub("", result)
# Phase 2: [thinking] blocks (ACP/Claude Code)
if "[thinking]" in result.lower():
result = _BRACKET_THINKING_RE.sub("", result)
result = re.sub(
r"^\[thinking\].*$", "", result,
flags=re.MULTILINE | re.IGNORECASE,
)
# Phase 3: Insight blocks
result = _INSIGHT_BLOCK_RE.sub("", result)
result = _INSIGHT_ASCII_RE.sub("", result)
# Phase 4: [plan] blocks
if "[plan]" in result.lower():
result = _PLAN_BLOCK_RE.sub("", result)
# Phase 5: acpx metadata lines
result = _ACPX_LINE_RE.sub("", result)
# Phase 6: Clean up artifacts
result = re.sub(r"^`[\u2500-]+`\s*$", "", result, flags=re.MULTILINE)
result = re.sub(r"^`[-]{20,}`\s*$", "", result, flags=re.MULTILINE)
# Collapse excessive blank lines
result = re.sub(r"\n{3,}", "\n\n", result)
return result.strip()
================================================
FILE: researchclaw/voice/__init__.py
================================================
"""Voice interaction modules."""
================================================
FILE: researchclaw/voice/commands.py
================================================
"""Voice command parsing."""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
class VoiceCommand(str, Enum):
"""Recognized voice commands."""
START = "start"
STOP = "stop"
PAUSE = "pause"
RESUME = "resume"
STATUS = "status"
NONE = "none" # Not a command, forward to chat
@dataclass
class ParsedVoiceInput:
"""Result of parsing voice input."""
command: VoiceCommand
text: str # original or remaining text
# Command patterns (Chinese + English)
_COMMAND_PATTERNS: list[tuple[VoiceCommand, re.Pattern[str]]] = [
(VoiceCommand.START, re.compile(r"^(?:start|run|开始|启动)", re.IGNORECASE)),
(VoiceCommand.STOP, re.compile(r"^(?:stop|停止|结束|终止)", re.IGNORECASE)),
(VoiceCommand.PAUSE, re.compile(r"^(?:pause|暂停|等一下)", re.IGNORECASE)),
(VoiceCommand.RESUME, re.compile(r"^(?:resume|continue|继续|恢复)", re.IGNORECASE)),
(VoiceCommand.STATUS, re.compile(r"^(?:status|progress|进度|到哪了|查看)", re.IGNORECASE)),
]
def parse_voice_input(text: str) -> ParsedVoiceInput:
"""Parse transcribed voice input into command + text."""
stripped = text.strip()
for cmd, pattern in _COMMAND_PATTERNS:
if pattern.search(stripped):
return ParsedVoiceInput(command=cmd, text=stripped)
return ParsedVoiceInput(command=VoiceCommand.NONE, text=stripped)
================================================
FILE: researchclaw/voice/synthesizer.py
================================================
"""Text-to-speech synthesis."""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
class VoiceSynthesizer:
"""Convert text to speech audio."""
def __init__(self, server_config: Any) -> None:
self._config = server_config
async def synthesize(
self,
text: str,
voice: str = "alloy",
speed: float = 1.0,
) -> bytes:
"""Synthesize text to audio bytes using OpenAI TTS API."""
try:
import httpx
except ImportError:
raise RuntimeError("httpx required for TTS")
import os
api_key = os.environ.get("OPENAI_API_KEY", "")
if not api_key:
raise RuntimeError("OPENAI_API_KEY not set for TTS")
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
"https://api.openai.com/v1/audio/speech",
headers={"Authorization": f"Bearer {api_key}"},
json={
"model": "tts-1",
"input": text,
"voice": voice,
"speed": speed,
},
)
response.raise_for_status()
return response.content
================================================
FILE: researchclaw/voice/transcriber.py
================================================
"""Voice transcription via Whisper API."""
from __future__ import annotations
import logging
from typing import Any, AsyncIterator
logger = logging.getLogger(__name__)
class VoiceTranscriber:
"""Transcribe audio to text using Whisper API."""
def __init__(self, server_config: Any) -> None:
self._model = server_config.whisper_model
self._api_url = server_config.whisper_api_url
async def transcribe(
self,
audio_bytes: bytes,
language: str = "zh",
) -> str:
"""Transcribe audio bytes to text.
Uses OpenAI Whisper API or compatible endpoint.
"""
try:
import httpx
except ImportError:
raise RuntimeError(
"httpx is required for voice transcription. "
"Install with: pip install httpx"
)
url = self._api_url or "https://api.openai.com/v1/audio/transcriptions"
import os
api_key = os.environ.get("OPENAI_API_KEY", "")
if not api_key:
raise RuntimeError("OPENAI_API_KEY not set for Whisper API")
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
url,
headers={"Authorization": f"Bearer {api_key}"},
files={"file": ("audio.webm", audio_bytes, "audio/webm")},
data={
"model": self._model,
"language": language,
},
)
response.raise_for_status()
result = response.json()
return result.get("text", "")
async def transcribe_stream(
self,
audio_stream: AsyncIterator[bytes],
language: str = "zh",
) -> AsyncIterator[str]:
"""Stream transcription (collects chunks then transcribes)."""
chunks: list[bytes] = []
async for chunk in audio_stream:
chunks.append(chunk)
if chunks:
full_audio = b"".join(chunks)
text = await self.transcribe(full_audio, language=language)
yield text
================================================
FILE: researchclaw/web/__init__.py
================================================
"""Web search, crawling, and content extraction layer.
Provides unified access to:
- **Crawl4AI**: Web page → Markdown extraction
- **Tavily**: AI-native web search API
- **scholarly**: Google Scholar search
- **PDF extraction**: Full-text from PDF files
Public API
----------
- ``WebSearchAgent`` — orchestrates all web capabilities
- ``WebCrawler`` — Crawl4AI wrapper
- ``WebSearchClient`` — Tavily search wrapper
- ``GoogleScholarClient`` — scholarly wrapper
- ``PDFExtractor`` — PDF text extraction
- ``check_url_ssrf`` — SSRF validation for URLs
"""
from researchclaw.web._ssrf import check_url_ssrf
from researchclaw.web.crawler import WebCrawler
from researchclaw.web.search import WebSearchClient
from researchclaw.web.scholar import GoogleScholarClient
from researchclaw.web.pdf_extractor import PDFExtractor
from researchclaw.web.agent import WebSearchAgent
__all__ = [
"check_url_ssrf",
"WebCrawler",
"WebSearchClient",
"GoogleScholarClient",
"PDFExtractor",
"WebSearchAgent",
]
================================================
FILE: researchclaw/web/_ssrf.py
================================================
"""SSRF validation for URLs fetched by the web layer."""
from __future__ import annotations
import ipaddress
import socket
from urllib.parse import urlparse
def check_url_ssrf(url: str) -> str | None:
"""Return an error message if *url* targets a private/internal host.
Validates scheme (http/https only) and resolves the hostname to check
against all RFC 1918, loopback, link-local, and reserved IP ranges
using :func:`ipaddress.ip_address`.
Returns ``None`` if the URL is safe to fetch.
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return f"Unsupported URL scheme: {parsed.scheme}"
hostname = parsed.hostname or ""
if not hostname:
return "URL has no hostname"
# Try parsing hostname as a literal IP address first
try:
addr = ipaddress.ip_address(hostname)
except ValueError:
# It's a domain name — resolve to IP
try:
info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
addr = ipaddress.ip_address(info[0][4][0])
except (socket.gaierror, OSError, IndexError):
# Can't resolve — let the actual request fail naturally
return None
if addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_reserved:
return f"Blocked internal/private URL: {hostname}"
return None
================================================
FILE: researchclaw/web/agent.py
================================================
"""Unified Web Search Agent.
Orchestrates all web capabilities (Tavily, Google Scholar, Crawl4AI,
PDF extraction) into a single search-and-extract pipeline.
Usage::
agent = WebSearchAgent()
result = agent.search_and_extract(
topic="knowledge distillation for vision transformers",
search_queries=["knowledge distillation survey", "ViT compression"],
)
# result.papers — Google Scholar papers
# result.web_results — Tavily/DDG web search results
# result.crawled_pages — full-text from crawled URLs
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from researchclaw.web.crawler import CrawlResult, WebCrawler
from researchclaw.web.pdf_extractor import PDFContent, PDFExtractor
from researchclaw.web.scholar import GoogleScholarClient, ScholarPaper
from researchclaw.web.search import SearchResult, WebSearchClient, WebSearchResponse
logger = logging.getLogger(__name__)
@dataclass
class WebSearchAgentResult:
"""Combined result from all web search sources."""
topic: str
web_results: list[SearchResult] = field(default_factory=list)
scholar_papers: list[ScholarPaper] = field(default_factory=list)
crawled_pages: list[CrawlResult] = field(default_factory=list)
pdf_extractions: list[PDFContent] = field(default_factory=list)
search_answer: str = "" # Tavily AI answer if available
elapsed_seconds: float = 0.0
@property
def total_results(self) -> int:
return (
len(self.web_results)
+ len(self.scholar_papers)
+ len(self.crawled_pages)
+ len(self.pdf_extractions)
)
def to_context_string(self, *, max_length: int = 30_000) -> str:
"""Convert all results to a single context string for LLM injection.
The output is structured Markdown suitable for prompt injection.
"""
parts: list[str] = []
# Tavily AI answer
if self.search_answer:
parts.append("## AI Search Summary")
parts.append(self.search_answer)
parts.append("")
# Web search results
if self.web_results:
parts.append("## Web Search Results")
for i, r in enumerate(self.web_results[:15], 1):
parts.append(f"### [{i}] {r.title}")
parts.append(f"URL: {r.url}")
if r.snippet:
parts.append(r.snippet)
parts.append("")
# Google Scholar papers
if self.scholar_papers:
parts.append("## Google Scholar Papers")
for i, p in enumerate(self.scholar_papers[:10], 1):
authors = ", ".join(p.authors[:3])
if len(p.authors) > 3:
authors += " et al."
parts.append(
f"- **{p.title}** ({authors}, {p.year}) "
f"[{p.citation_count} citations]"
)
if p.abstract:
parts.append(f" {p.abstract[:200]}...")
parts.append("")
# Crawled page content
if self.crawled_pages:
parts.append("## Crawled Page Content")
for cr in self.crawled_pages:
if cr.has_content:
parts.append(f"### {cr.title or cr.url}")
parts.append(cr.markdown[:3000])
parts.append("")
# PDF extractions
if self.pdf_extractions:
parts.append("## PDF Full-Text Extractions")
for pdf in self.pdf_extractions:
if pdf.has_content:
label = pdf.title or pdf.path
parts.append(f"### {label}")
if pdf.abstract:
parts.append(f"**Abstract:** {pdf.abstract}")
parts.append(pdf.text[:3000])
parts.append("")
result = "\n".join(parts)
if len(result) > max_length:
result = result[:max_length] + "\n\n[... truncated]"
return result
def to_dict(self) -> dict[str, Any]:
"""Serialize to dict for JSON output."""
return {
"topic": self.topic,
"web_results_count": len(self.web_results),
"scholar_papers_count": len(self.scholar_papers),
"crawled_pages_count": len(self.crawled_pages),
"pdf_extractions_count": len(self.pdf_extractions),
"has_search_answer": bool(self.search_answer),
"elapsed_seconds": self.elapsed_seconds,
"web_results": [r.to_dict() for r in self.web_results[:20]],
"scholar_papers": [p.to_dict() for p in self.scholar_papers[:20]],
}
class WebSearchAgent:
"""Orchestrates all web search and content extraction capabilities.
Parameters
----------
tavily_api_key:
Tavily API key (optional, falls back to env var or DuckDuckGo).
enable_scholar:
Whether to include Google Scholar search.
enable_crawling:
Whether to crawl top URLs for full content.
enable_pdf:
Whether to extract PDF content.
max_web_results:
Maximum web search results.
max_scholar_results:
Maximum Google Scholar results.
max_crawl_urls:
Maximum URLs to crawl for full content.
"""
def __init__(
self,
*,
tavily_api_key: str = "",
enable_scholar: bool = True,
enable_crawling: bool = True,
enable_pdf: bool = True,
max_web_results: int = 10,
max_scholar_results: int = 10,
max_crawl_urls: int = 5,
) -> None:
self.web_client = WebSearchClient(api_key=tavily_api_key)
try:
self.scholar_client = GoogleScholarClient()
except ImportError:
self.scholar_client = None # type: ignore[assignment]
self.crawler = WebCrawler()
self.pdf_extractor = PDFExtractor()
self.enable_scholar = enable_scholar
self.enable_crawling = enable_crawling
self.enable_pdf = enable_pdf
self.max_web_results = max_web_results
self.max_scholar_results = max_scholar_results
self.max_crawl_urls = max_crawl_urls
def search_and_extract(
self,
topic: str,
*,
search_queries: list[str] | None = None,
crawl_urls: list[str] | None = None,
pdf_urls: list[str] | None = None,
) -> WebSearchAgentResult:
"""Run the full search + extraction pipeline.
Parameters
----------
topic:
Research topic string.
search_queries:
Custom search queries. If None, auto-generates from topic.
crawl_urls:
Specific URLs to crawl. If None, crawls top search result URLs.
pdf_urls:
Specific PDF URLs to extract. If None, extracts PDFs from search.
"""
t0 = time.monotonic()
result = WebSearchAgentResult(topic=topic)
# 1. Generate search queries if not provided
if search_queries is None:
search_queries = self._generate_queries(topic)
# 2. Web search (Tavily / DuckDuckGo)
self._run_web_search(result, search_queries)
# 3. Google Scholar search
if self.enable_scholar and self.scholar_client and self.scholar_client.available:
self._run_scholar_search(result, topic)
# 4. Crawl top URLs for full content
if self.enable_crawling:
urls_to_crawl = crawl_urls or self._select_urls_to_crawl(result)
if urls_to_crawl:
self._run_crawling(result, urls_to_crawl)
# 5. Extract PDFs
if self.enable_pdf:
pdf_targets = pdf_urls or self._find_pdf_urls(result)
if pdf_targets:
self._run_pdf_extraction(result, pdf_targets)
result.elapsed_seconds = time.monotonic() - t0
logger.info(
"[WebSearchAgent] Done: %d web, %d scholar, %d crawled, %d PDFs (%.1fs)",
len(result.web_results),
len(result.scholar_papers),
len(result.crawled_pages),
len(result.pdf_extractions),
result.elapsed_seconds,
)
return result
# ------------------------------------------------------------------
# Pipeline steps
# ------------------------------------------------------------------
def _run_web_search(
self, result: WebSearchAgentResult, queries: list[str]
) -> None:
"""Run web search across all queries."""
try:
responses = self.web_client.search_multi(
queries, max_results=self.max_web_results
)
for resp in responses:
result.web_results.extend(resp.results)
if resp.answer and not result.search_answer:
result.search_answer = resp.answer
except Exception as exc: # noqa: BLE001
logger.warning("Web search failed: %s", exc)
def _run_scholar_search(
self, result: WebSearchAgentResult, topic: str
) -> None:
"""Run Google Scholar search."""
try:
papers = self.scholar_client.search(
topic, limit=self.max_scholar_results
)
result.scholar_papers.extend(papers)
except Exception as exc: # noqa: BLE001
logger.warning("Scholar search failed: %s", exc)
def _run_crawling(
self, result: WebSearchAgentResult, urls: list[str]
) -> None:
"""Crawl URLs for full content."""
try:
loop = None
try:
loop = asyncio.get_running_loop()
except RuntimeError:
pass
if loop and loop.is_running():
# We're inside an async context — use sync fallback
for url in urls[: self.max_crawl_urls]:
cr = self.crawler.crawl_sync(url)
if cr.has_content:
result.crawled_pages.append(cr)
else:
crawl_results = asyncio.run(
self.crawler.crawl_many(urls[: self.max_crawl_urls])
)
result.crawled_pages.extend(
cr for cr in crawl_results if cr.has_content
)
except Exception as exc: # noqa: BLE001
logger.warning("Crawling failed: %s", exc)
def _run_pdf_extraction(
self, result: WebSearchAgentResult, urls: list[str]
) -> None:
"""Extract text from PDF URLs."""
for url in urls[:5]:
try:
pdf = self.pdf_extractor.extract_from_url(url)
if pdf.has_content:
result.pdf_extractions.append(pdf)
except Exception as exc: # noqa: BLE001
logger.warning("PDF extraction failed for %s: %s", url, exc)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _generate_queries(topic: str) -> list[str]:
"""Generate search queries from a topic string."""
queries = [
topic,
f"{topic} survey",
f"{topic} benchmark state of the art",
]
return queries
def _select_urls_to_crawl(self, result: WebSearchAgentResult) -> list[str]:
"""Select top URLs from search results for crawling."""
urls = []
seen = set()
for r in result.web_results:
if r.url and r.url not in seen:
# Skip PDF URLs (handled separately) and common non-content sites
if r.url.endswith(".pdf"):
continue
seen.add(r.url)
urls.append(r.url)
if len(urls) >= self.max_crawl_urls:
break
return urls
@staticmethod
def _find_pdf_urls(result: WebSearchAgentResult) -> list[str]:
"""Find PDF URLs from search results."""
pdf_urls = []
seen = set()
for r in result.web_results:
if r.url and r.url.endswith(".pdf") and r.url not in seen:
seen.add(r.url)
pdf_urls.append(r.url)
if len(pdf_urls) >= 3:
break
return pdf_urls
================================================
FILE: researchclaw/web/crawler.py
================================================
"""Web page → Markdown extraction powered by Crawl4AI.
Crawl4AI is the primary extraction engine (installed as a dependency).
A lightweight urllib fallback exists for environments where Crawl4AI's
browser dependency is not set up.
Usage::
crawler = WebCrawler()
result = await crawler.crawl("https://arxiv.org/abs/2301.00001")
print(result.markdown)
"""
from __future__ import annotations
import asyncio
import logging
import re
import time
from dataclasses import dataclass, field
from typing import Any
from urllib.request import Request, urlopen
from researchclaw.web._ssrf import check_url_ssrf
logger = logging.getLogger(__name__)
@dataclass
class CrawlResult:
"""Result of crawling a single URL."""
url: str
markdown: str = ""
title: str = ""
success: bool = False
error: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
elapsed_seconds: float = 0.0
@property
def has_content(self) -> bool:
return bool(self.markdown and len(self.markdown.strip()) > 50)
class WebCrawler:
"""Web page → Markdown crawler powered by Crawl4AI.
Parameters
----------
timeout:
Request timeout in seconds.
max_content_length:
Maximum content length in characters (truncate beyond this).
"""
def __init__(
self,
*,
timeout: int = 30,
max_content_length: int = 50_000,
user_agent: str = "ResearchClaw/0.5 (Academic Research Bot)",
) -> None:
self.timeout = timeout
self.max_content_length = max_content_length
self.user_agent = user_agent
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def crawl(self, url: str) -> CrawlResult:
"""Crawl a URL and return Markdown content (async)."""
err = check_url_ssrf(url)
if err:
return CrawlResult(url=url, success=False, error=err, elapsed_seconds=0.0)
t0 = time.monotonic()
try:
return await self._crawl_with_crawl4ai(url, t0)
except Exception as exc: # noqa: BLE001
logger.debug("Crawl4AI failed for %s (%s), trying urllib fallback", url, exc)
try:
return self._crawl_with_urllib(url, t0)
except Exception as exc2: # noqa: BLE001
elapsed = time.monotonic() - t0
logger.warning("All crawl backends failed for %s: %s", url, exc2)
return CrawlResult(url=url, success=False, error=str(exc2), elapsed_seconds=elapsed)
def crawl_sync(self, url: str) -> CrawlResult:
"""Synchronous crawl — tries Crawl4AI via asyncio.run, falls back to urllib."""
err = check_url_ssrf(url)
if err:
return CrawlResult(url=url, success=False, error=err, elapsed_seconds=0.0)
t0 = time.monotonic()
try:
return asyncio.run(self._crawl_with_crawl4ai(url, t0))
except Exception: # noqa: BLE001
try:
return self._crawl_with_urllib(url, t0)
except Exception as exc: # noqa: BLE001
elapsed = time.monotonic() - t0
return CrawlResult(url=url, success=False, error=str(exc), elapsed_seconds=elapsed)
async def crawl_many(self, urls: list[str]) -> list[CrawlResult]:
"""Crawl multiple URLs using Crawl4AI's async engine."""
results = []
try:
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, BrowserConfig
browser_config = BrowserConfig(headless=True)
run_config = CrawlerRunConfig(
word_count_threshold=10,
excluded_tags=["nav", "footer", "header", "sidebar"],
remove_overlay_elements=True,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
for url in urls:
err = check_url_ssrf(url)
if err:
results.append(CrawlResult(url=url, success=False, error=err, elapsed_seconds=0.0))
continue
t0 = time.monotonic()
try:
raw = await crawler.arun(url=url, config=run_config)
elapsed = time.monotonic() - t0
if raw.success:
md = self._extract_markdown(raw)
results.append(CrawlResult(
url=url, markdown=md,
title=getattr(raw, "title", "") or "",
success=True, elapsed_seconds=elapsed,
metadata=raw.metadata if hasattr(raw, "metadata") and raw.metadata else {},
))
else:
results.append(CrawlResult(
url=url, success=False,
error=getattr(raw, "error_message", "crawl failed"),
elapsed_seconds=elapsed,
))
except Exception as exc: # noqa: BLE001
elapsed = time.monotonic() - t0
results.append(CrawlResult(url=url, success=False, error=str(exc), elapsed_seconds=elapsed))
except ImportError:
# Crawl4AI browser not set up — use urllib for each
for url in urls:
err = check_url_ssrf(url)
if err:
results.append(CrawlResult(url=url, success=False, error=err, elapsed_seconds=0.0))
continue
t0 = time.monotonic()
try:
results.append(self._crawl_with_urllib(url, t0))
except Exception as exc: # noqa: BLE001
elapsed = time.monotonic() - t0
results.append(CrawlResult(url=url, success=False, error=str(exc), elapsed_seconds=elapsed))
return results
# ------------------------------------------------------------------
# Crawl4AI backend (primary)
# ------------------------------------------------------------------
async def _crawl_with_crawl4ai(self, url: str, t0: float) -> CrawlResult:
"""Use Crawl4AI for high-quality extraction."""
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, BrowserConfig
browser_config = BrowserConfig(headless=True)
run_config = CrawlerRunConfig(
word_count_threshold=10,
excluded_tags=["nav", "footer", "header", "sidebar"],
remove_overlay_elements=True,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
raw = await crawler.arun(url=url, config=run_config)
elapsed = time.monotonic() - t0
if raw.success:
md = self._extract_markdown(raw)
return CrawlResult(
url=url, markdown=md,
title=getattr(raw, "title", "") or "",
success=True, elapsed_seconds=elapsed,
metadata=raw.metadata if hasattr(raw, "metadata") and raw.metadata else {},
)
return CrawlResult(
url=url, success=False,
error=getattr(raw, "error_message", "Unknown crawl4ai error"),
elapsed_seconds=elapsed,
)
def _extract_markdown(self, raw: Any) -> str:
"""Extract markdown from a Crawl4AI result object."""
# Crawl4AI v0.8+ uses markdown_v2.raw_markdown
md = ""
if hasattr(raw, "markdown_v2") and raw.markdown_v2:
md = getattr(raw.markdown_v2, "raw_markdown", "") or ""
if not md and hasattr(raw, "markdown"):
md = raw.markdown or ""
if len(md) > self.max_content_length:
md = md[: self.max_content_length] + "\n\n[... truncated]"
return md
# ------------------------------------------------------------------
# urllib fallback (lightweight, no browser needed)
# ------------------------------------------------------------------
def _crawl_with_urllib(self, url: str, t0: float) -> CrawlResult:
"""Lightweight fallback: fetch HTML and strip tags."""
req = Request(url, headers={"User-Agent": self.user_agent})
resp = urlopen(req, timeout=self.timeout) # noqa: S310
content_type = resp.headers.get("Content-Type", "")
raw = resp.read()
encoding = "utf-8"
if "charset=" in content_type:
encoding = content_type.split("charset=")[-1].split(";")[0].strip()
html = raw.decode(encoding, errors="replace")
title_match = re.search(r"]*>(.*?) ", html, re.DOTALL | re.IGNORECASE)
title = title_match.group(1).strip() if title_match else ""
markdown = self._html_to_markdown(html)
if len(markdown) > self.max_content_length:
markdown = markdown[: self.max_content_length] + "\n\n[... truncated]"
elapsed = time.monotonic() - t0
return CrawlResult(
url=url, markdown=markdown, title=title,
success=bool(markdown.strip()), elapsed_seconds=elapsed,
)
@staticmethod
def _html_to_markdown(html: str) -> str:
"""Best-effort HTML → Markdown conversion via regex."""
text = re.sub(r"<(script|style|noscript)[^>]*>.*?\1>", "", html, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"]*>(.*?)
", r"\n# \1\n", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"]*>(.*?)
", r"\n## \1\n", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"]*>(.*?)
", r"\n### \1\n", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"]*>(.*?) ", r"\n- \1", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"]*>(.*?)
", r"\n\1\n", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"
", "\n", text, flags=re.IGNORECASE)
text = re.sub(r"]*href=[\"']([^\"']*)[\"'][^>]*>(.*?)", r"[\2](\1)", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<[^>]+>", "", text)
import html as _html
text = _html.unescape(text)
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r" {2,}", " ", text)
return text.strip()
================================================
FILE: researchclaw/web/pdf_extractor.py
================================================
"""PDF full-text extraction powered by PyMuPDF (fitz).
PyMuPDF is installed as a dependency and provides fast, high-quality
PDF text extraction with metadata, section detection, and table support.
Usage::
extractor = PDFExtractor()
result = extractor.extract("/path/to/paper.pdf")
print(result.text[:1000])
"""
from __future__ import annotations
import logging
import re
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from urllib.request import Request, urlopen
from researchclaw.web._ssrf import check_url_ssrf
try:
import fitz # PyMuPDF
HAS_FITZ = True
except ImportError:
fitz = None # type: ignore[assignment]
HAS_FITZ = False
logger = logging.getLogger(__name__)
@dataclass
class PDFContent:
"""Extracted content from a PDF file."""
path: str
text: str = ""
title: str = ""
authors: list[str] = field(default_factory=list)
abstract: str = ""
sections: list[dict[str, str]] = field(default_factory=list)
page_count: int = 0
success: bool = False
error: str = ""
backend: str = "pymupdf"
metadata: dict[str, Any] = field(default_factory=dict)
@property
def has_content(self) -> bool:
return bool(self.text and len(self.text.strip()) > 100)
class PDFExtractor:
"""PDF text extraction using PyMuPDF.
Parameters
----------
max_pages:
Maximum pages to extract (0 = all).
extract_sections:
Whether to attempt section boundary detection.
"""
def __init__(
self,
*,
max_pages: int = 0,
extract_sections: bool = True,
) -> None:
self.max_pages = max_pages
self.extract_sections = extract_sections
@property
def backend(self) -> str:
return "pymupdf"
def extract(self, path: str | Path) -> PDFContent:
"""Extract text from a local PDF file using PyMuPDF."""
if not HAS_FITZ:
return PDFContent(
path=str(path),
error="PyMuPDF not installed. Install: pip install 'researchclaw[pdf]'",
)
path = Path(path)
try:
_exists = path.exists()
except (PermissionError, OSError):
_exists = False
if not _exists:
return PDFContent(path=str(path), error=f"File not found: {path}")
try:
with fitz.open(str(path)) as doc:
pages_to_read = doc.page_count
if self.max_pages > 0:
pages_to_read = min(pages_to_read, self.max_pages)
all_text = []
for i in range(pages_to_read):
page = doc[i]
all_text.append(page.get_text())
full_text = "\n".join(all_text)
meta = doc.metadata or {}
title = meta.get("title", "")
author = meta.get("author", "")
authors = [a.strip() for a in author.split(",")] if author else []
abstract = self._extract_abstract(full_text)
sections = self._detect_sections(full_text) if self.extract_sections else []
page_count = doc.page_count
return PDFContent(
path=str(path),
text=full_text,
title=title,
authors=authors,
abstract=abstract,
sections=sections,
page_count=page_count,
success=True,
metadata=meta,
)
except Exception as exc: # noqa: BLE001
logger.warning("PDF extraction failed for %s: %s", path, exc)
return PDFContent(path=str(path), error=str(exc))
def extract_from_url(self, url: str) -> PDFContent:
"""Download a PDF from URL and extract text."""
err = check_url_ssrf(url)
if err:
return PDFContent(path=url, error=err)
tmp_path = None
try:
req = Request(url, headers={
"User-Agent": "ResearchClaw/0.5 (Academic Research Bot)"
})
resp = urlopen(req, timeout=30) # noqa: S310
data = resp.read()
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(data)
tmp_path = f.name
result = self.extract(tmp_path)
result.path = url
return result
except Exception as exc: # noqa: BLE001
logger.warning("PDF download failed for %s: %s", url, exc)
return PDFContent(path=url, error=str(exc))
finally:
if tmp_path:
Path(tmp_path).unlink(missing_ok=True)
# ------------------------------------------------------------------
# Section detection
# ------------------------------------------------------------------
@staticmethod
def _extract_abstract(text: str) -> str:
"""Extract abstract from paper text."""
match = re.search(
r"(?:^|\n)\s*Abstract\s*\n(.*?)(?=\n\s*(?:\d+\.?\s+)?(?:Introduction|1\s))",
text, re.DOTALL | re.IGNORECASE,
)
if match:
return match.group(1).strip()
match = re.search(
r"(?:^|\n)\s*Abstract[:\s]*\n?(.*?)(?:\n\n|\n\s*\n)",
text, re.DOTALL | re.IGNORECASE,
)
if match:
return match.group(1).strip()
return ""
@staticmethod
def _detect_sections(text: str) -> list[dict[str, str]]:
"""Detect section boundaries in paper text."""
sections: list[dict[str, str]] = []
pattern = re.compile(r"(?:^|\n)\s*(\d+\.?\s+[A-Z][^\n]{2,50})\s*\n", re.MULTILINE)
matches = list(pattern.finditer(text))
for i, match in enumerate(matches):
heading = match.group(1).strip()
start = match.end()
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
body = text[start:end].strip()
sections.append({"heading": heading, "text": body[:5000]})
return sections
================================================
FILE: researchclaw/web/scholar.py
================================================
"""Google Scholar search powered by the ``scholarly`` library.
scholarly is installed as a dependency and provides direct access to
Google Scholar search, citation graph traversal, and author lookup.
Usage::
client = GoogleScholarClient()
papers = client.search("attention is all you need", limit=5)
citing = client.get_citations(papers[0].scholar_id, limit=10)
"""
from __future__ import annotations
import hashlib
import logging
import time
from dataclasses import dataclass, field
from typing import Any
try:
from scholarly import scholarly, ProxyGenerator
HAS_SCHOLARLY = True
except ImportError:
scholarly = None # type: ignore[assignment]
ProxyGenerator = None # type: ignore[assignment,misc]
HAS_SCHOLARLY = False
logger = logging.getLogger(__name__)
@dataclass
class ScholarPaper:
"""A paper result from Google Scholar."""
title: str
authors: list[str] = field(default_factory=list)
year: int = 0
abstract: str = ""
citation_count: int = 0
url: str = ""
scholar_id: str = ""
venue: str = ""
source: str = "google_scholar"
def to_dict(self) -> dict[str, Any]:
return {
"title": self.title,
"authors": self.authors,
"year": self.year,
"abstract": self.abstract,
"citation_count": self.citation_count,
"url": self.url,
"scholar_id": self.scholar_id,
"venue": self.venue,
"source": self.source,
}
def to_literature_paper(self) -> Any:
"""Convert to researchclaw.literature.models.Paper."""
from researchclaw.literature.models import Author, Paper
authors_tuple = tuple(Author(name=a) for a in self.authors)
return Paper(
paper_id=self.scholar_id or f"gs-{hashlib.sha256(self.title.encode()).hexdigest()[:8]}",
title=self.title,
authors=authors_tuple,
year=self.year,
abstract=self.abstract,
venue=self.venue,
citation_count=self.citation_count,
url=self.url,
source="google_scholar",
)
class GoogleScholarClient:
"""Google Scholar search client using the ``scholarly`` library.
Parameters
----------
inter_request_delay:
Seconds between requests to avoid rate limiting.
use_proxy:
Whether to set up a free proxy to reduce blocking risk.
"""
def __init__(
self,
*,
inter_request_delay: float = 2.0,
use_proxy: bool = False,
) -> None:
if not HAS_SCHOLARLY:
raise ImportError(
"scholarly is required for Google Scholar search. "
"Install: pip install 'researchclaw[web]'"
)
self.delay = inter_request_delay
self._last_request_time: float = 0.0
if use_proxy:
try:
pg = ProxyGenerator()
pg.FreeProxies()
scholarly.use_proxy(pg)
logger.info("Google Scholar: proxy enabled")
except Exception as exc: # noqa: BLE001
logger.warning("Failed to set up proxy: %s", exc)
@property
def available(self) -> bool:
"""Always True — scholarly is installed as a dependency."""
return True
def search(self, query: str, *, limit: int = 10) -> list[ScholarPaper]:
"""Search Google Scholar for papers matching query."""
self._rate_limit()
results: list[ScholarPaper] = []
try:
search_gen = scholarly.search_pubs(query)
for i, pub in enumerate(search_gen):
if i >= limit:
break
results.append(self._parse_pub(pub))
if i < limit - 1:
self._rate_limit()
logger.info("Google Scholar: found %d papers for %r", len(results), query)
except Exception as exc: # noqa: BLE001
logger.warning("Google Scholar search failed: %s", exc)
return results
def get_citations(self, scholar_id: str, *, limit: int = 20) -> list[ScholarPaper]:
"""Get papers that cite the given paper (citation graph traversal)."""
self._rate_limit()
results: list[ScholarPaper] = []
try:
pub = scholarly.search_single_pub(scholar_id)
if pub:
citations = scholarly.citedby(pub)
for i, cit in enumerate(citations):
if i >= limit:
break
results.append(self._parse_pub(cit))
if i < limit - 1:
self._rate_limit()
logger.info("Google Scholar: found %d citations for %s", len(results), scholar_id)
except Exception as exc: # noqa: BLE001
logger.warning("Citation retrieval failed for %s: %s", scholar_id, exc)
return results
def search_author(self, name: str) -> list[dict[str, Any]]:
"""Search for an author on Google Scholar."""
self._rate_limit()
try:
results = []
for author in scholarly.search_author(name):
results.append({
"name": author.get("name", ""),
"affiliation": author.get("affiliation", ""),
"scholar_id": author.get("scholar_id", ""),
"citedby": author.get("citedby", 0),
"interests": author.get("interests", []),
})
if len(results) >= 5:
break
return results
except Exception as exc: # noqa: BLE001
logger.warning("Author search failed for %s: %s", name, exc)
return []
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _rate_limit(self) -> None:
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.delay:
time.sleep(self.delay - elapsed)
self._last_request_time = time.monotonic()
@staticmethod
def _parse_pub(pub: Any) -> ScholarPaper:
"""Parse a scholarly publication object into ScholarPaper."""
bib = pub.get("bib", {}) if isinstance(pub, dict) else getattr(pub, "bib", {})
info = pub if isinstance(pub, dict) else pub.__dict__ if hasattr(pub, "__dict__") else {}
authors = bib.get("author", [])
if isinstance(authors, str):
authors = [a.strip() for a in authors.split(" and ")]
year = 0
year_raw = bib.get("pub_year", bib.get("year", 0))
try:
year = int(year_raw)
except (ValueError, TypeError):
pass
cites_id = info.get("cites_id", [])
scholar_id = info.get("author_pub_id", "") or (
cites_id[0] if isinstance(cites_id, list) and cites_id else ""
)
return ScholarPaper(
title=bib.get("title", ""),
authors=authors,
year=year,
abstract=bib.get("abstract", ""),
citation_count=info.get("num_citations", 0),
url=info.get("pub_url", info.get("eprint_url", "")),
scholar_id=scholar_id,
venue=bib.get("venue", bib.get("journal", "")),
)
================================================
FILE: researchclaw/web/search.py
================================================
"""Web search powered by Tavily AI Search API.
Tavily is the primary search engine (installed as a dependency).
A DuckDuckGo HTML scrape fallback exists for when no API key is set.
Usage::
client = WebSearchClient(api_key="tvly-...")
results = client.search("knowledge distillation survey 2024")
"""
from __future__ import annotations
import logging
import os
import re
import time
from dataclasses import dataclass, field
from typing import Any
from urllib.request import Request, urlopen
from urllib.parse import quote_plus
logger = logging.getLogger(__name__)
@dataclass
class SearchResult:
"""A single web search result."""
title: str
url: str
snippet: str = ""
content: str = ""
score: float = 0.0
source: str = "" # "tavily" | "duckduckgo"
def to_dict(self) -> dict[str, Any]:
return {
"title": self.title,
"url": self.url,
"snippet": self.snippet,
"content": self.content,
"score": self.score,
"source": self.source,
}
@dataclass
class WebSearchResponse:
"""Response from a web search query."""
query: str
results: list[SearchResult] = field(default_factory=list)
answer: str = "" # Tavily can provide a direct AI answer
elapsed_seconds: float = 0.0
source: str = "" # "tavily" | "duckduckgo"
@property
def has_results(self) -> bool:
return len(self.results) > 0
class WebSearchClient:
"""General-purpose web search client.
Uses Tavily (installed) as primary engine. Falls back to DuckDuckGo
HTML scraping only if no Tavily API key is available.
Parameters
----------
api_key:
Tavily API key. Falls back to ``TAVILY_API_KEY`` env var.
max_results:
Default number of results per query.
search_depth:
Tavily search depth: "basic" or "advanced".
include_answer:
Whether to request Tavily's AI-generated answer.
"""
def __init__(
self,
*,
api_key: str = "",
max_results: int = 10,
search_depth: str = "advanced",
include_answer: bool = True,
) -> None:
self.api_key = api_key or os.environ.get("TAVILY_API_KEY", "")
self.max_results = max_results
self.search_depth = search_depth
self.include_answer = include_answer
def search(
self,
query: str,
*,
max_results: int | None = None,
include_domains: list[str] | None = None,
exclude_domains: list[str] | None = None,
) -> WebSearchResponse:
"""Search the web for a query."""
limit = max_results or self.max_results
t0 = time.monotonic()
# Tavily is the primary engine
if self.api_key:
try:
return self._search_tavily(query, limit, include_domains, exclude_domains, t0)
except Exception as exc: # noqa: BLE001
logger.warning("Tavily search failed, falling back to DuckDuckGo: %s", exc)
return self._search_duckduckgo(query, limit, t0)
def search_multi(
self,
queries: list[str],
*,
max_results: int | None = None,
inter_query_delay: float = 1.0,
) -> list[WebSearchResponse]:
"""Run multiple search queries with cross-query deduplication."""
responses = []
seen_urls: set[str] = set()
for i, query in enumerate(queries):
if i > 0:
time.sleep(inter_query_delay)
resp = self.search(query, max_results=max_results)
unique_results = [r for r in resp.results if r.url not in seen_urls]
seen_urls.update(r.url for r in unique_results)
resp.results = unique_results
responses.append(resp)
return responses
# ------------------------------------------------------------------
# Tavily backend (primary — uses installed tavily-python SDK)
# ------------------------------------------------------------------
def _search_tavily(
self,
query: str,
limit: int,
include_domains: list[str] | None,
exclude_domains: list[str] | None,
t0: float,
) -> WebSearchResponse:
"""Search using Tavily API (installed SDK)."""
from tavily import TavilyClient
client = TavilyClient(api_key=self.api_key)
kwargs: dict[str, Any] = {
"query": query,
"max_results": limit,
"search_depth": self.search_depth,
"include_answer": self.include_answer,
}
if include_domains:
kwargs["include_domains"] = include_domains
if exclude_domains:
kwargs["exclude_domains"] = exclude_domains
response = client.search(**kwargs)
elapsed = time.monotonic() - t0
results = []
for item in response.get("results", []):
results.append(SearchResult(
title=item.get("title", ""),
url=item.get("url", ""),
snippet=item.get("content", "")[:500],
content=item.get("content", ""),
score=item.get("score", 0.0),
source="tavily",
))
return WebSearchResponse(
query=query,
results=results,
answer=response.get("answer", ""),
elapsed_seconds=elapsed,
source="tavily",
)
# ------------------------------------------------------------------
# DuckDuckGo fallback (no API key needed)
# ------------------------------------------------------------------
def _search_duckduckgo(
self, query: str, limit: int, t0: float
) -> WebSearchResponse:
"""Fallback: scrape DuckDuckGo HTML search results."""
encoded = quote_plus(query)
url = f"https://html.duckduckgo.com/html/?q={encoded}"
req = Request(url, headers={
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
})
try:
resp = urlopen(req, timeout=15) # noqa: S310
html = resp.read().decode("utf-8", errors="replace")
except Exception as exc: # noqa: BLE001
elapsed = time.monotonic() - t0
logger.warning("DuckDuckGo search failed: %s", exc)
return WebSearchResponse(query=query, elapsed_seconds=elapsed, source="duckduckgo")
results = self._parse_ddg_html(html, limit)
elapsed = time.monotonic() - t0
return WebSearchResponse(query=query, results=results, elapsed_seconds=elapsed, source="duckduckgo")
@staticmethod
def _parse_ddg_html(html: str, limit: int) -> list[SearchResult]:
"""Parse DuckDuckGo HTML results page."""
results = []
link_pattern = re.compile(
r']*class="result__a"[^>]*href="([^"]*)"[^>]*>(.*?)', re.DOTALL,
)
snippet_pattern = re.compile(
r']*class="result__snippet"[^>]*>(.*?)', re.DOTALL,
)
links = link_pattern.findall(html)
snippets = snippet_pattern.findall(html)
for i, (url, title_html) in enumerate(links[:limit]):
title = re.sub(r"<[^>]+>", "", title_html).strip()
snippet = re.sub(r"<[^>]+>", "", snippets[i]).strip() if i < len(snippets) else ""
if "duckduckgo.com" in url:
# Extract actual URL from DDG redirect: //duckduckgo.com/l/?uddg=https%3A...
from urllib.parse import urlparse as _urlparse, parse_qs as _parse_qs, unquote as _unquote
_parsed_ddg = _urlparse(url)
_uddg = _parse_qs(_parsed_ddg.query).get("uddg")
if _uddg:
url = _unquote(_uddg[0])
else:
continue
results.append(SearchResult(title=title, url=url, snippet=snippet, source="duckduckgo"))
return results
================================================
FILE: researchclaw/wizard/__init__.py
================================================
"""Setup wizard modules."""
================================================
FILE: researchclaw/wizard/quickstart.py
================================================
"""Quick-start interactive setup wizard."""
from __future__ import annotations
import sys
from typing import Any
from researchclaw.wizard.templates import TEMPLATES
class QuickStartWizard:
"""Interactive configuration generator."""
def run_interactive(self, template: str | None = None) -> dict[str, Any]:
"""CLI interactive wizard — returns a config dict."""
print("\n=== ResearchClaw Setup Wizard ===\n")
if template:
return self._apply_template(template)
config: dict[str, Any] = {}
# 1. Project name
name = self._ask("Project name", default="my-research")
config["project"] = {"name": name, "mode": "full-auto"}
# 2. Research topic
topic = self._ask("Research topic (describe in one sentence)")
if not topic:
print("Topic is required.")
return {}
config["research"] = {"topic": topic}
# 3. Research domain
domains_str = self._ask(
"Research domains (comma-separated: cv, nlp, rl, ml, ai4science)",
default="ml",
)
config["research"]["domains"] = [
d.strip() for d in domains_str.split(",") if d.strip()
]
# 4. Experiment mode
mode = self._choose(
"Experiment mode",
["simulated", "docker", "sandbox"],
default="docker",
)
config["experiment"] = {"mode": mode}
if mode == "docker":
gpu = self._ask_yn("Enable GPU?", default=True)
config["experiment"]["docker"] = {
"gpu_enabled": gpu,
"network_policy": "setup_only",
}
budget = self._ask("Time budget (seconds)", default="600")
config["experiment"]["time_budget_sec"] = int(budget)
# 5. LLM provider
print("\n--- LLM Configuration ---")
provider = self._choose(
"LLM provider",
["openai-compatible", "acp"],
default="openai-compatible",
)
config["llm"] = {"provider": provider}
if provider == "openai-compatible":
base_url = self._ask("API base URL", default="https://api.openai.com/v1")
api_key_env = self._ask("API key env var", default="OPENAI_API_KEY")
model = self._ask("Model name", default="gpt-4o")
config["llm"].update({
"base_url": base_url,
"api_key_env": api_key_env,
"primary_model": model,
})
# 6. Output format
conference = self._choose(
"Target conference format",
["neurips_2025", "iclr_2025", "icml_2025", "arxiv"],
default="neurips_2025",
)
config["export"] = {"target_conference": conference}
# 7. Runtime
config["runtime"] = {"timezone": "UTC"}
config["notifications"] = {"channel": "console"}
config["knowledge_base"] = {"backend": "markdown", "root": "knowledge"}
print("\n--- Configuration Summary ---")
self._print_summary(config)
confirm = self._ask_yn("\nSave this configuration?", default=True)
if not confirm:
print("Cancelled.")
return {}
return config
def run_web(self, steps: list[dict[str, Any]]) -> dict[str, Any]:
"""Process wizard steps from web interface."""
config: dict[str, Any] = {}
for step in steps:
key = step.get("key", "")
value = step.get("value", "")
if key == "project_name":
config.setdefault("project", {})["name"] = value
elif key == "topic":
config.setdefault("research", {})["topic"] = value
elif key == "mode":
config.setdefault("experiment", {})["mode"] = value
elif key == "model":
config.setdefault("llm", {})["primary_model"] = value
return config
def _apply_template(self, name: str) -> dict[str, Any]:
"""Apply a preset template."""
mapping = {
"quick": "quick-demo",
"standard": "standard-cv",
"advanced": "deep-nlp",
}
tpl_name = mapping.get(name, name)
tpl = TEMPLATES.get(tpl_name)
if not tpl:
print(f"Unknown template: {name}")
return {}
config = self._template_to_config(tpl)
print(f"Applied template: {tpl_name}")
print(f" Description: {tpl.get('description', '')}")
self._print_summary(config)
return config
def _template_to_config(self, tpl: dict[str, Any]) -> dict[str, Any]:
"""Convert a flat template to nested config dict."""
config: dict[str, Any] = {
"project": {"name": "wizard-project", "mode": "full-auto"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "console"},
"knowledge_base": {"backend": "markdown", "root": "knowledge"},
"research": {"topic": "Generated by wizard"},
"llm": {"provider": "openai-compatible", "api_key_env": "OPENAI_API_KEY"},
}
for key, value in tpl.items():
if key == "description":
continue
parts = key.split(".")
d = config
for p in parts[:-1]:
d = d.setdefault(p, {})
d[parts[-1]] = value
return config
def _ask(self, prompt: str, default: str = "") -> str:
suffix = f" [{default}]" if default else ""
try:
answer = input(f" {prompt}{suffix}: ").strip()
except (EOFError, KeyboardInterrupt):
print()
return default
return answer or default
def _ask_yn(self, prompt: str, default: bool = True) -> bool:
suffix = " [Y/n]" if default else " [y/N]"
try:
answer = input(f" {prompt}{suffix}: ").strip().lower()
except (EOFError, KeyboardInterrupt):
print()
return default
if not answer:
return default
return answer in ("y", "yes", "1", "true")
def _choose(
self,
prompt: str,
options: list[str],
default: str = "",
) -> str:
print(f" {prompt}:")
for i, opt in enumerate(options, 1):
marker = " *" if opt == default else ""
print(f" {i}. {opt}{marker}")
try:
answer = input(f" Choice [default={default}]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
return default
if not answer:
return default
try:
idx = int(answer) - 1
if 0 <= idx < len(options):
return options[idx]
except ValueError:
if answer in options:
return answer
return default
def _print_summary(self, config: dict[str, Any], indent: int = 2) -> None:
import yaml
print(yaml.dump(config, default_flow_style=False, allow_unicode=True))
================================================
FILE: researchclaw/wizard/templates.py
================================================
"""Preset research configuration templates."""
from __future__ import annotations
from typing import Any
TEMPLATES: dict[str, dict[str, Any]] = {
"quick-demo": {
"description": "5-minute quick demo (simulated mode, no GPU needed)",
"experiment.mode": "simulated",
"experiment.time_budget_sec": 60,
"experiment.max_iterations": 3,
},
"standard-cv": {
"description": "Standard Computer Vision paper (Docker + CIFAR-10)",
"research.domains": ["computer-vision"],
"experiment.mode": "docker",
"experiment.time_budget_sec": 600,
"experiment.docker.gpu_enabled": True,
"experiment.docker.network_policy": "setup_only",
},
"deep-nlp": {
"description": "Deep NLP research (Docker + GPU + transformers)",
"research.domains": ["nlp", "transformers"],
"experiment.mode": "docker",
"experiment.time_budget_sec": 1200,
"experiment.docker.gpu_enabled": True,
"experiment.docker.memory_limit_mb": 16384,
},
"rl-research": {
"description": "Reinforcement Learning research (Docker + custom env)",
"research.domains": ["reinforcement-learning"],
"experiment.mode": "docker",
"experiment.time_budget_sec": 900,
"experiment.docker.gpu_enabled": True,
},
"ai4science": {
"description": "AI for Science (large compute budget)",
"research.domains": ["ai4science"],
"experiment.mode": "docker",
"experiment.time_budget_sec": 1800,
"experiment.docker.gpu_enabled": True,
"experiment.docker.memory_limit_mb": 32768,
},
}
def get_template(name: str) -> dict[str, Any] | None:
"""Get a template by name."""
return TEMPLATES.get(name)
def list_templates() -> list[dict[str, str]]:
"""List all available templates with descriptions."""
return [
{"name": name, "description": tpl.get("description", "")}
for name, tpl in TEMPLATES.items()
]
================================================
FILE: researchclaw/wizard/validator.py
================================================
"""Environment detection and recommendation for the setup wizard."""
from __future__ import annotations
import os
import shutil
from dataclasses import dataclass, field
from typing import Any
@dataclass
class EnvironmentReport:
"""Report of detected environment capabilities."""
has_gpu: bool = False
gpu_name: str = ""
gpu_vram_gb: float = 0.0
has_docker: bool = False
docker_version: str = ""
has_python: bool = True
python_version: str = ""
has_latex: bool = False
available_memory_gb: float = 0.0
recommendations: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"has_gpu": self.has_gpu,
"gpu_name": self.gpu_name,
"gpu_vram_gb": self.gpu_vram_gb,
"has_docker": self.has_docker,
"docker_version": self.docker_version,
"has_python": self.has_python,
"python_version": self.python_version,
"has_latex": self.has_latex,
"available_memory_gb": round(self.available_memory_gb, 1),
"recommendations": self.recommendations,
}
def detect_environment() -> EnvironmentReport:
"""Detect local environment and generate recommendations."""
import sys
import subprocess
report = EnvironmentReport()
report.python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
# Docker
if shutil.which("docker"):
report.has_docker = True
try:
result = subprocess.run(
["docker", "--version"],
capture_output=True, text=True, timeout=5
)
report.docker_version = result.stdout.strip()
except Exception:
pass
# GPU
try:
import torch
if torch.cuda.is_available():
report.has_gpu = True
report.gpu_name = torch.cuda.get_device_name(0)
report.gpu_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except ImportError:
pass
# LaTeX
report.has_latex = shutil.which("pdflatex") is not None
# Memory
try:
import psutil
report.available_memory_gb = psutil.virtual_memory().available / (1024**3)
except ImportError:
pass
# Recommendations
if not report.has_docker:
report.recommendations.append(
"Install Docker for experiment isolation (recommended)"
)
if not report.has_gpu:
report.recommendations.append(
"No GPU detected — use 'simulated' mode or remote GPU server"
)
if not report.has_latex:
report.recommendations.append(
"Install LaTeX (texlive) for PDF paper export"
)
if report.has_gpu and report.has_docker:
report.recommendations.append(
"Environment ready for full Docker GPU experiments"
)
return report
================================================
FILE: researchclaw/writing_guide.py
================================================
"""Conference-grade writing knowledge base.
Structured tips from NeurIPS/ICML/ICLR best practices, reviewer feedback
analysis, and accepted paper patterns. Can be loaded and injected into
prompts at runtime, allowing updates without modifying prompt YAML.
"""
from __future__ import annotations
CONFERENCE_WRITING_TIPS: dict[str, list[str]] = {
"title": [
"Signal novelty — title should hint at what is new",
"Be specific and concrete, under 15 words",
"No abbreviations unless universally known",
"Pattern: '[Finding]: [Evidence]' or '[Method]: [What it does]'",
"Memeability test: would a reader enjoy telling a colleague about this?",
],
"abstract": [
"5-sentence structure: (1) problem, (2) prior approaches + limitations, "
"(3) your approach + novelty, (4) key results with numbers, (5) implication",
"150-250 words for ML conferences",
"Include at least 2 specific quantitative results",
],
"figure_1": [
"Most important figure in the paper — many readers look at Figure 1 first",
"Should convey the key idea or main result at a glance",
"Invest significant time in this figure",
],
"introduction": [
"State contributions clearly as bullet points",
"Many reviewers stop reading carefully after the intro",
"Include paper organization paragraph at the end",
],
"experiments": [
"Strong baselines: tune baselines with the same effort as your method",
"Ablations: remove one component at a time and measure the effect",
"Reproducibility: include hyperparameters, seeds, hardware specs",
"Statistical rigor: report variance, run multiple seeds",
],
"common_rejections": [
"Weak baselines (79% of rejected papers)",
"Missing ablations",
"Overclaiming beyond evidence",
"Poor reproducibility details",
"Ignoring limitations",
],
"rebuttal": [
"Start with positives reviewers identified",
"Quote reviewers directly, then respond",
"Provide new data/experiments rather than arguing",
"Do not promise — deliver",
],
}
def format_writing_tips(categories: list[str] | None = None) -> str:
"""Format writing tips as a prompt-injectable string.
Parameters
----------
categories:
Subset of tip categories to include. If *None*, include all.
Returns
-------
str
Formatted markdown-style tips block.
"""
lines: list[str] = ["## Conference Writing Best Practices"]
cats = categories or list(CONFERENCE_WRITING_TIPS.keys())
for cat in cats:
tips = CONFERENCE_WRITING_TIPS.get(cat, [])
if not tips:
continue
lines.append(f"\n### {cat.replace('_', ' ').title()}")
for tip in tips:
lines.append(f"- {tip}")
return "\n".join(lines)
================================================
FILE: scripts/metaclaw_start.sh
================================================
#!/bin/bash
# Start MetaClaw proxy for AutoResearchClaw integration.
#
# Usage:
# ./scripts/metaclaw_start.sh # skills_only mode (default)
# ./scripts/metaclaw_start.sh madmax # madmax mode (with RL training)
# ./scripts/metaclaw_start.sh skills_only # skills_only mode (explicit)
set -e
MODE="${1:-skills_only}"
PORT="${2:-30000}"
METACLAW_DIR="/home/jqliu/projects/MetaClaw"
VENV="$METACLAW_DIR/.venv"
if [ ! -d "$VENV" ]; then
echo "ERROR: MetaClaw venv not found at $VENV"
echo "Run: cd $METACLAW_DIR && python -m venv .venv && source .venv/bin/activate && pip install -e '.[evolve,embedding]'"
exit 1
fi
echo "Starting MetaClaw in ${MODE} mode on port ${PORT}..."
# Activate venv and start
source "$VENV/bin/activate"
exec metaclaw start --mode "$MODE" --port "$PORT"
================================================
FILE: scripts/plot_iteration_showcase.py
================================================
"""Generate promotional figure: Pipeline iterative improvement showcase.
Shows two experiment cases side-by-side demonstrating how the AutoResearchClaw
pipeline progressively improves experimental methods through self-iteration.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from pathlib import Path
# ── Styling ──────────────────────────────────────────────────────────────────
plt.rcParams.update({
"font.family": "serif",
"font.size": 11,
"axes.titlesize": 13,
"axes.labelsize": 11,
"figure.facecolor": "white",
"axes.facecolor": "#FAFAFA",
"axes.grid": True,
"grid.alpha": 0.3,
"grid.linestyle": "--",
})
BLUE = "#1565C0"
GREEN = "#2E7D32"
RED = "#C62828"
ORANGE = "#E65100"
PURPLE = "#6A1B9A"
GRAY = "#757575"
# ── Data ─────────────────────────────────────────────────────────────────────
# Case 1: Continual Meta-Learning for Few-Shot Adaptation
case1_iters = [0, 1, 2, 3, 4]
case1_labels = [
"Baseline\n(Initial Code)",
"Iter 1\n(Deep Encoder\n+ Meta-SGD)",
"Iter 2\n(Prototype Net\n— Regression)",
"Iter 3\n(Linear Clf\n+ L2 Anchor)",
"Iter 4\n(Converged)",
]
case1_error = [0.7411, 0.1883, 0.2249, 0.0663, 0.0656]
case1_accuracy = [100 * (1 - e) for e in case1_error]
# Marker styles: green=improved, red=regressed, gray=no change
case1_colors = [GRAY, GREEN, RED, GREEN, GRAY]
case1_improved = [None, True, False, True, None]
# Case 2: RLHF + Curriculum-Based Reward Shaping
case2_iters = [0, 1, 2, 3, 4]
case2_labels = [
"Baseline\n(Vanilla PPO)",
"Iter 1\n(No Change)",
"Iter 2\n(+Reward Model\n+Curriculum)",
"Iter 3\n(+Rank-Norm\n+Policy EMA)",
"Iter 4\n(+Confidence\nGating)",
]
case2_error = [0.6443, 0.6443, 0.3843, 0.3696, 0.3344]
case2_alignment = [100 * (1 - e) for e in case2_error]
case2_colors = [GRAY, GRAY, GREEN, GREEN, GREEN]
# ── Figure ───────────────────────────────────────────────────────────────────
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
# ── Case 1: Meta-Learning ───────────────────────────────────────────────────
# Main line
ax1.plot(case1_iters, case1_accuracy, "o-", color=BLUE, linewidth=2.5,
markersize=10, zorder=5, label="Few-Shot Accuracy")
# Colored markers for improvement status
for i, (x, y, c) in enumerate(zip(case1_iters, case1_accuracy, case1_colors)):
ax1.scatter(x, y, s=120, color=c, zorder=6, edgecolors="white", linewidths=1.5)
# Annotate key improvements
ax1.annotate(
"+55.3 pts\nDeep encoder\n+ context-gated replay",
xy=(1, case1_accuracy[1]), xytext=(1.3, 55),
fontsize=8.5, color=GREEN, fontweight="bold",
arrowprops=dict(arrowstyle="->", color=GREEN, lw=1.5),
ha="left",
)
ax1.annotate(
"Prototype net\ntoo simple",
xy=(2, case1_accuracy[2]), xytext=(2.25, 65),
fontsize=8, color=RED, fontstyle="italic",
arrowprops=dict(arrowstyle="->", color=RED, lw=1.2),
ha="left",
)
ax1.annotate(
"+15.9 pts\nLinear clf + L2 anchor\n+ cosine gating",
xy=(3, case1_accuracy[3]), xytext=(2.5, 98),
fontsize=8.5, color=GREEN, fontweight="bold",
arrowprops=dict(arrowstyle="->", color=GREEN, lw=1.5),
ha="left",
)
# Reference line for "ideal" performance
ax1.axhline(y=100, color=ORANGE, linestyle=":", alpha=0.6, linewidth=1.5)
ax1.text(4.3, 99, "Oracle (100%)", fontsize=8, color=ORANGE, ha="right",
fontstyle="italic", va="top")
# Shaded improvement region
ax1.fill_between(case1_iters, case1_accuracy, case1_accuracy[0],
where=[a >= case1_accuracy[0] for a in case1_accuracy],
alpha=0.08, color=BLUE)
ax1.set_xlabel("Self-Iteration Round", fontsize=12)
ax1.set_ylabel("Few-Shot Accuracy (%)", fontsize=12)
ax1.set_title("Case A: Continual Meta-Learning\nfor Few-Shot Adaptation", fontsize=13,
fontweight="bold", pad=12)
ax1.set_ylim(15, 105)
ax1.set_xticks(case1_iters)
ax1.set_xticklabels(case1_labels, fontsize=7.5, ha="center")
# Summary box
summary1 = f"Baseline: {case1_accuracy[0]:.1f}% → Best: {case1_accuracy[3]:.1f}%\nImprovement: +{case1_accuracy[3]-case1_accuracy[0]:.1f} pts ({(case1_accuracy[3]-case1_accuracy[0])/case1_accuracy[0]*100:.0f}% rel.)"
ax1.text(0.02, 0.97, summary1, transform=ax1.transAxes, fontsize=9,
verticalalignment="top", fontfamily="monospace",
bbox=dict(boxstyle="round,pad=0.5", facecolor="#E3F2FD", alpha=0.9,
edgecolor=BLUE, linewidth=1.2))
# ── Case 2: RLHF ────────────────────────────────────────────────────────────
ax2.plot(case2_iters, case2_alignment, "s-", color=PURPLE, linewidth=2.5,
markersize=10, zorder=5, label="Alignment Score")
for i, (x, y, c) in enumerate(zip(case2_iters, case2_alignment, case2_colors)):
ax2.scatter(x, y, s=120, color=c, zorder=6, edgecolors="white", linewidths=1.5,
marker="s")
# Annotate
ax2.annotate(
"No improvement\n(minor code fix)",
xy=(1, case2_alignment[1]), xytext=(1.3, 30),
fontsize=8, color=GRAY, fontstyle="italic",
arrowprops=dict(arrowstyle="->", color=GRAY, lw=1.2),
ha="left",
)
ax2.annotate(
"+26.0 pts\n+Learned reward model\n+Curriculum scheduling",
xy=(2, case2_alignment[2]), xytext=(1.8, 75),
fontsize=8.5, color=GREEN, fontweight="bold",
arrowprops=dict(arrowstyle="->", color=GREEN, lw=1.5),
ha="left",
)
ax2.annotate(
"+1.4 pts\n+Rank-norm\n+Policy EMA",
xy=(3, case2_alignment[3]), xytext=(3.2, 73),
fontsize=8, color=GREEN,
arrowprops=dict(arrowstyle="->", color=GREEN, lw=1.2),
ha="left",
)
ax2.annotate(
"+3.6 pts\n+Confidence gating\n+Mini-batch RM",
xy=(4, case2_alignment[4]), xytext=(3.5, 80),
fontsize=8.5, color=GREEN, fontweight="bold",
arrowprops=dict(arrowstyle="->", color=GREEN, lw=1.5),
ha="left",
)
# Shaded improvement
ax2.fill_between(case2_iters, case2_alignment, case2_alignment[0],
where=[a >= case2_alignment[0] for a in case2_alignment],
alpha=0.08, color=PURPLE)
ax2.set_xlabel("Self-Iteration Round", fontsize=12)
ax2.set_ylabel("LLM Alignment Score (%)", fontsize=12)
ax2.set_title("Case B: RLHF with Curriculum-Based\nReward Shaping for LLM Alignment", fontsize=13,
fontweight="bold", pad=12)
ax2.set_ylim(15, 105)
ax2.set_xticks(case2_iters)
ax2.set_xticklabels(case2_labels, fontsize=7.5, ha="center")
summary2 = f"Baseline: {case2_alignment[0]:.1f}% → Best: {case2_alignment[4]:.1f}%\nImprovement: +{case2_alignment[4]-case2_alignment[0]:.1f} pts ({(case2_alignment[4]-case2_alignment[0])/case2_alignment[0]*100:.0f}% rel.)"
ax2.text(0.02, 0.97, summary2, transform=ax2.transAxes, fontsize=9,
verticalalignment="top", fontfamily="monospace",
bbox=dict(boxstyle="round,pad=0.5", facecolor="#F3E5F5", alpha=0.9,
edgecolor=PURPLE, linewidth=1.2))
# ── Legend ───────────────────────────────────────────────────────────────────
legend_elements = [
mpatches.Patch(facecolor=GREEN, edgecolor="white", label="Improved"),
mpatches.Patch(facecolor=RED, edgecolor="white", label="Regressed (auto-recovered)"),
mpatches.Patch(facecolor=GRAY, edgecolor="white", label="No change / Baseline"),
]
fig.legend(handles=legend_elements, loc="lower center", ncol=3,
fontsize=10, frameon=True, fancybox=True, framealpha=0.9,
bbox_to_anchor=(0.5, -0.02))
# ── Suptitle ─────────────────────────────────────────────────────────────────
fig.suptitle(
"AutoResearchClaw: Autonomous Self-Iterating Experiment Optimization",
fontsize=15, fontweight="bold", y=1.02,
)
fig.tight_layout(rect=[0, 0.04, 1, 0.98])
# ── Save ─────────────────────────────────────────────────────────────────────
out_dir = Path(__file__).resolve().parent.parent / "docs" / "figures"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / "iteration_improvement_showcase.png"
fig.savefig(out_path, dpi=200, bbox_inches="tight", facecolor="white")
print(f"Saved: {out_path}")
# Also save a PDF version for papers
pdf_path = out_dir / "iteration_improvement_showcase.pdf"
fig.savefig(pdf_path, bbox_inches="tight", facecolor="white")
print(f"Saved: {pdf_path}")
plt.close(fig)
================================================
FILE: scripts/test_beast_mode_e2e.py
================================================
#!/usr/bin/env python3
"""End-to-end integration test for OpenCode Beast Mode.
Simulates Pipeline stages 1-9 artifacts, then invokes Beast Mode
to generate experiment code via OpenCode CLI.
Usage:
python scripts/test_beast_mode_e2e.py
"""
from __future__ import annotations
import json
import sys
import textwrap
import time
from pathlib import Path
# Add project root to path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from researchclaw.pipeline.opencode_bridge import (
OpenCodeBridge,
count_historical_failures,
score_complexity,
)
# ============================================================
# Simulated Pipeline Artifacts
# ============================================================
TOPIC = (
"Adaptive Mixtures of Local Experts for Image Classification: "
"Dynamic Gating with Load-Balanced Sparse Routing on CIFAR-10"
)
# Simulated Stage 9 output: exp_plan.yaml content
EXP_PLAN = textwrap.dedent("""\
topic: >
Adaptive Mixtures of Local Experts for Image Classification:
Dynamic Gating with Load-Balanced Sparse Routing on CIFAR-10
objectives:
- Investigate whether sparse Mixture-of-Experts (MoE) routing improves
accuracy over dense baselines under a fixed parameter budget
- Compare top-k routing vs soft routing vs hash-based routing
- Ablate the load-balancing auxiliary loss
datasets:
- CIFAR-10 (pre-cached at /opt/datasets/cifar10)
baselines:
- name: dense_resnet18
description: Standard ResNet-18 with all parameters active
implementation_spec:
class_name: DenseResNet18Trainer
key_hyperparameters:
batch_size: 128
learning_rate: 0.1
epochs: 20
weight_decay: 5e-4
- name: dense_wider_resnet
description: Wider ResNet with ~same FLOPs as MoE model
implementation_spec:
class_name: DenseWiderResNetTrainer
key_hyperparameters:
batch_size: 128
learning_rate: 0.1
epochs: 20
proposed_methods:
- name: topk_sparse_moe
description: >
Sparse MoE with top-2 gating. Each MoE layer has 4 expert MLPs,
a gating network selects top-2 per token. Load-balancing loss
ensures even expert utilization.
implementation_spec:
class_name: TopKSparseMoETrainer
algorithm_steps:
- Build backbone CNN (first 3 ResNet blocks)
- Replace final block with MoE layer (4 experts, top-2 gating)
- Gating network: linear projection → softmax → top-k selection
- Load-balance loss: CV of expert load across batch
- Total loss = CE + lambda_lb * load_balance_loss
key_hyperparameters:
batch_size: 128
learning_rate: 0.05
epochs: 20
num_experts: 4
top_k: 2
lambda_lb: 0.01
- name: soft_routing_moe
description: >
Soft MoE where all experts contribute with learned weights
(no hard top-k). Softer gradient flow but higher compute.
implementation_spec:
class_name: SoftRoutingMoETrainer
key_hyperparameters:
batch_size: 128
learning_rate: 0.05
epochs: 20
num_experts: 4
ablations:
- name: topk_moe_no_load_balance
description: TopK MoE without load-balancing loss (lambda_lb=0)
what_is_removed: Load-balancing auxiliary loss
expected_effect: Expert collapse — one expert dominates, accuracy drops
how_it_differs:
- Set lambda_lb = 0
- Everything else identical to topk_sparse_moe
- name: topk_moe_single_expert
description: TopK MoE with top_k=1 (only one expert per sample)
what_is_removed: Multi-expert routing (reduced to single expert)
expected_effect: Reduced model capacity per sample, likely lower accuracy
how_it_differs:
- Set top_k = 1 instead of 2
- Keep load-balancing loss active
metrics:
primary_metric:
name: test_accuracy
direction: maximize
description: Classification accuracy on CIFAR-10 test set
secondary_metrics:
- name: expert_utilization_cv
description: Coefficient of variation of expert usage (lower = more balanced)
- name: training_time_sec
description: Wall-clock training time
compute_budget:
effective_time_seconds: 240
estimated_seconds_per_run: 40
seeds_per_condition: 3
total_conditions: 6
notes:
- Use small models (< 5M params) to fit within budget
- Use 20 epochs max
- Early stopping if no improvement for 5 epochs
""")
PKG_HINT = textwrap.dedent("""\
AVAILABLE PACKAGES (docker mode): Python stdlib, numpy, torch, sklearn, scipy, pandas,
torchvision, torchaudio, matplotlib, seaborn, scipy, tqdm, transformers, datasets,
timm, einops, torchmetrics, and additional pip-installable packages via requirements.txt.
GPU: NVIDIA RTX 6000 Ada (cuda). You MAY use PyTorch with GPU acceleration.
Use `device = torch.device('cuda')` for tensor operations.
## Compute Budget Constraint
- Total execution time limit: 240 seconds
- Design experiments that complete within this budget
- Implement a time guard: stop gracefully at 80% of budget (192 seconds)
""")
EXTRA_GUIDANCE = textwrap.dedent("""\
## Dataset Guidance
CIFAR-10 is pre-cached at /opt/datasets/cifar10.
Use: torchvision.datasets.CIFAR10(root='/opt/datasets/cifar10', download=False)
## Multi-Seed Enforcement
Run each condition with seeds [0, 1, 2]. Report mean ± std for all metrics.
## Hyperparameter Reporting
Print all hyperparameters at the start of each condition run.
""")
def main() -> None:
print("=" * 70)
print("OpenCode Beast Mode — End-to-End Integration Test")
print("=" * 70)
# Step 1: Complexity scoring
print("\n[Step 1] Complexity scoring...")
cplx = score_complexity(
exp_plan=EXP_PLAN,
topic=TOPIC,
historical_failures=0,
threshold=0.4, # Lower threshold to ensure trigger for this test
)
print(f" Score: {cplx.score:.4f}")
print(f" Signals: {json.dumps(cplx.signals, indent=4)}")
print(f" Recommendation: {cplx.recommendation}")
print(f" Reason: {cplx.reason}")
if cplx.recommendation != "beast_mode":
print("\n [!] Score below threshold. Forcing beast mode for test purposes.\n")
# Step 2: Check OpenCode availability
print("\n[Step 2] Checking OpenCode availability...")
available = OpenCodeBridge.check_available()
if not available:
print(" [FATAL] OpenCode CLI not installed. Cannot proceed.")
sys.exit(1)
print(" OpenCode CLI: OK")
# Step 3: Create test workspace and invoke
print("\n[Step 3] Invoking OpenCode beast mode...")
test_dir = PROJECT_ROOT / "test_outputs_beast_mode"
test_dir.mkdir(parents=True, exist_ok=True)
stage_dir = test_dir / f"stage-10_{int(time.time())}"
stage_dir.mkdir(parents=True, exist_ok=True)
# Write complexity analysis
(stage_dir / "complexity_analysis.json").write_text(
json.dumps({
"score": cplx.score,
"signals": cplx.signals,
"recommendation": cplx.recommendation,
"reason": cplx.reason,
}, indent=2),
encoding="utf-8",
)
# NOTE: Azure AI Services endpoints don't support OpenCode's Responses API.
# The bridge auto-detects Azure and falls back to Anthropic provider.
bridge = OpenCodeBridge(
model="anthropic/claude-sonnet-4-6", # Direct Anthropic model
llm_base_url="https://huaxi-mlg4x1rk-eastus2.services.ai.azure.com/openai/v1",
api_key_env="AZURE_OPENAI_API_KEY",
llm_provider="azure",
timeout_sec=300,
max_retries=1,
workspace_cleanup=False, # Keep workspace for inspection
)
t0 = time.time()
result = bridge.generate(
stage_dir=stage_dir,
topic=TOPIC,
exp_plan=EXP_PLAN,
metric="test_accuracy",
pkg_hint=PKG_HINT,
extra_guidance=EXTRA_GUIDANCE,
time_budget_sec=240,
)
elapsed = time.time() - t0
# Step 4: Evaluate results
print(f"\n[Step 4] Results (elapsed: {elapsed:.1f}s)")
print(f" Success: {result.success}")
print(f" Error: {result.error or 'None'}")
print(f" Files: {list(result.files.keys())}")
print(f" OpenCode elapsed: {result.elapsed_sec:.1f}s")
if not result.success:
print(f"\n [FAILED] Beast mode failed: {result.error}")
print(f" Log (last 1000 chars):\n{result.opencode_log[-1000:]}")
# Write log for debugging
(stage_dir / "opencode_log.txt").write_text(
result.opencode_log, encoding="utf-8",
)
(stage_dir / "beast_mode_log.json").write_text(
json.dumps({
"success": False,
"error": result.error,
"elapsed_sec": result.elapsed_sec,
}, indent=2),
encoding="utf-8",
)
sys.exit(1)
# Write generated files
exp_dir = stage_dir / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
for fname, code in result.files.items():
fpath = exp_dir / fname
fpath.parent.mkdir(parents=True, exist_ok=True)
fpath.write_text(code, encoding="utf-8")
print(f"\n Files written to: {exp_dir}")
# Write beast mode log
(stage_dir / "beast_mode_log.json").write_text(
json.dumps({
"success": True,
"elapsed_sec": result.elapsed_sec,
"files": list(result.files.keys()),
}, indent=2),
encoding="utf-8",
)
# Step 5: Quality evaluation
print("\n[Step 5] Quality evaluation...")
checks = {
"main.py exists": "main.py" in result.files,
"main.py is non-empty": len(result.files.get("main.py", "")) > 100,
"Has metric print": "test_accuracy" in result.files.get("main.py", ""),
"Has seed loop": "seed" in result.files.get("main.py", "").lower(),
"Has CIFAR-10": "cifar" in result.files.get("main.py", "").lower(),
"Has torch import": "import torch" in result.files.get("main.py", ""),
"No argparse": "argparse" not in result.files.get("main.py", ""),
"Has multiple conditions": any(
kw in result.files.get("main.py", "").lower()
for kw in ["baseline", "dense", "moe", "expert", "condition"]
),
"Has time guard": any(
kw in result.files.get("main.py", "")
for kw in ["time.time", "time.monotonic", "time_budget", "time_limit"]
),
}
all_pass = True
for check_name, passed in checks.items():
status = "PASS" if passed else "FAIL"
if not passed:
all_pass = False
print(f" [{status}] {check_name}")
# Count lines of code
total_loc = sum(len(code.splitlines()) for code in result.files.values())
py_files = [f for f in result.files if f.endswith(".py")]
print(f"\n Total files: {len(result.files)}")
print(f" Python files: {len(py_files)}")
print(f" Total lines of code: {total_loc}")
# Try AST parsing main.py
import ast
try:
ast.parse(result.files["main.py"])
print(" [PASS] main.py AST parse: valid Python")
except SyntaxError as e:
print(f" [FAIL] main.py AST parse error: {e}")
all_pass = False
# Print first 50 lines of main.py for manual inspection
main_lines = result.files.get("main.py", "").splitlines()
print(f"\n --- main.py preview (first 50 of {len(main_lines)} lines) ---")
for i, line in enumerate(main_lines[:50], 1):
print(f" {i:4d} | {line}")
if len(main_lines) > 50:
print(f" ... ({len(main_lines) - 50} more lines)")
# Final verdict
print("\n" + "=" * 70)
pass_count = sum(1 for v in checks.values() if v)
total = len(checks)
if all_pass:
print(f"VERDICT: ALL CHECKS PASSED ({pass_count}/{total})")
else:
print(f"VERDICT: {pass_count}/{total} checks passed")
print(f"Stage dir: {stage_dir}")
print("=" * 70)
if __name__ == "__main__":
main()
================================================
FILE: scripts/test_code_agent_live.py
================================================
#!/usr/bin/env python3
"""Live test of CodeAgent with real LLM — evaluates code generation quality.
This script directly invokes the CodeAgent with real experiment plans
and evaluates the quality of generated code. No full pipeline needed.
Usage:
python scripts/test_code_agent_live.py [--model gpt-4.1] [--test-id 1]
"""
from __future__ import annotations
import argparse
import ast
import json
import os
import sys
import time
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from researchclaw.llm.client import LLMClient, LLMConfig
from researchclaw.pipeline.code_agent import CodeAgent, CodeAgentConfig
from researchclaw.prompts import PromptManager
# ---------------------------------------------------------------------------
# Test cases — progressively harder experiment scenarios
# ---------------------------------------------------------------------------
TEST_CASES = {
1: {
"name": "Vision Transformer on CIFAR-10",
"topic": (
"Comparing Vision Transformer (ViT) variants for image classification: "
"investigate how patch size, number of attention heads, and positional "
"encoding strategies affect classification accuracy on CIFAR-10"
),
"exp_plan": """
objectives:
- Compare ViT-Tiny variants with different patch sizes (4, 8, 16)
- Evaluate multi-head self-attention with different head counts (2, 4, 8)
- Compare learnable vs sinusoidal positional encodings
datasets:
- name: CIFAR-10
source: torchvision.datasets.CIFAR10
train_size: 50000
test_size: 10000
baselines:
- name: SimpleViT-P16
description: Standard ViT with patch_size=16, 4 heads, learnable pos encoding
proposed_methods:
- name: SmallPatch-ViT
implementation_spec:
class_name: SmallPatchViT
key_methods: [forward, _create_patches, _attention]
differentiator: Uses patch_size=4 for finer-grained spatial features
- name: ManyHead-ViT
implementation_spec:
class_name: ManyHeadViT
key_methods: [forward, _multi_head_attention]
differentiator: Uses 8 attention heads instead of 4
ablations:
- name: SinusoidalPos-ViT
description: Replace learnable positional encoding with sinusoidal
metrics:
- accuracy (higher is better)
- training_loss
compute_budget:
time_limit_sec: 300
epochs: 10
""",
"metric": "accuracy",
"min_files": 2,
"min_classes": 3,
"required_imports": ["torch", "torchvision"],
},
2: {
"name": "Distribution Shift Detection via Uncertainty",
"topic": (
"Detecting distribution shift in deployed ML models using "
"uncertainty estimation: comparing Monte Carlo Dropout, "
"Deep Ensembles, and Spectral-Normalized Neural GP (SNGP) "
"for out-of-distribution detection on corrupted CIFAR-10"
),
"exp_plan": """
objectives:
- Implement 3 uncertainty estimation methods for OOD detection
- Evaluate on CIFAR-10 vs CIFAR-10-C (corrupted) as OOD
- Compare AUROC for separating in-distribution from OOD samples
datasets:
- name: CIFAR-10
source: torchvision.datasets.CIFAR10
role: in-distribution
- name: CIFAR-10-C
source: Generated via Gaussian noise corruption
role: out-of-distribution
baselines:
- name: MCDropout
description: Monte Carlo Dropout with 30 forward passes, mean+std of softmax
implementation_spec:
class_name: MCDropoutDetector
key_methods: [predict_with_uncertainty, _mc_forward, compute_auroc]
differentiator: Standard MC Dropout baseline
proposed_methods:
- name: DeepEnsemble
implementation_spec:
class_name: DeepEnsembleDetector
key_methods: [train_ensemble, predict_with_uncertainty, _member_forward]
differentiator: Trains 3 independent models, uses prediction disagreement
- name: SNGP
implementation_spec:
class_name: SNGPDetector
key_methods: [forward, _spectral_norm_layer, _gp_output_layer]
differentiator: Spectral normalization + GP output layer for distance-aware uncertainty
ablations:
- name: MCDropout-10passes
description: MC Dropout with only 10 forward passes (reduced compute)
metrics:
- auroc (higher is better)
- ece (expected calibration error, lower is better)
compute_budget:
time_limit_sec: 300
epochs: 5
""",
"metric": "auroc",
"min_files": 2,
"min_classes": 4,
"required_imports": ["torch", "numpy"],
},
3: {
"name": "Meta-Learning Few-Shot with MAML",
"topic": (
"Few-shot learning with gradient-based meta-learning: comparing "
"MAML, Reptile, and Prototypical Networks on Omniglot-style "
"synthetic tasks with 5-way 1-shot and 5-way 5-shot settings"
),
"exp_plan": """
objectives:
- Implement 3 few-shot learning algorithms from scratch
- Evaluate on synthetic few-shot tasks (5-way, 1-shot and 5-shot)
- Compare accuracy and convergence speed
datasets:
- name: SyntheticFewShot
source: Generated in-code (random linear classification tasks)
n_classes: 20
samples_per_class: 20
baselines:
- name: ProtoNet
description: Prototypical Networks — learn embedding, classify by nearest class prototype
implementation_spec:
class_name: PrototypicalNetwork
key_methods: [embed, compute_prototypes, classify, meta_train_step]
differentiator: Non-gradient meta-learning baseline using metric space
proposed_methods:
- name: MAML
implementation_spec:
class_name: MAMLLearner
key_methods: [inner_loop, outer_loop, meta_train_step, adapt]
differentiator: Second-order gradient-based meta-learning with inner loop adaptation
- name: Reptile
implementation_spec:
class_name: ReptileLearner
key_methods: [inner_loop, meta_update, meta_train_step]
differentiator: First-order approximation — SGD on tasks, move toward task-optimal weights
ablations:
- name: MAML-FirstOrder
description: MAML with first-order approximation (no second derivatives)
metrics:
- accuracy (higher is better)
- meta_train_loss
compute_budget:
time_limit_sec: 300
meta_epochs: 200
inner_steps: 5
inner_lr: 0.01
""",
"metric": "accuracy",
"min_files": 2,
"min_classes": 3,
"required_imports": ["torch"],
},
}
# ---------------------------------------------------------------------------
# Code quality analysis
# ---------------------------------------------------------------------------
def analyze_code_quality(files: dict[str, str], test_case: dict) -> dict:
"""Analyze the quality of generated code."""
report = {
"test_name": test_case["name"],
"num_files": len(files),
"file_names": list(files.keys()),
"total_lines": 0,
"effective_lines": 0,
"classes_found": [],
"functions_found": [],
"imports_found": [],
"issues": [],
"scores": {},
}
all_code = ""
for fname, code in files.items():
all_code += code + "\n"
lines = code.split("\n")
report["total_lines"] += len(lines)
effective = [
l for l in lines
if l.strip() and not l.strip().startswith("#") and not l.strip().startswith("import") and not l.strip().startswith("from")
]
report["effective_lines"] += len(effective)
# AST analysis
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
methods = [
n.name for n in node.body
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
]
method_lines = sum(
n.end_lineno - n.lineno + 1
for n in node.body
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
and n.end_lineno
)
report["classes_found"].append({
"name": node.name,
"file": fname,
"methods": methods,
"method_count": len(methods),
"total_method_lines": method_lines,
})
elif isinstance(node, ast.FunctionDef) and node.col_offset == 0:
report["functions_found"].append({
"name": node.name,
"file": fname,
"lines": (node.end_lineno or node.lineno) - node.lineno + 1,
})
elif isinstance(node, (ast.Import, ast.ImportFrom)):
if isinstance(node, ast.Import):
for alias in node.names:
report["imports_found"].append(alias.name.split(".")[0])
else:
if node.module:
report["imports_found"].append(node.module.split(".")[0])
except SyntaxError as e:
report["issues"].append(f"SyntaxError in {fname}: {e}")
report["imports_found"] = sorted(set(report["imports_found"]))
# Scoring
# 1. File count (target: min_files)
file_score = min(10, (len(files) / test_case["min_files"]) * 10)
report["scores"]["file_structure"] = round(file_score, 1)
# 2. Class count (target: min_classes)
class_score = min(10, (len(report["classes_found"]) / test_case["min_classes"]) * 10)
report["scores"]["class_coverage"] = round(class_score, 1)
# 3. Code depth (effective lines)
depth_score = min(10, report["effective_lines"] / 30) # 300 lines = 10
report["scores"]["code_depth"] = round(depth_score, 1)
# 4. Method richness (average methods per class)
if report["classes_found"]:
avg_methods = sum(c["method_count"] for c in report["classes_found"]) / len(report["classes_found"])
method_score = min(10, avg_methods / 0.5) # 5 methods/class = 10
report["scores"]["method_richness"] = round(method_score, 1)
else:
report["scores"]["method_richness"] = 0
# 5. Import coverage
required = set(test_case.get("required_imports", []))
found = set(report["imports_found"])
if required:
import_score = len(required & found) / len(required) * 10
else:
import_score = 10
report["scores"]["import_coverage"] = round(import_score, 1)
# 6. Syntax validity
syntax_score = 10 if not any("SyntaxError" in i for i in report["issues"]) else 0
report["scores"]["syntax_valid"] = syntax_score
# Overall score
scores = report["scores"]
report["overall_score"] = round(
sum(scores.values()) / len(scores), 1
)
# Quality checks
if len(files) < test_case["min_files"]:
report["issues"].append(
f"Too few files: {len(files)} < {test_case['min_files']}"
)
if len(report["classes_found"]) < test_case["min_classes"]:
report["issues"].append(
f"Too few classes: {len(report['classes_found'])} < {test_case['min_classes']}"
)
for cls in report["classes_found"]:
if cls["total_method_lines"] < 10:
report["issues"].append(
f"Class {cls['name']} has only {cls['total_method_lines']} method lines (too thin)"
)
return report
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Live test CodeAgent quality")
parser.add_argument("--model", default="gpt-4.1", help="Model to use")
parser.add_argument("--test-id", type=int, default=0, help="Test case ID (0=all)")
parser.add_argument("--no-sandbox", action="store_true", help="Skip sandbox exec-fix")
parser.add_argument("--tree-search", action="store_true", help="Enable tree search")
parser.add_argument("--output-dir", default="test_outputs", help="Output directory")
args = parser.parse_args()
# Setup LLM client
base_url = os.environ.get("OPENAI_BASE_URL", "")
api_key = os.environ.get("OPENAI_API_KEY", "")
if not base_url or not api_key:
print("ERROR: Set OPENAI_BASE_URL and OPENAI_API_KEY environment variables")
sys.exit(1)
llm_config = LLMConfig(
base_url=base_url,
api_key=api_key,
primary_model=args.model,
fallback_models=[],
max_tokens=16384,
temperature=0.7,
timeout_sec=300,
)
llm = LLMClient(llm_config)
# Quick connectivity test
print(f"Testing LLM connectivity ({args.model})... ", end="", flush=True)
ok, msg = llm.preflight()
if not ok:
print(f"FAILED: {msg}")
sys.exit(1)
print("OK")
pm = PromptManager()
# Select test cases
if args.test_id > 0:
if args.test_id not in TEST_CASES:
print(f"ERROR: Unknown test ID {args.test_id}. Available: {list(TEST_CASES.keys())}")
sys.exit(1)
cases = {args.test_id: TEST_CASES[args.test_id]}
else:
cases = TEST_CASES
# Output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
all_reports = []
for test_id, tc in cases.items():
print(f"\n{'='*60}")
print(f"Test {test_id}: {tc['name']}")
print(f"{'='*60}")
stage_dir = output_dir / f"test_{test_id}"
stage_dir.mkdir(parents=True, exist_ok=True)
config = CodeAgentConfig(
architecture_planning=True,
exec_fix_max_iterations=0 if args.no_sandbox else 3,
tree_search_enabled=args.tree_search,
review_max_rounds=2,
)
agent = CodeAgent(
llm=llm,
prompts=pm,
config=config,
stage_dir=stage_dir,
sandbox_factory=None, # No sandbox for quick test
)
t0 = time.time()
result = agent.generate(
topic=tc["topic"],
exp_plan=tc["exp_plan"],
metric=tc["metric"],
pkg_hint=(
"\nAVAILABLE PACKAGES (docker mode): Python stdlib, numpy, "
"torch, torchvision, sklearn, scipy, pandas, matplotlib.\n"
"GPU: NVIDIA RTX 6000 Ada (49GB VRAM). "
"Use `device = torch.device('cuda')` for tensor operations.\n"
),
max_tokens=16384,
)
elapsed = time.time() - t0
print(f"\nGeneration time: {elapsed:.1f}s")
print(f"LLM calls: {result.total_llm_calls}")
print(f"Review rounds: {result.review_rounds}")
print(f"Architecture spec: {len(result.architecture_spec)} chars")
# Write generated files
for fname, code in result.files.items():
fpath = stage_dir / fname
fpath.parent.mkdir(parents=True, exist_ok=True)
fpath.write_text(code, encoding="utf-8")
lines = len(code.split("\n"))
print(f" {fname}: {lines} lines")
# Write architecture spec
if result.architecture_spec:
(stage_dir / "architecture_spec.yaml").write_text(
result.architecture_spec, encoding="utf-8"
)
# Analyze quality
report = analyze_code_quality(result.files, tc)
report["generation_time_sec"] = round(elapsed, 1)
report["llm_calls"] = result.total_llm_calls
report["review_rounds"] = result.review_rounds
report["architecture_spec_chars"] = len(result.architecture_spec)
# Print report
print(f"\n--- Quality Report ---")
print(f"Files: {report['num_files']}")
print(f"Total lines: {report['total_lines']}")
print(f"Effective lines: {report['effective_lines']}")
print(f"Classes: {len(report['classes_found'])}")
for cls in report["classes_found"]:
print(f" - {cls['name']} ({cls['method_count']} methods, {cls['total_method_lines']} lines)")
print(f"Imports: {', '.join(report['imports_found'])}")
print(f"\nScores:")
for k, v in report["scores"].items():
print(f" {k}: {v}/10")
print(f" OVERALL: {report['overall_score']}/10")
if report["issues"]:
print(f"\nIssues:")
for issue in report["issues"]:
print(f" - {issue}")
# Save report
(stage_dir / "quality_report.json").write_text(
json.dumps(report, indent=2), encoding="utf-8"
)
all_reports.append(report)
# Summary
if len(all_reports) > 1:
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
for r in all_reports:
print(f" {r['test_name']}: {r['overall_score']}/10 "
f"({r['effective_lines']} lines, {len(r['classes_found'])} classes)")
avg = sum(r["overall_score"] for r in all_reports) / len(all_reports)
print(f"\n Average: {avg:.1f}/10")
# Save all reports
(output_dir / "all_reports.json").write_text(
json.dumps(all_reports, indent=2), encoding="utf-8"
)
print(f"\nAll outputs saved to: {output_dir}/")
if __name__ == "__main__":
main()
================================================
FILE: scripts/test_code_agent_sandbox.py
================================================
#!/usr/bin/env python3
"""Test CodeAgent with Docker sandbox exec-fix loop.
Generates code with Phase 1-4 (architecture, exec-fix, review),
runs in Docker sandbox, verifies the exec-fix loop catches and fixes errors.
Usage:
python scripts/test_code_agent_sandbox.py [--model gpt-5.1] [--test-id 1]
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from researchclaw.config import DockerSandboxConfig, ExperimentConfig
from researchclaw.experiment.docker_sandbox import DockerSandbox
from researchclaw.llm.client import LLMClient, LLMConfig
from researchclaw.pipeline.code_agent import CodeAgent, CodeAgentConfig
from researchclaw.prompts import PromptManager
# ---------------------------------------------------------------------------
# Test case (simple — should run quickly in sandbox)
# ---------------------------------------------------------------------------
TEST_CASES = {
1: {
"name": "ViT on CIFAR-10 (sandbox)",
"topic": (
"Comparing Vision Transformer (ViT) variants for image classification: "
"investigate how patch size and number of attention heads affect "
"classification accuracy on CIFAR-10"
),
"exp_plan": """
objectives:
- Compare ViT-Tiny variants with different patch sizes (4, 16)
- Evaluate multi-head self-attention with different head counts (4, 8)
datasets:
- name: CIFAR-10
source: torchvision.datasets.CIFAR10
train_size: 50000
test_size: 10000
baselines:
- name: SimpleViT-P16
description: Standard ViT with patch_size=16, 4 heads, learnable pos encoding
proposed_methods:
- name: SmallPatch-ViT
implementation_spec:
class_name: SmallPatchViT
key_methods: [forward, _create_patches, _attention]
differentiator: Uses patch_size=4 for finer-grained spatial features
- name: ManyHead-ViT
implementation_spec:
class_name: ManyHeadViT
key_methods: [forward, _multi_head_attention]
differentiator: Uses 8 attention heads instead of 4
ablations:
- name: SinusoidalPos-ViT
description: Replace learnable positional encoding with sinusoidal
metrics:
- accuracy (higher is better)
- training_loss
compute_budget:
time_limit_sec: 120
epochs: 3
""",
"metric": "accuracy",
},
2: {
"name": "OOD Detection (sandbox)",
"topic": (
"Detecting distribution shift using uncertainty estimation: "
"comparing Monte Carlo Dropout and Deep Ensembles "
"for out-of-distribution detection on corrupted CIFAR-10"
),
"exp_plan": """
objectives:
- Implement 2 uncertainty estimation methods for OOD detection
- Evaluate on CIFAR-10 vs Gaussian noise corruption as OOD
- Compare AUROC for separating in-distribution from OOD samples
datasets:
- name: CIFAR-10
source: torchvision.datasets.CIFAR10
role: in-distribution
- name: CIFAR-10-C
source: Generated via Gaussian noise corruption
role: out-of-distribution
baselines:
- name: MCDropout
description: Monte Carlo Dropout with 20 forward passes
implementation_spec:
class_name: MCDropoutDetector
key_methods: [predict_with_uncertainty, _mc_forward, compute_auroc]
proposed_methods:
- name: DeepEnsemble
implementation_spec:
class_name: DeepEnsembleDetector
key_methods: [train_ensemble, predict_with_uncertainty]
differentiator: Trains 3 independent models, uses prediction disagreement
ablations:
- name: MCDropout-5passes
description: MC Dropout with only 5 forward passes
metrics:
- auroc (higher is better)
compute_budget:
time_limit_sec: 120
epochs: 3
""",
"metric": "auroc",
},
}
def make_sandbox_factory(docker_cfg: DockerSandboxConfig):
"""Return a factory function that creates DockerSandbox instances."""
def factory(exp_config, workdir: Path):
return DockerSandbox(docker_cfg, workdir)
return factory
def main():
parser = argparse.ArgumentParser(description="Test CodeAgent with Docker sandbox")
parser.add_argument("--model", default="gpt-5.1", help="Model to use")
parser.add_argument("--test-id", type=int, default=1, help="Test case ID")
parser.add_argument("--output-dir", default="test_outputs_sandbox", help="Output dir")
parser.add_argument("--exec-fix-iters", type=int, default=3, help="Max exec-fix iterations")
parser.add_argument("--timeout", type=int, default=180, help="Sandbox timeout (sec)")
args = parser.parse_args()
# Setup LLM
base_url = os.environ.get("OPENAI_BASE_URL", "")
api_key = os.environ.get("OPENAI_API_KEY", "")
if not base_url or not api_key:
print("ERROR: Set OPENAI_BASE_URL and OPENAI_API_KEY")
sys.exit(1)
llm_config = LLMConfig(
base_url=base_url,
api_key=api_key,
primary_model=args.model,
fallback_models=[],
max_tokens=16384,
temperature=0.7,
timeout_sec=300,
)
llm = LLMClient(llm_config)
print(f"Testing LLM connectivity ({args.model})... ", end="", flush=True)
ok, msg = llm.preflight()
if not ok:
print(f"FAILED: {msg}")
sys.exit(1)
print("OK")
# Docker sandbox setup
docker_cfg = DockerSandboxConfig(
image="researchclaw/experiment:latest",
gpu_enabled=True,
memory_limit_mb=16384,
network_policy="setup_only",
)
if not DockerSandbox.check_docker_available():
print("ERROR: Docker not available")
sys.exit(1)
if not DockerSandbox.ensure_image(docker_cfg.image):
print(f"ERROR: Docker image {docker_cfg.image} not found")
sys.exit(1)
print(f"Docker sandbox ready: {docker_cfg.image}")
# Select test case
tc = TEST_CASES.get(args.test_id)
if not tc:
print(f"ERROR: Unknown test ID {args.test_id}")
sys.exit(1)
pm = PromptManager()
output_dir = Path(args.output_dir)
stage_dir = output_dir / f"test_{args.test_id}"
stage_dir.mkdir(parents=True, exist_ok=True)
# CodeAgent with sandbox enabled
config = CodeAgentConfig(
architecture_planning=True,
exec_fix_max_iterations=args.exec_fix_iters,
exec_fix_timeout_sec=args.timeout,
tree_search_enabled=False,
review_max_rounds=2,
)
sandbox_factory = make_sandbox_factory(docker_cfg)
agent = CodeAgent(
llm=llm,
prompts=pm,
config=config,
stage_dir=stage_dir,
sandbox_factory=sandbox_factory,
)
print(f"\n{'='*60}")
print(f"Test {args.test_id}: {tc['name']}")
print(f" exec_fix_max_iterations: {args.exec_fix_iters}")
print(f" sandbox_timeout: {args.timeout}s")
print(f"{'='*60}")
t0 = time.time()
result = agent.generate(
topic=tc["topic"],
exp_plan=tc["exp_plan"],
metric=tc["metric"],
pkg_hint=(
"\nAVAILABLE PACKAGES (docker mode): Python stdlib, numpy, "
"torch, torchvision, sklearn, scipy, pandas, matplotlib, "
"tqdm, timm, einops, torchmetrics, gymnasium, networkx.\n"
"GPU: NVIDIA RTX 6000 Ada (49GB VRAM). "
"Use `device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')` "
"for tensor operations.\n"
"DATA PATH: CIFAR-10 is pre-cached at /opt/datasets/cifar-10-batches-py/. "
"Use `torchvision.datasets.CIFAR10(root='/opt/datasets', download=False)`.\n"
),
max_tokens=16384,
)
elapsed = time.time() - t0
# Report
print(f"\n--- Generation Report ---")
print(f"Time: {elapsed:.1f}s")
print(f"LLM calls: {result.total_llm_calls}")
print(f"Sandbox runs: {result.total_sandbox_runs}")
print(f"Review rounds: {result.review_rounds}")
print(f"Best score: {result.best_score}")
# Write files
for fname, code in result.files.items():
fpath = stage_dir / fname
fpath.parent.mkdir(parents=True, exist_ok=True)
fpath.write_text(code, encoding="utf-8")
lines = len(code.split("\n"))
print(f" {fname}: {lines} lines")
# Write arch spec
if result.architecture_spec:
(stage_dir / "architecture_spec.yaml").write_text(
result.architecture_spec, encoding="utf-8"
)
# Write validation log
(stage_dir / "validation_log.json").write_text(
json.dumps({
"log": result.validation_log,
"total_llm_calls": result.total_llm_calls,
"total_sandbox_runs": result.total_sandbox_runs,
"review_rounds": result.review_rounds,
"best_score": result.best_score,
"elapsed_sec": round(elapsed, 1),
}, indent=2),
encoding="utf-8",
)
# Final sandbox run for end-to-end verification
print(f"\n--- Final sandbox verification ---")
workdir = stage_dir / "_final_run"
workdir.mkdir(parents=True, exist_ok=True)
sandbox = DockerSandbox(docker_cfg, workdir)
final_result = sandbox.run_project(
stage_dir, entry_point="main.py", timeout_sec=args.timeout,
)
print(f"Return code: {final_result.returncode}")
print(f"Elapsed: {final_result.elapsed_sec:.1f}s")
print(f"Timed out: {final_result.timed_out}")
if final_result.metrics:
print(f"Metrics: {json.dumps(dict(final_result.metrics), indent=2)}")
if final_result.returncode != 0:
print(f"STDERR (last 500):\n{final_result.stderr[-500:]}")
else:
print("SUCCESS: Code runs to completion in Docker sandbox!")
stdout_lines = final_result.stdout.strip().split("\n")
print(f"STDOUT (last 10 lines):")
for line in stdout_lines[-10:]:
print(f" {line}")
# Save final run results
(stage_dir / "final_run_result.json").write_text(
json.dumps({
"returncode": final_result.returncode,
"elapsed_sec": final_result.elapsed_sec,
"timed_out": final_result.timed_out,
"metrics": dict(final_result.metrics) if final_result.metrics else {},
"stdout_tail": "\n".join(stdout_lines[-20:]) if final_result.returncode == 0 else "",
"stderr_tail": final_result.stderr[-1000:] if final_result.returncode != 0 else "",
}, indent=2),
encoding="utf-8",
)
if __name__ == "__main__":
main()
================================================
FILE: scripts/test_codegen_v2.py
================================================
#!/usr/bin/env python3
"""Enhanced code generation test — generates code and runs in Docker sandbox.
Tests the full code generation pipeline in isolation:
1. Load experiment plan (from previous run or built-in test case)
2. Generate code via CodeAgent
3. Validate generated code (AST, security, quality)
4. Run in Docker sandbox
5. Score results comprehensively
Usage:
# Run with built-in test case
python scripts/test_codegen_v2.py --test-id 1
# Run with real experiment plan from a previous run
python scripts/test_codegen_v2.py --from-run output/run20
# Run all built-in test cases
python scripts/test_codegen_v2.py --test-id 0
# Skip sandbox (only test generation quality)
python scripts/test_codegen_v2.py --test-id 1 --no-sandbox
"""
from __future__ import annotations
import argparse
import ast
import json
import os
import re
import sys
import time
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from researchclaw.llm.client import LLMClient, LLMConfig
from researchclaw.pipeline.code_agent import CodeAgent, CodeAgentConfig
from researchclaw.prompts import PromptManager
# ---------------------------------------------------------------------------
# Built-in test cases
# ---------------------------------------------------------------------------
TEST_CASES = {
1: {
"name": "KD for Compact ViTs (CIFAR-10)",
"topic": (
"Knowledge Distillation for Compact Vision Transformers: "
"Attention-Guided Feature Alignment on CIFAR-10"
),
"exp_plan": """
topic: "Knowledge Distillation for Compact Vision Transformers"
datasets:
- name: CIFAR-10
source: torchvision.datasets.CIFAR10
path: /opt/datasets/cifar10
baselines:
- name: TeacherResNet18
description: Pre-trained ResNet-18 teacher model (frozen)
implementation_spec:
class_name: TeacherResNet18
key_methods: [__init__, forward]
algorithm_steps:
- Load pre-trained ResNet-18 from torchvision
- Freeze all parameters
- Use as teacher for distillation
- name: StudentViT_Baseline
description: Compact ViT trained with standard cross-entropy (no KD)
implementation_spec:
class_name: StudentViTBaseline
key_methods: [__init__, forward, train_epoch, evaluate]
algorithm_steps:
- Compact ViT with patch_size=4, dim=128, depth=4, heads=4
- Train with cross-entropy loss only
- Standard SGD optimizer with cosine LR schedule
loss_function: "L = CrossEntropy(student_logits, labels)"
key_hyperparameters:
lr: 0.01
batch_size: 128
epochs: 20
proposed_methods:
- name: AttentionGuidedKD
description: Knowledge distillation with attention-guided feature alignment
aligns_hypothesis: H1
implementation_spec:
class_name: AttentionGuidedKDStudent
key_methods: [__init__, forward, compute_kd_loss, compute_attention_loss, train_epoch]
algorithm_steps:
- Same compact ViT architecture as baseline
- KD loss with temperature T=4
- Attention transfer loss between teacher and student attention maps
- Combined loss = alpha * KD_loss + beta * attention_loss + (1-alpha-beta) * CE_loss
loss_function: "L = 0.5*KLDiv(s/T, t/T)*T^2 + 0.3*MSE(student_attn, teacher_attn) + 0.2*CE(s, y)"
key_hyperparameters:
temperature: 4
alpha: 0.5
beta: 0.3
lr: 0.01
differentiator: Uses attention map alignment between teacher and student
ablations:
- name: KD_NoAttentionTransfer
based_on: AttentionGuidedKD
what_is_removed: Attention transfer loss (beta=0)
how_it_differs: Only uses KD loss + CE loss, no attention alignment
expected_effect: Lower accuracy due to missing attention guidance
- name: KD_ReducedCapacity
based_on: AttentionGuidedKD
what_is_removed: Half the model capacity (dim=64, depth=2, heads=2)
how_it_differs: Smaller ViT architecture, same training procedure
expected_effect: Lower accuracy due to reduced model capacity
metrics:
primary_metric:
name: primary_metric
direction: maximize
description: Top-1 accuracy on CIFAR-10 test set
compute_budget:
total_time_seconds: 300
conditions: [TeacherResNet18, StudentViT_Baseline, AttentionGuidedKD, KD_NoAttentionTransfer, KD_ReducedCapacity]
""",
"metric": "primary_metric",
"metric_direction": "maximize",
},
2: {
"name": "PPO with Curiosity Reward (Gymnasium)",
"topic": (
"Agent-Centric Reinforcement Learning with Adaptive Reward "
"Decomposition for CartPole and LunarLander"
),
"exp_plan": """
topic: "Agent-Centric RL with Adaptive Reward Decomposition"
datasets:
- name: CartPole-v1
source: gymnasium
- name: LunarLander-v3
source: gymnasium
baselines:
- name: VanillaPPO
description: Standard PPO with clipped surrogate objective
implementation_spec:
class_name: VanillaPPO
key_methods: [__init__, select_action, update, train_episode]
algorithm_steps:
- Policy network (2-layer MLP, 64 hidden)
- Value network (separate 2-layer MLP)
- Clipped surrogate objective with epsilon=0.2
- GAE lambda=0.95 for advantage estimation
loss_function: "L_policy = -min(r*A, clip(r,1-eps,1+eps)*A); L_value = MSE(V, R)"
key_hyperparameters:
lr: 3e-4
gamma: 0.99
clip_eps: 0.2
gae_lambda: 0.95
differentiator: Standard PPO baseline
proposed_methods:
- name: CuriosityPPO
description: PPO with intrinsic curiosity module
implementation_spec:
class_name: CuriosityPPO
key_methods: [__init__, select_action, compute_intrinsic_reward, update, train_episode]
algorithm_steps:
- Same PPO base as VanillaPPO
- Forward dynamics model predicts next state from (state, action)
- Intrinsic reward = prediction error of forward model
- Total reward = extrinsic + eta * intrinsic
loss_function: "L = L_ppo + L_forward_model; r_total = r_ext + eta * ||f(s,a) - s'||^2"
key_hyperparameters:
eta: 0.1
forward_model_lr: 1e-3
differentiator: Adds intrinsic curiosity-driven exploration reward
ablations:
- name: PPO_NoCuriosity
based_on: CuriosityPPO
what_is_removed: Intrinsic reward (eta=0, forward model not used)
how_it_differs: Same architecture but intrinsic reward zeroed out
expected_effect: Should match VanillaPPO performance
- name: PPO_ReducedNetwork
based_on: VanillaPPO
what_is_removed: Half network capacity (32 hidden units)
how_it_differs: Smaller policy and value networks
expected_effect: Lower performance due to limited capacity
metrics:
primary_metric:
name: primary_metric
direction: maximize
description: Average episodic reward over last 10 episodes
compute_budget:
total_time_seconds: 300
conditions: [VanillaPPO, CuriosityPPO, PPO_NoCuriosity, PPO_ReducedNetwork]
""",
"metric": "primary_metric",
"metric_direction": "maximize",
},
3: {
"name": "Graph Neural ODE (Synthetic)",
"topic": (
"Graph Neural Ordinary Differential Equations for Dynamic System "
"Modeling on Synthetic Coupled Oscillator Networks"
),
"exp_plan": """
topic: "Graph Neural ODE for Dynamic System Modeling"
datasets:
- name: SyntheticOscillators
source: Generated in-code
description: Coupled spring-mass system on a random graph
baselines:
- name: StaticGCN
description: Standard GCN applied at discrete time steps
implementation_spec:
class_name: StaticGCN
key_methods: [__init__, forward, predict_trajectory]
algorithm_steps:
- 2-layer GCN with message passing
- Discrete time step predictions
- MSE loss on next-step prediction
loss_function: "L = MSE(pred_next, true_next)"
key_hyperparameters:
hidden_dim: 64
num_layers: 2
lr: 1e-3
proposed_methods:
- name: GraphNeuralODE
description: Continuous-time dynamics via Neural ODE on graph
implementation_spec:
class_name: GraphNeuralODE
key_methods: [__init__, forward, ode_func, predict_trajectory]
algorithm_steps:
- GNN-based ODE function f(t, x, A) that defines dx/dt
- Neural ODE solver (torchdiffeq.odeint) for continuous trajectory
- MSE loss on trajectory prediction at observed time points
loss_function: "L = MSE(odeint(f, x0, t), x_true)"
key_hyperparameters:
hidden_dim: 64
solver: dopri5
lr: 1e-3
differentiator: Continuous-time dynamics via ODE solver
ablations:
- name: GraphODE_NoMessagePassing
based_on: GraphNeuralODE
what_is_removed: Graph structure (treats nodes independently)
how_it_differs: ODE function ignores adjacency, no message passing
expected_effect: Worse prediction on coupled systems
- name: GraphODE_EulerSolver
based_on: GraphNeuralODE
what_is_removed: Adaptive ODE solver (uses fixed-step Euler)
how_it_differs: Simple Euler integration instead of dopri5
expected_effect: Less accurate trajectories
metrics:
primary_metric:
name: primary_metric
direction: minimize
description: MSE between predicted and true trajectories
compute_budget:
total_time_seconds: 300
conditions: [StaticGCN, GraphNeuralODE, GraphODE_NoMessagePassing, GraphODE_EulerSolver]
""",
"metric": "primary_metric",
"metric_direction": "minimize",
},
}
# ---------------------------------------------------------------------------
# Code quality analysis (comprehensive)
# ---------------------------------------------------------------------------
def analyze_code_quality(files: dict[str, str], test_case: dict) -> dict:
"""Comprehensive code quality analysis."""
report = {
"test_name": test_case["name"],
"num_files": len(files),
"file_names": list(files.keys()),
"total_lines": 0,
"effective_lines": 0,
"classes_found": [],
"functions_found": [],
"imports_found": [],
"issues": [],
"scores": {},
}
for fname, code in files.items():
lines = code.split("\n")
report["total_lines"] += len(lines)
effective = [
l for l in lines
if l.strip()
and not l.strip().startswith("#")
and not l.strip().startswith('"""')
and not l.strip().startswith("'''")
]
report["effective_lines"] += len(effective)
# AST analysis
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
methods = [
n.name for n in node.body
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
]
method_lines = sum(
(n.end_lineno or n.lineno) - n.lineno + 1
for n in node.body
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
)
# Check for empty methods
empty_methods = []
for n in node.body:
if isinstance(n, ast.FunctionDef):
body_stmts = [
s for s in n.body
if not isinstance(s, (ast.Pass, ast.Expr))
or (isinstance(s, ast.Expr)
and not isinstance(s.value, (ast.Constant, ast.Str)))
]
if len(body_stmts) <= 1:
empty_methods.append(n.name)
report["classes_found"].append({
"name": node.name,
"file": fname,
"methods": methods,
"method_count": len(methods),
"total_method_lines": method_lines,
"bases": [ast.unparse(b) for b in node.bases],
"empty_methods": empty_methods,
})
elif isinstance(node, ast.FunctionDef) and node.col_offset == 0:
report["functions_found"].append({
"name": node.name,
"file": fname,
"lines": (node.end_lineno or node.lineno) - node.lineno + 1,
})
elif isinstance(node, (ast.Import, ast.ImportFrom)):
if isinstance(node, ast.Import):
for alias in node.names:
report["imports_found"].append(alias.name.split(".")[0])
elif node.module:
report["imports_found"].append(node.module.split(".")[0])
except SyntaxError as e:
report["issues"].append(f"CRITICAL: SyntaxError in {fname}: {e}")
report["imports_found"] = sorted(set(report["imports_found"]))
# ---- Scoring ----
# 1. Syntax validity (0 or 10)
syntax_ok = not any("SyntaxError" in i for i in report["issues"])
report["scores"]["syntax_valid"] = 10 if syntax_ok else 0
# 2. File structure
file_score = min(10, len(files) * 5) # 2+ files = 10
report["scores"]["file_structure"] = round(file_score, 1)
# 3. Class coverage
n_classes = len(report["classes_found"])
class_score = min(10, n_classes * 2.5) # 4+ classes = 10
report["scores"]["class_coverage"] = round(class_score, 1)
# 4. Code depth
depth_score = min(10, report["effective_lines"] / 40) # 400+ = 10
report["scores"]["code_depth"] = round(depth_score, 1)
# 5. Method richness
if report["classes_found"]:
avg_methods = sum(c["method_count"] for c in report["classes_found"]) / n_classes
method_score = min(10, avg_methods * 2) # 5+ methods = 10
else:
method_score = 0
report["scores"]["method_richness"] = round(method_score, 1)
# 6. Class distinctness (check for identical/empty classes)
empty_class_count = sum(
1 for c in report["classes_found"]
if c["total_method_lines"] < 5
)
identical_pairs = _check_identical_classes(files)
distinctness = 10
if empty_class_count > 0:
distinctness -= empty_class_count * 3
report["issues"].append(
f"WARNING: {empty_class_count} classes have <5 method lines (too thin)"
)
if identical_pairs:
distinctness -= len(identical_pairs) * 4
for p in identical_pairs:
report["issues"].append(f"WARNING: Identical classes: {p}")
report["scores"]["class_distinctness"] = max(0, round(distinctness, 1))
# 7. Import appropriateness
has_torch = "torch" in report["imports_found"]
has_numpy = "numpy" in report["imports_found"]
import_score = 5 # base
if has_torch:
import_score += 3
if has_numpy:
import_score += 2
report["scores"]["imports"] = min(10, import_score)
# Overall score
scores = report["scores"]
report["overall_score"] = round(sum(scores.values()) / len(scores), 1)
return report
def _check_identical_classes(files: dict[str, str]) -> list[str]:
"""Check for classes with identical method bodies."""
identical = []
class_bodies: dict[str, str] = {}
for fname, code in files.items():
try:
tree = ast.parse(code)
except SyntaxError:
continue
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Hash the method bodies
method_code = ""
for n in node.body:
if isinstance(n, ast.FunctionDef):
try:
method_code += ast.unparse(n) + "\n"
except Exception:
pass
if method_code:
key = hash(method_code)
if key in class_bodies:
identical.append(
f"{class_bodies[key]} == {node.name}"
)
else:
class_bodies[key] = node.name
return identical
# ---------------------------------------------------------------------------
# Sandbox execution
# ---------------------------------------------------------------------------
def run_in_sandbox(
files: dict[str, str],
output_dir: Path,
config_path: str | None = None,
timeout_sec: int = 300,
) -> dict:
"""Run generated code in subprocess (or Docker sandbox if available)."""
# Write files
code_dir = output_dir / "experiment"
code_dir.mkdir(parents=True, exist_ok=True)
for fname, code in files.items():
(code_dir / fname).write_text(code, encoding="utf-8")
# Try to run with subprocess as fallback
import subprocess
main_py = code_dir / "main.py"
if not main_py.exists():
return {"status": "failed", "reason": "no main.py"}
print(f" Running in subprocess (timeout={timeout_sec}s)...")
try:
proc = subprocess.run(
[sys.executable, str(main_py)],
cwd=str(code_dir),
capture_output=True,
text=True,
timeout=timeout_sec,
env={**os.environ, "PYTHONPATH": str(code_dir)},
)
stdout = proc.stdout
stderr = proc.stderr
returncode = proc.returncode
timed_out = False
except subprocess.TimeoutExpired:
stdout = ""
stderr = "TIMEOUT"
returncode = -1
timed_out = True
# Parse results
result = {
"status": "success" if returncode == 0 else "failed",
"returncode": returncode,
"timed_out": timed_out,
"stdout_lines": len(stdout.split("\n")) if stdout else 0,
"stderr_lines": len(stderr.split("\n")) if stderr else 0,
"conditions_found": [],
"metrics_found": {},
"has_metric_def": False,
"has_registered_conditions": False,
}
# Parse stdout for conditions and metrics
if stdout:
for line in stdout.split("\n"):
if line.startswith("METRIC_DEF:"):
result["has_metric_def"] = True
elif line.startswith("REGISTERED_CONDITIONS:"):
result["has_registered_conditions"] = True
conds = line.split(":", 1)[1].strip()
result["conditions_found"] = [c.strip() for c in conds.split(",")]
elif "condition=" in line:
m = re.match(r"condition=(\S+)\s+(\S+):\s+(\S+)", line)
if m:
cond, metric_name, value = m.groups()
if cond not in result["metrics_found"]:
result["metrics_found"][cond] = {}
try:
result["metrics_found"][cond][metric_name] = float(value)
except ValueError:
pass
# Score execution
exec_score = 0
if returncode == 0:
exec_score += 3 # runs
if result["has_metric_def"]:
exec_score += 1
if result["has_registered_conditions"]:
exec_score += 1
if result["conditions_found"]:
exec_score += min(3, len(result["conditions_found"])) # up to 3 for conditions
if result["metrics_found"]:
exec_score += 2 # produces metrics
result["exec_score"] = min(10, exec_score)
# Save stdout/stderr
(output_dir / "stdout.txt").write_text(stdout or "(empty)", encoding="utf-8")
(output_dir / "stderr.txt").write_text(stderr or "(empty)", encoding="utf-8")
return result
# ---------------------------------------------------------------------------
# Load experiment plan from previous run
# ---------------------------------------------------------------------------
def load_from_run(run_dir: str) -> dict:
"""Load experiment plan and config from a previous pipeline run."""
run_path = Path(run_dir)
if not run_path.exists():
print(f"ERROR: Run directory not found: {run_dir}")
sys.exit(1)
# Find exp_plan.yaml
plan_path = None
for s9_dir in sorted(run_path.glob("stage-09*"), reverse=True):
candidate = s9_dir / "exp_plan.yaml"
if candidate.exists():
plan_path = candidate
break
if plan_path is None:
print(f"ERROR: No exp_plan.yaml found in {run_dir}/stage-09*/")
sys.exit(1)
exp_plan = plan_path.read_text(encoding="utf-8")
# Load topic from config or stage-01
topic = ""
for topic_file in ["topic_evaluation.json", "topic.json"]:
for s_dir in sorted(run_path.glob("stage-0[12]*"), reverse=True):
tf = s_dir / topic_file
if tf.exists():
try:
td = json.loads(tf.read_text(encoding="utf-8"))
topic = td.get("topic", "") or td.get("research_topic", "")
if topic:
break
except Exception:
pass
if topic:
break
# Try to extract topic from exp_plan if not found elsewhere
if not topic:
import yaml
try:
plan_data = yaml.safe_load(exp_plan)
topic = plan_data.get("topic", "Unknown Topic")
except Exception:
topic = "Unknown Topic"
return {
"name": f"From {run_path.name}",
"topic": topic,
"exp_plan": exp_plan,
"metric": "primary_metric",
"metric_direction": "maximize",
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Test code generation quality with optional sandbox execution"
)
parser.add_argument("--model", default="gpt-5.1", help="Model to use")
parser.add_argument("--test-id", type=int, default=0, help="Test case ID (0=all)")
parser.add_argument("--from-run", default="", help="Load exp plan from run dir")
parser.add_argument("--no-sandbox", action="store_true", help="Skip sandbox execution")
parser.add_argument("--sandbox-timeout", type=int, default=300, help="Sandbox timeout (sec)")
parser.add_argument("--output-dir", default="test_outputs_codegen", help="Output dir")
parser.add_argument("--config", default="config_run20.yaml", help="Config file for LLM")
args = parser.parse_args()
# Setup LLM client
# Try loading from config file first
config_path = Path(args.config)
if config_path.exists():
import yaml
with open(config_path) as f:
cfg = yaml.safe_load(f)
llm_cfg = cfg.get("llm", {})
base_url = llm_cfg.get("base_url", "")
api_key = llm_cfg.get("api_key", "") or os.environ.get(
llm_cfg.get("api_key_env", "OPENAI_API_KEY"), ""
)
else:
base_url = os.environ.get("OPENAI_BASE_URL", "")
api_key = os.environ.get("OPENAI_API_KEY", "")
if not base_url or not api_key:
print("ERROR: Need LLM config. Provide --config or set env vars.")
sys.exit(1)
llm_config = LLMConfig(
base_url=base_url,
api_key=api_key,
primary_model=args.model,
fallback_models=["gpt-4.1", "gpt-4o"],
max_tokens=16384,
temperature=0.7,
timeout_sec=300,
)
llm = LLMClient(llm_config)
# Connectivity test
print(f"Testing LLM ({args.model})...", end=" ", flush=True)
ok, msg = llm.preflight()
if not ok:
print(f"FAILED: {msg}")
sys.exit(1)
print("OK")
pm = PromptManager()
# Select test cases
if args.from_run:
cases = {99: load_from_run(args.from_run)}
elif args.test_id > 0:
if args.test_id not in TEST_CASES:
print(f"ERROR: Unknown test ID {args.test_id}. Available: {list(TEST_CASES.keys())}")
sys.exit(1)
cases = {args.test_id: TEST_CASES[args.test_id]}
else:
cases = dict(TEST_CASES)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
all_reports = []
for test_id, tc in cases.items():
print(f"\n{'='*70}")
print(f" Test {test_id}: {tc['name']}")
print(f"{'='*70}")
stage_dir = output_dir / f"test_{test_id}"
stage_dir.mkdir(parents=True, exist_ok=True)
# Configure CodeAgent
agent_config = CodeAgentConfig(
architecture_planning=True,
exec_fix_max_iterations=0, # no sandbox in generation phase
tree_search_enabled=False,
review_max_rounds=2,
)
agent = CodeAgent(
llm=llm,
prompts=pm,
config=agent_config,
stage_dir=stage_dir,
)
# Build pkg_hint
pkg_hint = (
"\nAVAILABLE PACKAGES (docker mode): Python stdlib, numpy, torch, "
"torchvision, torchaudio, matplotlib, seaborn, scipy, tqdm, "
"torchdiffeq, gymnasium, networkx, PyYAML, Pillow, transformers, "
"datasets, accelerate, peft, timm, einops, torchmetrics.\n"
"GPU: NVIDIA RTX 6000 Ada (49GB VRAM). "
"Use `device = torch.device('cuda')` for tensor operations.\n"
)
metric_dir = tc.get("metric_direction", "maximize")
pkg_hint += f"\nMETRIC DIRECTION: {metric_dir}\n"
# Add compute budget
pkg_hint += (
"\n## Compute Budget Constraint\n"
"- Total execution time limit: 300 seconds\n"
"- Design experiments that complete within this budget\n"
"- Implement a time guard: stop gracefully at 80% of budget\n"
)
# Generate
t0 = time.time()
result = agent.generate(
topic=tc["topic"],
exp_plan=tc["exp_plan"],
metric=tc.get("metric", "primary_metric"),
pkg_hint=pkg_hint,
max_tokens=16384,
)
gen_elapsed = time.time() - t0
print(f"\n Generation: {gen_elapsed:.1f}s, {result.total_llm_calls} LLM calls")
print(f" Architecture spec: {len(result.architecture_spec)} chars")
print(f" Review rounds: {result.review_rounds}")
# Write files
for fname, code in result.files.items():
fpath = stage_dir / fname
fpath.parent.mkdir(parents=True, exist_ok=True)
fpath.write_text(code, encoding="utf-8")
print(f" -> {fname}: {len(code.split(chr(10)))} lines")
if result.architecture_spec:
(stage_dir / "architecture_spec.yaml").write_text(
result.architecture_spec, encoding="utf-8"
)
# Quality analysis
report = analyze_code_quality(result.files, tc)
report["generation_time_sec"] = round(gen_elapsed, 1)
report["llm_calls"] = result.total_llm_calls
# Sandbox execution
exec_result = {"status": "skipped"}
if not args.no_sandbox and result.files:
exec_result = run_in_sandbox(
result.files, stage_dir,
timeout_sec=args.sandbox_timeout,
)
report["execution"] = exec_result
print(f"\n Execution: {exec_result['status']}")
if exec_result.get("returncode") is not None:
print(f" Return code: {exec_result['returncode']}")
if exec_result.get("conditions_found"):
print(f" Conditions: {', '.join(exec_result['conditions_found'])}")
if exec_result.get("metrics_found"):
for cond, metrics in exec_result["metrics_found"].items():
print(f" {cond}: {metrics}")
# Print scores
print(f"\n --- Scores ---")
for k, v in report["scores"].items():
print(f" {k}: {v}/10")
if exec_result.get("exec_score") is not None:
print(f" execution: {exec_result['exec_score']}/10")
print(f" OVERALL: {report['overall_score']}/10")
if report["issues"]:
print(f"\n Issues:")
for issue in report["issues"]:
print(f" - {issue}")
# Save report
(stage_dir / "quality_report.json").write_text(
json.dumps(report, indent=2, default=str), encoding="utf-8"
)
all_reports.append(report)
# Summary
if len(all_reports) > 1:
print(f"\n{'='*70}")
print(" SUMMARY")
print(f"{'='*70}")
for r in all_reports:
exec_info = ""
if "execution" in r:
exec_info = f" | exec: {r['execution'].get('status', '?')}"
print(
f" {r['test_name']}: {r['overall_score']}/10 "
f"({r['effective_lines']} lines, "
f"{len(r['classes_found'])} classes{exec_info})"
)
avg = sum(r["overall_score"] for r in all_reports) / len(all_reports)
print(f"\n Average: {avg:.1f}/10")
(output_dir / "summary.json").write_text(
json.dumps(all_reports, indent=2, default=str), encoding="utf-8"
)
print(f"\nAll outputs saved to: {output_dir}/")
if __name__ == "__main__":
main()
================================================
FILE: sentinel.sh
================================================
#!/usr/bin/env bash
# sentinel.sh — Watchdog for AutoResearchClaw pipeline process.
#
# Monitors the pipeline heartbeat file and auto-restarts on crash.
# Inspired by Sibyl's sentinel watchdog design.
#
# Usage:
# ./sentinel.sh [--python ]
#
# The pipeline runner writes heartbeat.json after each stage. If the
# heartbeat goes stale (>5 min) and the PID is dead, sentinel restarts.
#
# Configuration via environment:
# SENTINEL_CHECK_INTERVAL — seconds between checks (default: 60)
# SENTINEL_STALE_THRESHOLD — seconds before heartbeat is stale (default: 300)
# SENTINEL_MAX_RETRIES — max restart attempts (default: 5)
# SENTINEL_COOLDOWN — seconds to wait after 3 consecutive failures (default: 360)
set -euo pipefail
# --- Arguments ---
RUN_DIR="${1:?Usage: sentinel.sh [--python ]}"
PYTHON_PATH="python"
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--python)
PYTHON_PATH="$2"
shift 2
;;
*)
echo "Unknown argument: $1" >&2
exit 1
;;
esac
done
# --- Configuration ---
CHECK_INTERVAL="${SENTINEL_CHECK_INTERVAL:-60}"
STALE_THRESHOLD="${SENTINEL_STALE_THRESHOLD:-300}"
MAX_RETRIES="${SENTINEL_MAX_RETRIES:-5}"
COOLDOWN="${SENTINEL_COOLDOWN:-360}"
HEARTBEAT_FILE="${RUN_DIR}/heartbeat.json"
RECOVERY_LOG="${RUN_DIR}/sentinel_recovery.log"
FAILED_LOG="${RUN_DIR}/sentinel_failed.log"
retry_count=0
consecutive_failures=0
log() {
local msg="[sentinel $(date '+%Y-%m-%dT%H:%M:%S')] $1"
echo "$msg"
echo "$msg" >> "$RECOVERY_LOG"
}
# --- Check if heartbeat is stale ---
is_stale() {
if [[ ! -f "$HEARTBEAT_FILE" ]]; then
return 0 # No heartbeat = stale
fi
local now
now=$(date +%s)
# Extract timestamp from heartbeat.json
local hb_ts
hb_ts=$(python3 -c "
import json, sys
try:
data = json.load(open('${HEARTBEAT_FILE}'))
from datetime import datetime
ts = datetime.fromisoformat(data['timestamp'])
print(int(ts.timestamp()))
except Exception:
print(0)
" 2>/dev/null || echo 0)
local age=$(( now - hb_ts ))
[[ $age -gt $STALE_THRESHOLD ]]
}
# --- Check if PID is alive ---
pid_alive() {
local pid_file="${RUN_DIR}/pipeline.pid"
if [[ ! -f "$pid_file" ]]; then
return 1
fi
local pid
pid=$(cat "$pid_file" 2>/dev/null || echo "")
if [[ -z "$pid" ]]; then
return 1
fi
kill -0 "$pid" 2>/dev/null
}
# --- Check for active subprocesses ---
has_active_children() {
local pid_file="${RUN_DIR}/pipeline.pid"
if [[ ! -f "$pid_file" ]]; then
return 1
fi
local pid
pid=$(cat "$pid_file" 2>/dev/null || echo "")
if [[ -z "$pid" ]]; then
return 1
fi
# Check if any child processes exist
pgrep -P "$pid" > /dev/null 2>&1
}
# --- Restart pipeline ---
restart_pipeline() {
log "Attempting pipeline restart (attempt $((retry_count + 1))/${MAX_RETRIES})"
$PYTHON_PATH -m researchclaw run --resume --output "$RUN_DIR" &
local new_pid=$!
echo "$new_pid" > "${RUN_DIR}/pipeline.pid"
log "Pipeline restarted with PID ${new_pid}"
retry_count=$((retry_count + 1))
}
# --- Main loop ---
log "Sentinel started for ${RUN_DIR}"
log "Check interval: ${CHECK_INTERVAL}s, Stale threshold: ${STALE_THRESHOLD}s"
log "Max retries: ${MAX_RETRIES}, Cooldown: ${COOLDOWN}s"
while true; do
sleep "$CHECK_INTERVAL"
# If PID is alive, reset failure counter
if pid_alive; then
consecutive_failures=0
continue
fi
# PID is dead — check if heartbeat is stale
if ! is_stale; then
# Heartbeat is fresh but PID is gone — might have just exited normally
continue
fi
# Don't interrupt active subprocesses
if has_active_children; then
log "Active subprocesses detected — skipping restart"
continue
fi
# Check retry limit
if [[ $retry_count -ge $MAX_RETRIES ]]; then
log "Max retries (${MAX_RETRIES}) reached — sentinel giving up"
echo "Sentinel failed after ${MAX_RETRIES} retries at $(date)" >> "$FAILED_LOG"
exit 1
fi
# Cooldown after consecutive failures
consecutive_failures=$((consecutive_failures + 1))
if [[ $consecutive_failures -ge 3 ]]; then
log "3 consecutive failures — cooling down for ${COOLDOWN}s"
sleep "$COOLDOWN"
consecutive_failures=0
fi
restart_pipeline
done
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/conftest.py
================================================
# conftest.py — shared pytest fixtures for researchclaw tests
================================================
FILE: tests/e2e_docker_sandbox.py
================================================
#!/usr/bin/env python3
"""End-to-end verification for Docker sandbox.
Run after building the image:
docker build -t researchclaw/experiment:latest researchclaw/docker/
python tests/e2e_docker_sandbox.py
"""
from __future__ import annotations
import json
import sys
import tempfile
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from researchclaw.config import DockerSandboxConfig, ExperimentConfig
from researchclaw.experiment.docker_sandbox import DockerSandbox
from researchclaw.experiment.factory import create_sandbox
PASS = "\033[92mPASS\033[0m"
FAIL = "\033[91mFAIL\033[0m"
SKIP = "\033[93mSKIP\033[0m"
results: list[tuple[str, bool, str]] = []
def check(name: str, ok: bool, detail: str = "") -> None:
results.append((name, ok, detail))
tag = PASS if ok else FAIL
msg = f" [{tag}] {name}"
if detail:
msg += f" — {detail}"
print(msg)
def main() -> None:
print("=" * 60)
print("Docker Sandbox End-to-End Verification")
print("=" * 60)
# ── Preflight ──────────────────────────────────────────────
print("\n--- Preflight ---")
docker_ok = DockerSandbox.check_docker_available()
check("Docker daemon reachable", docker_ok)
if not docker_ok:
print("\nDocker is not available. Cannot proceed.")
sys.exit(1)
image_ok = DockerSandbox.ensure_image("researchclaw/experiment:latest")
check("Image exists locally", image_ok)
if not image_ok:
print("\nImage not found. Build it first:")
print(" docker build -t researchclaw/experiment:latest researchclaw/docker/")
sys.exit(1)
# ── Test 1: Basic execution + metrics ──────────────────────
print("\n--- Test 1: Basic execution + metrics ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
cfg = DockerSandboxConfig(gpu_enabled=False, network_policy="none")
sandbox = DockerSandbox(cfg, Path(tmp) / "work")
code = (
"import numpy as np\n"
"x = np.random.randn(100)\n"
"print(f'primary_metric: {float(np.mean(x**2)):.4f}')\n"
"print(f'std: {float(np.std(x)):.4f}')\n"
"print('Done.')\n"
)
r = sandbox.run(code, timeout_sec=60)
check("returncode == 0", r.returncode == 0, f"rc={r.returncode}")
check("metrics parsed", "primary_metric" in r.metrics, str(r.metrics))
check("stdout non-empty", bool(r.stdout.strip()), repr(r.stdout[:100]))
check("timed_out is False", r.timed_out is False)
check("elapsed_sec > 0", r.elapsed_sec > 0, f"{r.elapsed_sec:.2f}s")
# ── Test 2: Multi-file project ─────────────────────────────
print("\n--- Test 2: Multi-file project ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
cfg = DockerSandboxConfig(gpu_enabled=False, network_policy="none")
sandbox = DockerSandbox(cfg, Path(tmp) / "work")
project = Path(tmp) / "project"
project.mkdir()
(project / "utils.py").write_text(
"def add(a, b): return a + b\n", encoding="utf-8"
)
(project / "main.py").write_text(
"from utils import add\n"
"result = add(3, 4)\n"
"print(f'primary_metric: {result}')\n",
encoding="utf-8",
)
r = sandbox.run_project(project, timeout_sec=60)
check("project returncode == 0", r.returncode == 0, f"rc={r.returncode}")
check("project metric correct", r.metrics.get("primary_metric") == 7.0,
str(r.metrics))
# ── Test 3: results.json ───────────────────────────────────
print("\n--- Test 3: results.json from volume ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
cfg = DockerSandboxConfig(gpu_enabled=False, network_policy="none")
sandbox = DockerSandbox(cfg, Path(tmp) / "work")
code = (
"import json\n"
"results = {'accuracy': 0.92, 'f1': 0.88}\n"
"with open('results.json', 'w') as f:\n"
" json.dump(results, f)\n"
"print('primary_metric: 0.92')\n"
)
r = sandbox.run(code, timeout_sec=60)
check("results.json metric merged", "f1" in r.metrics,
str(r.metrics))
# ── Test 4: Network isolation ──────────────────────────────
print("\n--- Test 4: Network isolation ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
cfg = DockerSandboxConfig(gpu_enabled=False, network_policy="none")
sandbox = DockerSandbox(cfg, Path(tmp) / "work")
code = (
"import urllib.request\n"
"try:\n"
" urllib.request.urlopen('http://example.com', timeout=5)\n"
" print('NETWORK_ACCESS: yes')\n"
"except Exception as e:\n"
" print('NETWORK_ACCESS: no')\n"
" print(f'primary_metric: 1.0')\n"
)
r = sandbox.run(code, timeout_sec=30)
network_blocked = "NETWORK_ACCESS: no" in r.stdout
check("Network blocked (--network=none)", network_blocked,
r.stdout.strip()[:200])
# ── Test 5: GPU visibility ─────────────────────────────────
print("\n--- Test 5: GPU visibility ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
cfg = DockerSandboxConfig(gpu_enabled=True, network_policy="none")
sandbox = DockerSandbox(cfg, Path(tmp) / "work")
code = (
"import torch\n"
"gpu_available = torch.cuda.is_available()\n"
"if gpu_available:\n"
" print(f'GPU: {torch.cuda.get_device_name(0)}')\n"
" print('primary_metric: 1.0')\n"
"else:\n"
" print('GPU: none')\n"
" print('primary_metric: 0.0')\n"
)
r = sandbox.run(code, timeout_sec=60)
gpu_visible = "primary_metric" in r.metrics and r.metrics["primary_metric"] == 1.0
if gpu_visible:
check("GPU visible in container", True, r.stdout.strip()[:200])
else:
# Not a hard failure — might not have NVIDIA runtime
print(f" [{SKIP}] GPU not visible (NVIDIA Container Toolkit may not be installed)")
print(f" stdout: {r.stdout.strip()[:200]}")
print(f" stderr: {r.stderr.strip()[:200]}")
# ── Test 6: Memory limit ──────────────────────────────────
print("\n--- Test 6: Memory limit enforcement ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
# Set a very low memory limit to trigger OOM
cfg = DockerSandboxConfig(
gpu_enabled=False, network_policy="none", memory_limit_mb=64
)
sandbox = DockerSandbox(cfg, Path(tmp) / "work")
code = (
"import numpy as np\n"
"# Allocate ~200MB to exceed 64MB limit\n"
"x = np.ones((25_000_000,), dtype=np.float64)\n"
"print(f'primary_metric: {x.sum()}')\n"
)
r = sandbox.run(code, timeout_sec=30)
oom = r.returncode != 0
check("OOM kills container (64MB limit, 200MB alloc)", oom,
f"rc={r.returncode}, stderr={r.stderr.strip()[:200]}")
# ── Test 7: Factory integration ────────────────────────────
print("\n--- Test 7: Factory integration ---")
with tempfile.TemporaryDirectory(prefix="rc_e2e_") as tmp:
config = ExperimentConfig(mode="docker", docker=DockerSandboxConfig(gpu_enabled=False))
sandbox = create_sandbox(config, Path(tmp) / "work")
check("Factory returns DockerSandbox", isinstance(sandbox, DockerSandbox))
r = sandbox.run("print('primary_metric: 42.0')", timeout_sec=30)
check("Factory sandbox executes", r.returncode == 0 and r.metrics.get("primary_metric") == 42.0,
str(r.metrics))
# ── Summary ────────────────────────────────────────────────
print("\n" + "=" * 60)
passed = sum(1 for _, ok, _ in results if ok)
failed = sum(1 for _, ok, _ in results if not ok)
print(f"Results: {passed} passed, {failed} failed")
if failed:
print("\nFailed tests:")
for name, ok, detail in results:
if not ok:
print(f" - {name}: {detail}")
sys.exit(1)
else:
print("All tests passed!")
if __name__ == "__main__":
main()
================================================
FILE: tests/e2e_real_llm.py
================================================
#!/usr/bin/env python3
"""Real E2E test: run all 22 stages with actual LLM API calls.
Usage:
.venv_arc/bin/python3 tests/e2e_real_llm.py
"""
from __future__ import annotations
import json
import sys
import time
from pathlib import Path
import yaml
# Ensure project root is on path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from researchclaw.config import RCConfig
from researchclaw.adapters import AdapterBundle
from researchclaw.llm.client import LLMClient
from researchclaw.pipeline.stages import Stage, STAGE_SEQUENCE
from researchclaw.pipeline.executor import execute_stage, StageResult
from researchclaw.pipeline.runner import execute_pipeline
def main() -> None:
# --- Load config ---
config_path = Path("config.arc.yaml")
if not config_path.exists():
print("ERROR: config.arc.yaml not found")
sys.exit(1)
with open(config_path) as f:
raw = yaml.safe_load(f)
# Override for test
raw["research"]["topic"] = (
"Efficient Attention Mechanisms for Long-Context Language Models"
)
raw["experiment"]["mode"] = "sandbox"
raw["experiment"]["time_budget_sec"] = 60
raw["experiment"]["max_iterations"] = 3
config = RCConfig.from_dict(raw, check_paths=False)
adapters = AdapterBundle()
# --- Create run directory ---
run_dir = Path("artifacts/e2e-real-llm-run")
run_dir.mkdir(parents=True, exist_ok=True)
run_id = f"e2e-real-{int(time.time())}"
print(f"=" * 70)
print(f"ResearchClaw E2E Test — Real LLM API")
print(f"Topic: {config.research.topic}")
print(f"Run ID: {run_id}")
print(f"Output: {run_dir}")
print(f"=" * 70)
# --- Run full pipeline ---
start = time.time()
results = execute_pipeline(
run_dir=run_dir,
run_id=run_id,
config=config,
adapters=adapters,
auto_approve_gates=True, # Auto-approve all gates for E2E test
kb_root=run_dir / "kb",
)
total_time = time.time() - start
# --- Report ---
print(f"\n{'=' * 70}")
print(f"RESULTS: {len(results)}/22 stages executed in {total_time:.1f}s")
print(f"{'=' * 70}")
passed = 0
failed = 0
for r in results:
status_icon = "✅" if r.status.value == "done" else "❌"
print(
f" {status_icon} Stage {int(r.stage):02d} {r.stage.name}: {r.status.value} | artifacts: {r.artifacts}"
)
if r.status.value == "done":
passed += 1
else:
failed += 1
print(f"\n{'=' * 70}")
print(f"SUMMARY: {passed} passed, {failed} failed, {total_time:.1f}s total")
print(f"{'=' * 70}")
# --- Validate key artifacts ---
checks = [
("Stage 1 goal.md", "stage-01/goal.md"),
("Stage 10 experiment.py", "stage-10/experiment.py"),
("Stage 12 runs/", "stage-12/runs"),
("Stage 14 experiment_summary.json", "stage-14/experiment_summary.json"),
("Stage 17 paper_draft.md", "stage-17/paper_draft.md"),
("Stage 22 export files", "stage-22"),
]
print("\nArtifact Checks:")
for label, path in checks:
full = run_dir / path
exists = full.exists()
if full.is_file():
size = full.stat().st_size
print(f" {'✅' if exists else '❌'} {label}: {size} bytes")
elif full.is_dir():
count = len(list(full.iterdir())) if exists else 0
print(f" {'✅' if exists else '❌'} {label}: {count} items")
else:
print(f" {'❌'} {label}: NOT FOUND")
# --- Check experiment_summary.json has real data ---
summary_path = run_dir / "stage-14" / "experiment_summary.json"
if summary_path.exists():
summary = json.loads(summary_path.read_text())
has_metrics = bool(summary.get("metrics_summary"))
print(
f"\n 📊 Experiment summary has real metrics: {'YES' if has_metrics else 'NO'}"
)
if has_metrics:
for k, v in summary["metrics_summary"].items():
print(f" - {k}: {v}")
# --- Check paper draft has real data (not placeholder) ---
draft_path = run_dir / "stage-17" / "paper_draft.md"
if draft_path.exists():
draft = draft_path.read_text()
has_placeholder = "no quantitative results yet" in draft.lower()
has_template = draft.count("Template") > 3
print(
f" 📝 Paper draft: {len(draft)} chars, placeholder={has_placeholder}, template={has_template}"
)
# --- Check validation report ---
val_report = run_dir / "stage-10" / "validation_report.md"
if val_report.exists():
print(f" 🔍 Code validation report: {val_report.stat().st_size} bytes")
print(f" {val_report.read_text()[:200]}")
# Final verdict
if passed == 22 and failed == 0:
print(f"\n🎉 ALL 22 STAGES PASSED!")
sys.exit(0)
else:
print(f"\n⚠️ {failed} stages did not pass.")
sys.exit(1)
if __name__ == "__main__":
main()
================================================
FILE: tests/test_anthropic.py
================================================
"""测试 Anthropic Messages 兼容 API 是否可用。"""
from __future__ import annotations
import os
from typing import Any
import httpx
import pytest
pytestmark = pytest.mark.skipif(
"ANTHROPIC_API_KEY" not in os.environ,
reason="ANTHROPIC_API_KEY not set",
)
BASE_URL = os.environ.get("ANTHROPIC_BASE_URL", "https://api.anthropic.com")
API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")
MODEL = os.environ.get("ANTHROPIC_MODEL", "claude-haiku-4-5-20251001")
def _create_message() -> dict[str, Any]:
url = f"{BASE_URL.rstrip('/')}/v1/messages"
headers = {
"content-type": "application/json",
"anthropic-version": "2023-06-01",
"x-api-key": API_KEY,
}
payload = {
"model": MODEL,
"max_tokens": 256,
"messages": [{"role": "user", "content": "Say hello in one sentence."}],
}
with httpx.Client(timeout=30.0) as client:
response = client.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
def test_anthropic_api() -> None:
message = _create_message()
usage = message.get("usage", {})
content = message.get("content", [])
text_blocks = [block.get("text", "") for block in content if block.get("type") == "text"]
print(f"Status: stop_reason={message.get('stop_reason')}")
print(f"Model: {message.get('model')}")
print(f"Usage: input={usage.get('input_tokens')}, output={usage.get('output_tokens')}")
print(f"Response: {' '.join(text_blocks)}")
assert message.get("type") == "message"
assert len(content) > 0
print("\n✅ API 可用!")
if __name__ == "__main__":
test_anthropic_api()
================================================
FILE: tests/test_assessor.py
================================================
"""Tests for researchclaw.assessor — Paper Quality Assessor (Agent D3).
20+ tests covering rubrics, scorer, venue_recommender, and comparator.
"""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock
import pytest
from researchclaw.assessor.rubrics import RUBRICS, Rubric
from researchclaw.assessor.scorer import PaperScorer
from researchclaw.assessor.venue_recommender import VenueRecommender
from researchclaw.assessor.comparator import HistoryComparator
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _sample_paper() -> str:
return (
"# Novel Graph Attention Networks\n\n"
"## Abstract\nWe propose a new method for graph-based learning.\n\n"
"## Experiments\nWe compare against baseline on CIFAR-10.\n"
"Results are shown in table 1 and figure 2.\n"
"Our method achieves 95.2% accuracy.\n"
) * 5 # ~500 words
def _sample_scores(overall: float = 7.5) -> dict[str, Any]:
return {
"scores": {
"novelty": 7.0,
"rigor": 8.0,
"clarity": 7.0,
"impact": 7.5,
"experiments": 8.0,
},
"overall": overall,
}
class MockLLM:
"""Minimal mock LLM client."""
def __init__(self, response: str = "SCORE: 7\nREASON: Solid contribution"):
self.response = response
async def chat_async(self, prompt: str) -> str:
return self.response
class FailingLLM:
async def chat_async(self, prompt: str) -> str:
raise RuntimeError("API error")
# ===================================================================
# Rubric tests
# ===================================================================
class TestRubrics:
def test_all_five_dimensions_present(self):
assert set(RUBRICS.keys()) == {
"novelty", "rigor", "clarity", "impact", "experiments"
}
def test_rubric_is_frozen(self):
r = RUBRICS["novelty"]
with pytest.raises(AttributeError):
r.name = "changed" # type: ignore[misc]
def test_rubric_has_criteria_and_scale(self):
for dim, rubric in RUBRICS.items():
assert rubric.criteria, f"{dim} missing criteria"
assert rubric.scale, f"{dim} missing scale"
def test_default_weight(self):
r = Rubric(name="test", criteria="test criteria", scale="1-10")
assert r.weight == 1.0
# ===================================================================
# PaperScorer tests
# ===================================================================
class TestPaperScorer:
def test_score_without_llm(self):
scorer = PaperScorer()
result = asyncio.run(scorer.score(_sample_paper()))
assert "overall" in result
assert "scores" in result
assert isinstance(result["overall"], float)
assert len(result["dimensions_evaluated"]) == 5
def test_score_with_mock_llm(self):
llm = MockLLM("SCORE: 8\nREASON: Excellent work")
scorer = PaperScorer(llm_client=llm)
result = asyncio.run(scorer.score(_sample_paper()))
assert result["overall"] == 8.0
for dim in result["scores"]:
assert result["scores"][dim] == 8.0
def test_score_with_failing_llm_falls_back(self):
scorer = PaperScorer(llm_client=FailingLLM())
result = asyncio.run(scorer.score(_sample_paper()))
# Should still return valid scores via heuristic
assert "overall" in result
assert result["overall"] > 0
def test_score_subset_dimensions(self):
scorer = PaperScorer(dimensions=("novelty", "clarity"))
result = asyncio.run(scorer.score(_sample_paper()))
assert len(result["dimensions_evaluated"]) == 2
def test_parse_score_valid(self):
score, reason = PaperScorer._parse_score_response(
"SCORE: 9\nREASON: Breakthrough paper", "novelty"
)
assert score == 9.0
assert reason == "Breakthrough paper"
def test_parse_score_clamped(self):
score, _ = PaperScorer._parse_score_response("SCORE: 15", "test")
assert score == 10.0
score, _ = PaperScorer._parse_score_response("SCORE: 0", "test")
assert score == 1.0
def test_parse_score_missing(self):
score, reason = PaperScorer._parse_score_response("No format here", "test")
assert score == 5.0 # default
assert reason == "No detail provided"
def test_heuristic_clarity_long_paper(self):
long_paper = "word " * 4000
score, detail = PaperScorer._heuristic_score(long_paper, RUBRICS["clarity"])
assert score == 6.0
assert "4000" in detail
def test_heuristic_clarity_short_paper(self):
short_paper = "word " * 500
score, _ = PaperScorer._heuristic_score(short_paper, RUBRICS["clarity"])
assert score == 3.0
def test_heuristic_experiments_with_table_and_figure(self):
paper = "Results in table 1 and figure 3 show improvements."
score, _ = PaperScorer._heuristic_score(paper, RUBRICS["experiments"])
assert score == 7.0 # 4.0 + 1.5 + 1.5
def test_heuristic_experiments_no_evidence(self):
paper = "We discuss theoretical implications."
score, _ = PaperScorer._heuristic_score(paper, RUBRICS["experiments"])
assert score == 4.0
def test_heuristic_default_dimension(self):
paper = "Some paper content"
score, reason = PaperScorer._heuristic_score(paper, RUBRICS["novelty"])
assert score == 5.0
assert "default" in reason.lower()
# ===================================================================
# VenueRecommender tests
# ===================================================================
class TestVenueRecommender:
def test_recommend_high_score(self):
rec = VenueRecommender()
scores = _sample_scores(overall=9.0)
results = rec.recommend(scores)
# Should include tier 1 venues
tier_1_venues = [r for r in results if r["tier"] == "tier_1"]
assert len(tier_1_venues) > 0
def test_recommend_low_score(self):
rec = VenueRecommender()
scores = _sample_scores(overall=2.0)
results = rec.recommend(scores)
assert len(results) == 0
def test_recommend_medium_score_no_tier1(self):
rec = VenueRecommender()
scores = _sample_scores(overall=5.0)
results = rec.recommend(scores)
tier_1 = [r for r in results if r["tier"] == "tier_1"]
assert len(tier_1) == 0
def test_recommend_filter_by_domain(self):
rec = VenueRecommender()
scores = _sample_scores(overall=9.0)
results = rec.recommend(scores, domains=["cv"])
for r in results:
assert "cv" in r["venue_domains"] or "deep-learning" in r["venue_domains"]
def test_get_suggestion_weak_dimension(self):
scores = {"scores": {"novelty": 3, "clarity": 8}, "overall": 5.5}
suggestion = VenueRecommender._get_suggestion("ICML", scores)
assert "novelty" in suggestion.lower()
assert "Strengthen" in suggestion
def test_get_suggestion_moderate(self):
scores = {"scores": {"novelty": 6, "clarity": 8}, "overall": 7.0}
suggestion = VenueRecommender._get_suggestion("ICML", scores)
assert "improving" in suggestion.lower()
def test_get_suggestion_strong(self):
scores = {"scores": {"novelty": 8, "clarity": 9}, "overall": 8.5}
suggestion = VenueRecommender._get_suggestion("ICML", scores)
assert "strong" in suggestion.lower()
def test_get_suggestion_no_scores(self):
scores = {"overall": 5.0}
suggestion = VenueRecommender._get_suggestion("ICML", scores)
assert "Evaluate" in suggestion
def test_format_recommendations_empty(self):
rec = VenueRecommender()
output = rec.format_recommendations([])
assert "No suitable venues" in output
def test_format_recommendations_with_data(self):
rec = VenueRecommender()
results = rec.recommend(_sample_scores(overall=9.0))
output = rec.format_recommendations(results)
assert "Venue Recommendations" in output
# ===================================================================
# HistoryComparator tests
# ===================================================================
class TestHistoryComparator:
def test_record_and_get_history(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(7.5))
history = comp.get_history()
assert len(history) == 1
assert history[0]["run_id"] == "run-1"
def test_record_persists_to_disk(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(7.5))
# Reload from disk
comp2 = HistoryComparator(history_dir=tmp_path)
assert len(comp2.get_history()) == 1
def test_compare_no_history(self):
comp = HistoryComparator()
result = comp.compare(_sample_scores(8.0))
assert result["comparison"] == "no_history"
def test_compare_with_previous(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(6.0))
result = comp.compare(_sample_scores(8.0), previous_run_id="run-1")
assert result["comparison"] == "success"
assert result["delta"] == 2.0
assert result["trend"] == "improved"
def test_compare_stable_trend(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(7.5))
result = comp.compare(_sample_scores(7.5))
assert result["trend"] == "stable"
def test_compare_declined_trend(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(9.0))
result = comp.compare(_sample_scores(7.0))
assert result["trend"] == "declined"
def test_compare_not_found(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(7.0))
result = comp.compare(_sample_scores(8.0), previous_run_id="nonexistent")
assert result["comparison"] == "not_found"
def test_get_best_run(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
comp.record("run-1", "topic A", _sample_scores(6.0))
comp.record("run-2", "topic B", _sample_scores(9.0))
comp.record("run-3", "topic C", _sample_scores(7.5))
best = comp.get_best_run()
assert best is not None
assert best["run_id"] == "run-2"
def test_get_best_run_empty(self):
comp = HistoryComparator()
assert comp.get_best_run() is None
def test_dimension_deltas(self, tmp_path: Path):
comp = HistoryComparator(history_dir=tmp_path)
scores_old = {
"scores": {"novelty": 5.0, "clarity": 6.0},
"overall": 5.5,
}
scores_new = {
"scores": {"novelty": 7.0, "clarity": 8.0},
"overall": 7.5,
}
comp.record("run-1", "topic A", scores_old)
result = comp.compare(scores_new, previous_run_id="run-1")
assert result["dimension_deltas"]["novelty"] == 2.0
assert result["dimension_deltas"]["clarity"] == 2.0
================================================
FILE: tests/test_benchmark_agent.py
================================================
"""Tests for the BenchmarkAgent multi-agent system."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import pytest
import yaml
# ---------------------------------------------------------------------------
# Fake LLM client (same pattern as test_code_agent.py)
# ---------------------------------------------------------------------------
@dataclass
class FakeLLMResponse:
content: str = ""
model: str = "fake"
prompt_tokens: int = 10
completion_tokens: int = 20
total_tokens: int = 30
finish_reason: str = "stop"
truncated: bool = False
raw: dict = field(default_factory=dict)
class FakeLLM:
"""Fake LLM that returns preconfigured responses."""
def __init__(self, responses: list[str] | None = None) -> None:
self._responses = list(responses or [])
self._idx = 0
self.calls: list[dict[str, Any]] = []
def chat(self, messages, **kwargs) -> FakeLLMResponse:
self.calls.append({"messages": messages, **kwargs})
if self._idx < len(self._responses):
content = self._responses[self._idx]
self._idx += 1
else:
content = '{"benchmarks": [], "baselines": []}'
return FakeLLMResponse(content=content)
# ---------------------------------------------------------------------------
# Knowledge base tests
# ---------------------------------------------------------------------------
class TestBenchmarkKnowledge:
"""Test the benchmark_knowledge.yaml file."""
def test_knowledge_file_exists(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import _KNOWLEDGE_PATH
assert _KNOWLEDGE_PATH.exists(), f"Knowledge file missing: {_KNOWLEDGE_PATH}"
def test_knowledge_loads(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import _KNOWLEDGE_PATH
data = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8"))
assert isinstance(data, dict)
assert "domains" in data
def test_knowledge_has_domains(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import _KNOWLEDGE_PATH
data = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8"))
domains = data["domains"]
assert len(domains) >= 10, f"Expected 10+ domains, got {len(domains)}"
def test_each_domain_has_benchmarks_and_baselines(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import _KNOWLEDGE_PATH
data = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8"))
for did, info in data["domains"].items():
assert "keywords" in info, f"Domain {did} missing keywords"
assert "standard_benchmarks" in info, f"Domain {did} missing benchmarks"
assert "common_baselines" in info, f"Domain {did} missing baselines"
assert len(info["standard_benchmarks"]) > 0, f"Domain {did} has 0 benchmarks"
assert len(info["common_baselines"]) > 0, f"Domain {did} has 0 baselines"
def test_benchmark_entries_have_required_fields(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import _KNOWLEDGE_PATH
data = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8"))
for did, info in data["domains"].items():
for b in info["standard_benchmarks"]:
assert "name" in b, f"Benchmark in {did} missing name"
assert "tier" in b, f"Benchmark {b.get('name')} in {did} missing tier"
assert b["tier"] in (1, 2, 3), f"Invalid tier for {b.get('name')}"
def test_baseline_entries_have_required_fields(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import _KNOWLEDGE_PATH
data = yaml.safe_load(_KNOWLEDGE_PATH.read_text(encoding="utf-8"))
for did, info in data["domains"].items():
for bl in info["common_baselines"]:
assert "name" in bl, f"Baseline in {did} missing name"
assert "source" in bl, f"Baseline {bl.get('name')} in {did} missing source"
assert "paper" in bl, f"Baseline {bl.get('name')} in {did} missing paper"
# ---------------------------------------------------------------------------
# Surveyor tests
# ---------------------------------------------------------------------------
class TestSurveyor:
"""Test SurveyorAgent domain matching and local search."""
def test_domain_matching_image_classification(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
domains = agent._match_domains(
"Image Classification with Contrastive Learning"
)
assert "image_classification" in domains
def test_domain_matching_rl(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
domains = agent._match_domains(
"Reinforcement Learning for Continuous Control"
)
assert "reinforcement_learning" in domains
def test_domain_matching_knowledge_distillation(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
domains = agent._match_domains(
"Knowledge Distillation with Feature Alignment"
)
assert "knowledge_distillation" in domains
def test_domain_matching_multiple(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
domains = agent._match_domains(
"Self-Supervised Contrastive Learning for Image Classification"
)
assert len(domains) >= 2
def test_local_candidates_returns_benchmarks(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
result = agent._get_local_candidates(["image_classification"])
assert len(result["benchmarks"]) > 0
assert len(result["baselines"]) > 0
def test_execute_returns_benchmarks(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
result = agent.execute({
"topic": "Image Classification with Data Augmentation",
"hypothesis": "Novel augmentation improves accuracy",
})
assert result.success
assert len(result.data["benchmarks"]) > 0
def test_execute_with_unknown_topic_uses_llm_fallback(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
llm = FakeLLM([json.dumps({
"benchmarks": [{"name": "CustomDS", "tier": 2}],
"baselines": [{"name": "CustomBL", "source": "custom", "paper": "X"}],
"rationale": "test",
})])
agent = SurveyorAgent(llm, enable_hf_search=False)
result = agent.execute({
"topic": "Completely Novel Alien Technology Classification",
"hypothesis": "",
})
assert result.success
assert result.data["llm_fallback_used"]
def test_extract_search_keywords(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
kws = SurveyorAgent._extract_search_keywords(
"Novel Approach for Image Classification using Contrastive Learning"
)
assert len(kws) >= 1
for kw in kws:
assert "novel" not in kw.lower()
assert "using" not in kw.lower()
def test_execute_empty_topic_fails(self) -> None:
from researchclaw.agents.benchmark_agent.surveyor import SurveyorAgent
agent = SurveyorAgent(FakeLLM(), enable_hf_search=False)
result = agent.execute({"topic": ""})
assert not result.success
# ---------------------------------------------------------------------------
# Selector tests
# ---------------------------------------------------------------------------
class TestSelector:
"""Test SelectorAgent filtering and ranking logic."""
@pytest.fixture()
def benchmarks(self) -> list[dict]:
return [
{"name": "CIFAR-10", "tier": 1, "size_mb": 170, "origin": "knowledge_base",
"metrics": ["accuracy"]},
{"name": "CIFAR-100", "tier": 1, "size_mb": 170, "origin": "knowledge_base",
"metrics": ["accuracy"]},
{"name": "Tiny-ImageNet", "tier": 2, "size_mb": 237, "origin": "knowledge_base",
"metrics": ["top1_accuracy"]},
{"name": "ImageNet-1K", "tier": 3, "size_mb": 168000, "origin": "knowledge_base",
"metrics": ["top1_accuracy"]},
{"name": "hf/custom-ds", "tier": 2, "size_mb": 500, "origin": "huggingface_hub",
"downloads": 1000},
]
@pytest.fixture()
def baselines(self) -> list[dict]:
return [
{"name": "ResNet-18", "origin": "knowledge_base", "pip": [],
"paper": "He et al."},
{"name": "ViT-B/16", "origin": "knowledge_base", "pip": ["timm"],
"paper": "Dosovitskiy et al."},
]
def test_filter_excludes_tier3(self, benchmarks: list[dict]) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
agent = SelectorAgent(FakeLLM(), tier_limit=2)
filtered = agent._filter_benchmarks(benchmarks)
names = [b["name"] for b in filtered]
assert "ImageNet-1K" not in names
assert "CIFAR-10" in names
def test_filter_network_none_only_tier1(self, benchmarks: list[dict]) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
agent = SelectorAgent(FakeLLM(), network_policy="none")
filtered = agent._filter_benchmarks(benchmarks)
for b in filtered:
assert b["tier"] == 1
def test_ranking_prefers_tier1(self, benchmarks: list[dict]) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
agent = SelectorAgent(FakeLLM())
filtered = agent._filter_benchmarks(benchmarks)
ranked = agent._rank_benchmarks(filtered)
# Tier 1 should come first
assert ranked[0]["tier"] == 1
def test_ranking_prefers_knowledge_base(self, benchmarks: list[dict]) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
agent = SelectorAgent(FakeLLM())
filtered = agent._filter_benchmarks(benchmarks)
ranked = agent._rank_benchmarks(filtered)
# Knowledge base entries should precede HF entries of same tier
kb_indices = [i for i, b in enumerate(ranked) if b["origin"] == "knowledge_base"]
hf_indices = [i for i, b in enumerate(ranked) if b["origin"] == "huggingface_hub"]
if kb_indices and hf_indices:
assert min(kb_indices) < min(hf_indices)
def test_execute_selects_minimum(self, benchmarks: list[dict],
baselines: list[dict]) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
llm = FakeLLM([json.dumps({
"primary_benchmark": "CIFAR-10",
"secondary_benchmarks": ["CIFAR-100"],
"selected_baselines": ["ResNet-18", "ViT-B/16"],
"rationale": "Standard benchmarks",
"experiment_notes": "",
})])
agent = SelectorAgent(llm, min_benchmarks=1, min_baselines=2)
result = agent.execute({
"topic": "Image Classification",
"survey": {"benchmarks": benchmarks, "baselines": baselines},
})
assert result.success
assert len(result.data["selected_benchmarks"]) >= 1
assert len(result.data["selected_baselines"]) >= 2
# ---------------------------------------------------------------------------
# Acquirer tests
# ---------------------------------------------------------------------------
class TestAcquirer:
"""Test AcquirerAgent code generation."""
def test_generate_setup_script_tier1_only(self) -> None:
from researchclaw.agents.benchmark_agent.acquirer import AcquirerAgent
agent = AcquirerAgent(FakeLLM())
script = agent._generate_setup_script(
[{"name": "CIFAR-10", "tier": 1, "api": "torchvision..."}], []
)
# Tier 1 datasets don't need setup scripts
assert script == ""
def test_generate_setup_script_tier2(self) -> None:
from researchclaw.agents.benchmark_agent.acquirer import AcquirerAgent
agent = AcquirerAgent(FakeLLM())
script = agent._generate_setup_script(
[{"name": "IMDB", "tier": 2,
"api": "datasets.load_dataset('imdb', cache_dir='/workspace/data/hf')"}],
[],
)
assert "download_datasets" in script
assert "load_dataset" in script
def test_generate_requirements_filters_builtin(self) -> None:
from researchclaw.agents.benchmark_agent.acquirer import AcquirerAgent
agent = AcquirerAgent(FakeLLM())
reqs = agent._generate_requirements(["torch", "numpy", "xgboost", "timm"])
assert "torch" not in reqs
assert "numpy" not in reqs
assert "timm" not in reqs
assert "xgboost" in reqs
def test_strip_fences(self) -> None:
from researchclaw.agents.benchmark_agent.acquirer import AcquirerAgent
code = "```python\nimport torch\n```"
assert AcquirerAgent._strip_fences(code) == "import torch"
def test_execute_generates_code(self) -> None:
from researchclaw.agents.benchmark_agent.acquirer import AcquirerAgent
llm = FakeLLM([
"import torchvision\ndef get_datasets(): pass",
"import torch.nn as nn\ndef get_baselines(): pass",
])
agent = AcquirerAgent(llm)
result = agent.execute({
"topic": "Image Classification",
"selection": {
"selected_benchmarks": [
{"name": "CIFAR-10", "tier": 1, "role": "primary",
"api": "torchvision.datasets.CIFAR10(...)"},
],
"selected_baselines": [
{"name": "ResNet-18", "source": "torchvision.models.resnet18()",
"paper": "He et al.", "pip": []},
],
"required_pip": [],
},
})
assert result.success
assert result.data["data_loader_code"]
# ---------------------------------------------------------------------------
# Validator tests
# ---------------------------------------------------------------------------
class TestValidator:
"""Test ValidatorAgent code validation."""
def test_syntax_check_valid(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
agent = ValidatorAgent(FakeLLM())
errors = agent._check_syntax("import torch\nx = 1 + 2", "test")
assert errors == []
def test_syntax_check_invalid(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
agent = ValidatorAgent(FakeLLM())
errors = agent._check_syntax("def foo(\n x = ", "test")
assert len(errors) > 0
assert "SyntaxError" in errors[0]
def test_import_check_builtin_ok(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
agent = ValidatorAgent(FakeLLM())
warnings = agent._check_imports("import torch\nimport numpy", "test", [])
assert warnings == []
def test_import_check_unknown(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
agent = ValidatorAgent(FakeLLM())
warnings = agent._check_imports("import some_obscure_lib", "test", [])
assert len(warnings) > 0
def test_import_check_with_requirements(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
agent = ValidatorAgent(FakeLLM())
warnings = agent._check_imports(
"import xgboost", "test", ["xgboost"],
)
assert warnings == []
def test_execute_passes_valid_code(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
llm = FakeLLM([json.dumps({
"passed": True,
"issues": [],
"suggestions": [],
"severity": "none",
})])
agent = ValidatorAgent(llm)
result = agent.execute({
"acquisition": {
"data_loader_code": "import torch\ndef get_datasets(): pass",
"baseline_code": "import torch.nn as nn\ndef get_baselines(): pass",
"setup_code": "",
"requirements": "",
"benchmark_names": ["CIFAR-10"],
"baseline_names": ["ResNet-18"],
},
})
assert result.success
assert result.data["passed"]
def test_execute_fails_syntax_error(self) -> None:
from researchclaw.agents.benchmark_agent.validator import ValidatorAgent
agent = ValidatorAgent(FakeLLM())
result = agent.execute({
"acquisition": {
"data_loader_code": "def foo(\n x = ",
"baseline_code": "",
"setup_code": "",
"requirements": "",
"benchmark_names": [],
"baseline_names": [],
},
})
assert not result.data["passed"]
assert len(result.data["errors"]) > 0
# ---------------------------------------------------------------------------
# Orchestrator tests
# ---------------------------------------------------------------------------
class TestOrchestrator:
"""Test BenchmarkOrchestrator end-to-end."""
def test_orchestrate_produces_plan(self, tmp_path: Path) -> None:
from researchclaw.agents.benchmark_agent.orchestrator import (
BenchmarkAgentConfig,
BenchmarkOrchestrator,
)
responses = [
# Selector LLM response
json.dumps({
"primary_benchmark": "CIFAR-10",
"secondary_benchmarks": ["CIFAR-100"],
"selected_baselines": ["ResNet-18", "ViT-B/16"],
"rationale": "Standard CV benchmarks",
"experiment_notes": "Use standard augmentation",
}),
# Acquirer: data_loader_code
"import torchvision\ndef get_datasets(data_root='/workspace/data'):\n return {}",
# Acquirer: baseline_code
"import torch.nn as nn\ndef get_baselines(num_classes=10):\n return {}",
# Validator: LLM review
json.dumps({
"passed": True,
"issues": [],
"suggestions": ["Add transforms"],
"severity": "none",
}),
]
cfg = BenchmarkAgentConfig(enable_hf_search=False)
orchestrator = BenchmarkOrchestrator(
FakeLLM(responses),
config=cfg,
stage_dir=tmp_path / "benchmark_agent",
)
plan = orchestrator.orchestrate({
"topic": "Image Classification with Data Augmentation",
"hypothesis": "Novel augmentation improves accuracy",
})
assert len(plan.selected_benchmarks) >= 1
assert len(plan.selected_baselines) >= 1
assert plan.validation_passed
assert plan.total_llm_calls > 0
assert plan.elapsed_sec > 0
def test_orchestrate_saves_artifacts(self, tmp_path: Path) -> None:
from researchclaw.agents.benchmark_agent.orchestrator import (
BenchmarkAgentConfig,
BenchmarkOrchestrator,
)
responses = [
json.dumps({
"primary_benchmark": "CIFAR-10",
"secondary_benchmarks": [],
"selected_baselines": ["ResNet-18"],
"rationale": "test",
"experiment_notes": "",
}),
"def get_datasets(): pass",
"def get_baselines(): pass",
json.dumps({"passed": True, "issues": [], "suggestions": [], "severity": "none"}),
]
stage_dir = tmp_path / "benchmark_agent"
cfg = BenchmarkAgentConfig(enable_hf_search=False)
orchestrator = BenchmarkOrchestrator(
FakeLLM(responses),
config=cfg,
stage_dir=stage_dir,
)
orchestrator.orchestrate({
"topic": "Image Classification",
"hypothesis": "",
})
assert (stage_dir / "survey_results.json").exists()
assert (stage_dir / "selection_results.json").exists()
assert (stage_dir / "benchmark_plan.json").exists()
def test_plan_to_prompt_block(self) -> None:
from researchclaw.agents.benchmark_agent.orchestrator import BenchmarkPlan
plan = BenchmarkPlan(
selected_benchmarks=[
{"name": "CIFAR-10", "role": "primary", "metrics": ["accuracy"],
"api": "torchvision.datasets.CIFAR10(...)"},
],
selected_baselines=[
{"name": "ResNet-18", "source": "torchvision.models.resnet18()",
"paper": "He et al."},
],
data_loader_code="def get_datasets(): pass",
baseline_code="def get_baselines(): pass",
)
block = plan.to_prompt_block()
assert "CIFAR-10" in block
assert "ResNet-18" in block
assert "get_datasets" in block
assert "get_baselines" in block
def test_plan_to_dict_serializable(self) -> None:
from researchclaw.agents.benchmark_agent.orchestrator import BenchmarkPlan
plan = BenchmarkPlan(
selected_benchmarks=[{"name": "test"}],
data_loader_code="code",
)
d = plan.to_dict()
# Should be JSON-serializable
json_str = json.dumps(d)
assert "test" in json_str
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
class TestConfig:
"""Test BenchmarkAgentConfig in config.py."""
def test_default_config_has_benchmark_agent(self) -> None:
from researchclaw.config import ExperimentConfig
cfg = ExperimentConfig()
assert hasattr(cfg, "benchmark_agent")
assert cfg.benchmark_agent.enabled is True
def test_parse_benchmark_agent_config(self) -> None:
from researchclaw.config import _parse_benchmark_agent_config
cfg = _parse_benchmark_agent_config({
"enabled": False,
"tier_limit": 1,
"min_baselines": 3,
})
assert cfg.enabled is False
assert cfg.tier_limit == 1
assert cfg.min_baselines == 3
def test_parse_benchmark_agent_config_empty(self) -> None:
from researchclaw.config import _parse_benchmark_agent_config
cfg = _parse_benchmark_agent_config({})
assert cfg.enabled is True
assert cfg.tier_limit == 2
# ---------------------------------------------------------------------------
# Base agent tests
# ---------------------------------------------------------------------------
class TestBaseAgent:
"""Test the base agent class."""
def test_parse_json_direct(self) -> None:
from researchclaw.agents.base import BaseAgent
result = BaseAgent._parse_json('{"key": "value"}')
assert result == {"key": "value"}
def test_parse_json_fenced(self) -> None:
from researchclaw.agents.base import BaseAgent
result = BaseAgent._parse_json('Some text\n```json\n{"key": 1}\n```\nMore text')
assert result == {"key": 1}
def test_parse_json_embedded(self) -> None:
from researchclaw.agents.base import BaseAgent
result = BaseAgent._parse_json('Here is the result: {"a": 2} end')
assert result == {"a": 2}
def test_parse_json_invalid(self) -> None:
from researchclaw.agents.base import BaseAgent
result = BaseAgent._parse_json("no json here at all")
assert result is None
# ---------------------------------------------------------------------------
# Required baselines injection (Improvement E)
# ---------------------------------------------------------------------------
class TestRequiredBaselines:
"""Test that required baselines are injected from knowledge base."""
def test_inject_required_baselines_image_classification(self) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
llm = FakeLLM()
agent = SelectorAgent(llm, min_baselines=1)
selected: list[dict[str, Any]] = [
{"name": "EfficientNet-B0", "origin": "knowledge_base"},
]
injected = agent._inject_required_baselines(
"image classification on CIFAR-10",
selected,
[],
)
# Should inject ResNet-50 and ViT-B/16 (required for image_classification)
injected_names = {b["name"] for b in injected}
assert "ResNet-50" in injected_names
assert "ViT-B/16" in injected_names
# Already-present baselines should not be duplicated
assert sum(1 for b in selected if b["name"] == "EfficientNet-B0") == 1
def test_inject_required_baselines_no_duplicates(self) -> None:
from researchclaw.agents.benchmark_agent.selector import SelectorAgent
llm = FakeLLM()
agent = SelectorAgent(llm, min_baselines=1)
selected: list[dict[str, Any]] = [
{"name": "ResNet-50", "origin": "knowledge_base"},
{"name": "ViT-B/16", "origin": "llm_suggestion"},
]
injected = agent._inject_required_baselines(
"image classification on CIFAR-10",
selected,
[],
)
# Both are already present → nothing should be injected
assert len(injected) == 0
================================================
FILE: tests/test_calendar.py
================================================
"""Tests for researchclaw.calendar — Conference Deadline Calendar (Agent D4).
15+ tests covering deadlines, planner, and reminder modules.
"""
from __future__ import annotations
from datetime import date, timedelta
from pathlib import Path
import pytest
import yaml
from researchclaw.calendar.deadlines import Conference, ConferenceCalendar
from researchclaw.calendar.planner import SubmissionPlanner
from researchclaw.calendar.reminder import Reminder, ReminderCalculator
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_conference(
name: str = "TestConf",
full_name: str = "Test Conference",
domains: tuple[str, ...] = ("ml",),
tier: int = 1,
abstract_deadline: date | None = None,
paper_deadline: date | None = None,
**kwargs,
) -> Conference:
return Conference(
name=name,
full_name=full_name,
domains=domains,
tier=tier,
abstract_deadline=abstract_deadline,
paper_deadline=paper_deadline,
**kwargs,
)
def _future(days: int) -> date:
return date.today() + timedelta(days=days)
def _past(days: int) -> date:
return date.today() - timedelta(days=days)
# ===================================================================
# Conference dataclass tests
# ===================================================================
class TestConference:
def test_from_dict_minimal(self):
data = {"name": "NeurIPS"}
conf = Conference.from_dict(data)
assert conf.name == "NeurIPS"
assert conf.tier == 3 # default
assert conf.domains == ()
def test_from_dict_full(self):
data = {
"name": "ICML",
"full_name": "International Conference on Machine Learning",
"domains": ["ml", "ai"],
"tier": 1,
"url": "https://icml.cc",
"abstract_deadline": "2026-06-01",
"paper_deadline": "2026-06-08",
}
conf = Conference.from_dict(data)
assert conf.name == "ICML"
assert conf.full_name == "International Conference on Machine Learning"
assert conf.domains == ("ml", "ai")
assert conf.tier == 1
assert conf.abstract_deadline == date(2026, 6, 1)
assert conf.paper_deadline == date(2026, 6, 8)
def test_from_dict_date_passthrough(self):
"""date objects in YAML are already date instances."""
data = {
"name": "X",
"abstract_deadline": date(2026, 12, 1),
}
conf = Conference.from_dict(data)
assert conf.abstract_deadline == date(2026, 12, 1)
def test_next_deadline_returns_earliest_future(self):
conf = _make_conference(
abstract_deadline=_future(10),
paper_deadline=_future(20),
)
assert conf.next_deadline == _future(10)
def test_next_deadline_skips_past(self):
conf = _make_conference(
abstract_deadline=_past(5),
paper_deadline=_future(15),
)
assert conf.next_deadline == _future(15)
def test_next_deadline_none_when_all_past(self):
conf = _make_conference(
abstract_deadline=_past(10),
paper_deadline=_past(5),
)
assert conf.next_deadline is None
def test_days_until_deadline(self):
conf = _make_conference(paper_deadline=_future(30))
assert conf.days_until_deadline == 30
def test_days_until_deadline_none(self):
conf = _make_conference()
assert conf.days_until_deadline is None
# ===================================================================
# ConferenceCalendar tests
# ===================================================================
class TestConferenceCalendar:
def test_load_from_yaml(self, tmp_path: Path):
data = {
"conferences": [
{
"name": "TestConf",
"domains": ["ml"],
"tier": 1,
"paper_deadline": (_future(30)).isoformat(),
},
{
"name": "TestConf2",
"domains": ["cv"],
"tier": 2,
"paper_deadline": (_future(60)).isoformat(),
},
]
}
yaml_path = tmp_path / "conferences.yaml"
yaml_path.write_text(yaml.dump(data), encoding="utf-8")
cal = ConferenceCalendar.load(yaml_path)
assert len(cal.conferences) == 2
assert cal.conferences[0].name == "TestConf"
def test_load_skips_invalid_entries(self, tmp_path: Path):
data = {
"conferences": [
{"name": "Valid", "tier": 1},
{"invalid": "no name field"},
]
}
yaml_path = tmp_path / "conf.yaml"
yaml_path.write_text(yaml.dump(data), encoding="utf-8")
cal = ConferenceCalendar.load(yaml_path)
assert len(cal.conferences) == 1
def test_get_upcoming_filters_by_days(self):
confs = [
_make_conference(name="Soon", paper_deadline=_future(10)),
_make_conference(name="Far", paper_deadline=_future(200)),
]
cal = ConferenceCalendar(confs)
upcoming = cal.get_upcoming(days=90)
assert len(upcoming) == 1
assert upcoming[0].name == "Soon"
def test_get_upcoming_filters_by_domain(self):
confs = [
_make_conference(name="ML", domains=("ml",), paper_deadline=_future(10)),
_make_conference(name="CV", domains=("cv",), paper_deadline=_future(10)),
]
cal = ConferenceCalendar(confs)
result = cal.get_upcoming(domains=["ml"], days=90)
assert len(result) == 1
assert result[0].name == "ML"
def test_get_upcoming_filters_by_tier(self):
confs = [
_make_conference(name="T1", tier=1, paper_deadline=_future(10)),
_make_conference(name="T3", tier=3, paper_deadline=_future(10)),
]
cal = ConferenceCalendar(confs)
result = cal.get_upcoming(tier=1, days=90)
assert len(result) == 1
assert result[0].name == "T1"
def test_get_by_name_case_insensitive(self):
confs = [_make_conference(name="NeurIPS")]
cal = ConferenceCalendar(confs)
assert cal.get_by_name("neurips") is not None
assert cal.get_by_name("NEURIPS") is not None
assert cal.get_by_name("nonexistent") is None
def test_get_by_domain(self):
confs = [
_make_conference(name="A", domains=("ml", "ai")),
_make_conference(name="B", domains=("cv",)),
]
cal = ConferenceCalendar(confs)
assert len(cal.get_by_domain("ml")) == 1
assert len(cal.get_by_domain("cv")) == 1
assert len(cal.get_by_domain("nlp")) == 0
def test_format_upcoming_no_deadlines(self):
cal = ConferenceCalendar([])
output = cal.format_upcoming()
assert "No upcoming deadlines" in output
def test_format_upcoming_with_deadlines(self):
confs = [_make_conference(
name="ICML", paper_deadline=_future(15), url="https://icml.cc"
)]
cal = ConferenceCalendar(confs)
output = cal.format_upcoming(days=90)
assert "ICML" in output
assert "15 days left" in output
assert "https://icml.cc" in output
def test_load_builtin(self):
"""Built-in conferences.yaml should load without error."""
cal = ConferenceCalendar.load_builtin()
assert isinstance(cal.conferences, list)
# ===================================================================
# SubmissionPlanner tests
# ===================================================================
class TestSubmissionPlanner:
def test_plan_basic(self):
conf = _make_conference(name="TestConf", paper_deadline=_future(100))
cal = ConferenceCalendar([conf])
planner = SubmissionPlanner(cal)
plan = planner.plan("TestConf", start_date=date.today())
assert plan["venue"] == "TestConf"
assert plan["total_days"] == 100
assert len(plan["milestones"]) == 8 # 8 stages in STAGE_PROPORTIONS
def test_plan_unknown_venue(self):
cal = ConferenceCalendar([])
planner = SubmissionPlanner(cal)
result = planner.plan("NonExistent")
assert "error" in result
def test_plan_past_deadline(self):
conf = _make_conference(name="Past", paper_deadline=_past(5))
cal = ConferenceCalendar([conf])
planner = SubmissionPlanner(cal)
result = planner.plan("Past", start_date=date.today())
assert "error" in result
assert "passed" in result["error"]
def test_format_plan(self):
conf = _make_conference(name="ICML", paper_deadline=_future(60))
cal = ConferenceCalendar([conf])
planner = SubmissionPlanner(cal)
output = planner.format_plan("ICML", start_date=date.today())
assert "Submission Plan for ICML" in output
assert "Milestones:" in output
def test_format_plan_error(self):
cal = ConferenceCalendar([])
planner = SubmissionPlanner(cal)
output = planner.format_plan("None")
assert "Error:" in output
# ===================================================================
# ReminderCalculator tests
# ===================================================================
class TestReminderCalculator:
def test_check_fires_on_matching_day(self):
deadline = date.today() + timedelta(days=7)
conf = _make_conference(name="Conf", paper_deadline=deadline)
calc = ReminderCalculator(reminder_days=(7,))
reminders = calc.check([conf])
assert len(reminders) == 1
assert reminders[0].days_until == 7
def test_check_no_fire_on_non_matching_day(self):
deadline = date.today() + timedelta(days=8)
conf = _make_conference(name="Conf", paper_deadline=deadline)
calc = ReminderCalculator(reminder_days=(7,))
reminders = calc.check([conf])
assert len(reminders) == 0
def test_check_skips_past_deadlines(self):
conf = _make_conference(name="Conf", paper_deadline=_past(3))
calc = ReminderCalculator(reminder_days=(3,))
assert len(calc.check([conf])) == 0
def test_urgency_critical(self):
assert ReminderCalculator._classify_urgency(1) == "critical"
assert ReminderCalculator._classify_urgency(3) == "critical"
def test_urgency_warning(self):
assert ReminderCalculator._classify_urgency(7) == "warning"
assert ReminderCalculator._classify_urgency(14) == "warning"
def test_urgency_info(self):
assert ReminderCalculator._classify_urgency(30) == "info"
def test_get_active_reminders(self):
confs = [
_make_conference(name="Soon", paper_deadline=_future(5)),
_make_conference(name="Far", paper_deadline=_future(100)),
]
calc = ReminderCalculator(reminder_days=(30, 14, 7, 3, 1))
active = calc.get_active_reminders(confs)
assert len(active) == 1
assert active[0].conference_name == "Soon"
def test_format_reminders_empty(self):
calc = ReminderCalculator()
assert "No upcoming" in calc.format_reminders([])
def test_format_reminders_with_data(self):
r = Reminder(
conference_name="ICML",
deadline_type="paper",
deadline_date=_future(3),
days_until=3,
urgency="critical",
)
calc = ReminderCalculator()
output = calc.format_reminders([r])
assert "ICML" in output
assert "!!!" in output
def test_reminder_frozen(self):
r = Reminder("X", "paper", date.today(), 5, "info")
with pytest.raises(AttributeError):
r.days_until = 10 # type: ignore[misc]
================================================
FILE: tests/test_cli.py
================================================
"""Tests for CLI setup helpers."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from researchclaw import cli
def test_install_opencode_uses_which_resolved_npm_path():
mock_result = MagicMock()
mock_result.returncode = 0
with patch(
"researchclaw.cli.shutil.which",
return_value=r"C:\Program Files\nodejs\npm.cmd",
), patch("researchclaw.cli.subprocess.run", return_value=mock_result) as run_mock:
assert cli._install_opencode() is True
run_mock.assert_called_once()
assert run_mock.call_args.args[0][0] == r"C:\Program Files\nodejs\npm.cmd"
def test_install_opencode_returns_false_when_npm_missing():
with patch("researchclaw.cli.shutil.which", return_value=None):
assert cli._install_opencode() is False
def test_is_opencode_installed_uses_which_resolved_path():
mock_result = MagicMock()
mock_result.returncode = 0
with patch(
"researchclaw.cli.shutil.which",
return_value=r"C:\Users\tester\AppData\Roaming\npm\opencode.cmd",
), patch("researchclaw.cli.subprocess.run", return_value=mock_result) as run_mock:
assert cli._is_opencode_installed() is True
run_mock.assert_called_once()
assert run_mock.call_args.args[0][0].endswith("opencode.cmd")
================================================
FILE: tests/test_code_agent.py
================================================
"""Tests for the advanced multi-phase code generation agent (F-02)."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import pytest
from researchclaw.llm.client import LLMResponse
from researchclaw.pipeline.code_agent import (
CodeAgent,
CodeAgentConfig,
CodeAgentResult,
SolutionNode,
_SimpleResult,
)
from researchclaw.prompts import PromptManager
# ---------------------------------------------------------------------------
# Test fixtures
# ---------------------------------------------------------------------------
class FakeLLM:
"""Fake LLM client that returns configurable responses."""
def __init__(self, responses: list[str] | None = None):
self.calls: list[dict[str, Any]] = []
self._responses = list(responses or [])
self._call_idx = 0
def chat(self, messages: list[dict], **kwargs: Any) -> LLMResponse:
self.calls.append({"messages": messages, **kwargs})
if self._responses:
text = self._responses[min(self._call_idx, len(self._responses) - 1)]
else:
text = '```filename:main.py\nprint("hello")\n```'
self._call_idx += 1
return LLMResponse(content=text, model="fake-model")
@dataclass
class FakeSandboxResult:
returncode: int = 0
stdout: str = "primary_metric: 0.95"
stderr: str = ""
elapsed_sec: float = 1.0
metrics: dict[str, object] = field(default_factory=dict)
timed_out: bool = False
class FakeSandbox:
"""Fake sandbox for testing."""
def __init__(self, results: list[FakeSandboxResult] | None = None):
self.runs: list[Path] = []
self._results = list(results or [FakeSandboxResult()])
self._run_idx = 0
def run_project(
self, project_dir: Path, *, entry_point: str = "main.py",
timeout_sec: int = 300,
) -> FakeSandboxResult:
self.runs.append(project_dir)
result = self._results[min(self._run_idx, len(self._results) - 1)]
self._run_idx += 1
return result
@pytest.fixture()
def stage_dir(tmp_path: Path) -> Path:
d = tmp_path / "stage-10"
d.mkdir()
return d
@pytest.fixture()
def pm() -> PromptManager:
return PromptManager()
# ---------------------------------------------------------------------------
# CodeAgentConfig tests
# ---------------------------------------------------------------------------
class TestCodeAgentConfig:
def test_default_values(self) -> None:
cfg = CodeAgentConfig()
assert cfg.enabled is True
assert cfg.architecture_planning is True
assert cfg.exec_fix_max_iterations == 3
assert cfg.tree_search_enabled is False
assert cfg.review_max_rounds == 2
def test_custom_values(self) -> None:
cfg = CodeAgentConfig(
enabled=False,
exec_fix_max_iterations=5,
tree_search_enabled=True,
tree_search_candidates=5,
)
assert cfg.enabled is False
assert cfg.exec_fix_max_iterations == 5
assert cfg.tree_search_enabled is True
assert cfg.tree_search_candidates == 5
# ---------------------------------------------------------------------------
# Phase 1: Architecture Planning
# ---------------------------------------------------------------------------
class TestPhase1Architecture:
def test_architecture_planning_produces_spec(
self, stage_dir: Path, pm: PromptManager,
) -> None:
arch_yaml = (
"```yaml\nfiles:\n - name: main.py\n purpose: entry point\n"
" - name: models.py\n purpose: models\n```"
)
code = '```filename:main.py\nprint("metric: 1.0")\n```'
# reviewer approves immediately
review = '{"verdict": "APPROVE", "score": 8, "critical_issues": []}'
llm = FakeLLM(responses=[arch_yaml, code, review])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(architecture_planning=True),
stage_dir=stage_dir,
)
result = agent.generate(
topic="test topic", exp_plan="objectives: test",
metric="accuracy", pkg_hint="numpy, torch",
)
assert result.architecture_spec
assert "main.py" in result.architecture_spec
assert result.files
assert result.total_llm_calls >= 2 # arch + codegen + review
def test_architecture_planning_disabled(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nprint("metric: 1.0")\n```'
review = '{"verdict": "APPROVE", "score": 9, "critical_issues": []}'
llm = FakeLLM(responses=[code, review])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(architecture_planning=False),
stage_dir=stage_dir,
)
result = agent.generate(
topic="test", exp_plan="plan", metric="m", pkg_hint="",
)
assert result.architecture_spec == ""
assert result.files
# First call should be code_generation, not the architecture planning prompt
first_call_user = llm.calls[0]["messages"][0]["content"]
# The architecture planning prompt has "Design the architecture" phrasing
assert "design the architecture for an experiment" not in first_call_user.lower()
# ---------------------------------------------------------------------------
# Phase 2: Execution-in-the-Loop
# ---------------------------------------------------------------------------
class TestPhase2ExecFix:
def test_exec_fix_loop_fixes_crashing_code(
self, stage_dir: Path, pm: PromptManager,
) -> None:
# Initial code crashes, then fix succeeds
initial_code = '```filename:main.py\nraise RuntimeError("bug")\n```'
fixed_code = '```filename:main.py\nprint("metric: 1.0")\n```'
review = '{"verdict": "APPROVE", "score": 8, "critical_issues": []}'
llm = FakeLLM(responses=[
initial_code, # phase 2: initial generation (no arch)
fixed_code, # phase 2: exec-fix iteration
review, # phase 4: review
])
sandbox_results = [
FakeSandboxResult(returncode=1, stderr="RuntimeError: bug"),
FakeSandboxResult(returncode=0, stdout="metric: 1.0"),
]
fake_sandbox = FakeSandbox(results=sandbox_results)
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
exec_fix_max_iterations=3,
),
stage_dir=stage_dir,
sandbox_factory=lambda cfg, wd: fake_sandbox,
experiment_config=None,
)
result = agent.generate(
topic="test", exp_plan="plan", metric="metric", pkg_hint="",
)
assert result.files
assert result.total_sandbox_runs >= 1
def test_exec_fix_skipped_without_sandbox(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nprint("m: 1")\n```'
review = '{"verdict": "APPROVE", "score": 9, "critical_issues": []}'
llm = FakeLLM(responses=[code, review])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(architecture_planning=False),
stage_dir=stage_dir,
sandbox_factory=None,
)
result = agent.generate(
topic="t", exp_plan="p", metric="m", pkg_hint="",
)
assert result.total_sandbox_runs == 0
assert result.files
def test_exec_fix_max_iterations_respected(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nraise RuntimeError("persistent")\n```'
review = '{"verdict": "APPROVE", "score": 5, "critical_issues": []}'
llm = FakeLLM(responses=[code, code, code, code, review])
always_crash = FakeSandbox(
results=[FakeSandboxResult(returncode=1, stderr="RuntimeError")]
)
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
exec_fix_max_iterations=2,
),
stage_dir=stage_dir,
sandbox_factory=lambda cfg, wd: always_crash,
experiment_config=None,
)
result = agent.generate(
topic="t", exp_plan="p", metric="m", pkg_hint="",
)
# Should have exactly 2 sandbox runs (max iterations)
assert result.total_sandbox_runs == 2
# ---------------------------------------------------------------------------
# Phase 3: Solution Tree Search
# ---------------------------------------------------------------------------
class TestPhase3TreeSearch:
def test_tree_search_generates_multiple_candidates(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code_a = '```filename:main.py\nprint("metric: 0.5")\n```'
code_b = '```filename:main.py\nprint("metric: 0.9")\n```'
review = '{"verdict": "APPROVE", "score": 9, "critical_issues": []}'
llm = FakeLLM(responses=[code_a, code_b, review])
sandbox = FakeSandbox(results=[
FakeSandboxResult(returncode=0, stdout="metric: 0.5",
metrics={"metric": 0.5}),
FakeSandboxResult(returncode=0, stdout="metric: 0.9",
metrics={"metric": 0.9}),
])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
tree_search_enabled=True,
tree_search_candidates=2,
tree_search_max_depth=1,
),
stage_dir=stage_dir,
sandbox_factory=lambda cfg, wd: sandbox,
experiment_config=None,
)
result = agent.generate(
topic="t", exp_plan="p", metric="metric", pkg_hint="",
)
assert result.tree_nodes_explored >= 2
assert result.files
def test_tree_search_fixes_crashing_candidates(
self, stage_dir: Path, pm: PromptManager,
) -> None:
crash_code = '```filename:main.py\nraise ValueError("x")\n```'
fixed_code = '```filename:main.py\nprint("metric: 1.0")\n```'
review = '{"verdict": "APPROVE", "score": 8, "critical_issues": []}'
llm = FakeLLM(responses=[
crash_code, # candidate 0
crash_code, # candidate 1
fixed_code, # fix for candidate 0
fixed_code, # fix for candidate 1
review, # review
])
results_seq = [
FakeSandboxResult(returncode=1, stderr="ValueError: x"),
FakeSandboxResult(returncode=1, stderr="ValueError: x"),
FakeSandboxResult(returncode=0, stdout="metric: 1.0"),
FakeSandboxResult(returncode=0, stdout="metric: 1.0"),
]
sandbox = FakeSandbox(results=results_seq)
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
tree_search_enabled=True,
tree_search_candidates=2,
tree_search_max_depth=2,
),
stage_dir=stage_dir,
sandbox_factory=lambda cfg, wd: sandbox,
experiment_config=None,
)
result = agent.generate(
topic="t", exp_plan="p", metric="metric", pkg_hint="",
)
assert result.tree_nodes_explored >= 2
# ---------------------------------------------------------------------------
# Phase 4: Multi-Agent Review
# ---------------------------------------------------------------------------
class TestPhase4Review:
def test_review_approves_on_first_round(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nprint("m: 1")\n```'
review = '{"verdict": "APPROVE", "score": 9, "critical_issues": []}'
llm = FakeLLM(responses=[code, review])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
review_max_rounds=2,
),
stage_dir=stage_dir,
)
result = agent.generate(
topic="t", exp_plan="p", metric="m", pkg_hint="",
)
assert result.review_rounds == 1
def test_review_triggers_fix_on_critical_issues(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nprint("m: 1")\n```'
review1 = json.dumps({
"verdict": "REVISE",
"score": 3,
"critical_issues": ["Missing seed handling", "Wrong metric name"],
"suggestions": [],
})
fixed = '```filename:main.py\nimport random\nrandom.seed(42)\nprint("m: 1")\n```'
review2 = '{"verdict": "APPROVE", "score": 8, "critical_issues": []}'
llm = FakeLLM(responses=[code, review1, fixed, review2])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
review_max_rounds=3,
hard_validation=False, # Test focuses on review, not validation
),
stage_dir=stage_dir,
)
result = agent.generate(
topic="t", exp_plan="p", metric="m", pkg_hint="",
)
assert result.review_rounds == 2
assert result.total_llm_calls == 4 # codegen + review1 + fix + review2
def test_review_disabled(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nprint("m: 1")\n```'
llm = FakeLLM(responses=[code])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=False,
review_max_rounds=0,
hard_validation=False, # Test focuses on review, not validation
),
stage_dir=stage_dir,
)
result = agent.generate(
topic="t", exp_plan="p", metric="m", pkg_hint="",
)
assert result.review_rounds == 0
assert result.total_llm_calls == 1 # only codegen
# ---------------------------------------------------------------------------
# Full pipeline tests
# ---------------------------------------------------------------------------
class TestFullPipeline:
def test_all_phases_end_to_end(
self, stage_dir: Path, pm: PromptManager,
) -> None:
arch = "```yaml\nfiles:\n - name: main.py\n```"
code = '```filename:main.py\nprint("acc: 0.9")\n```'
review = '{"verdict": "APPROVE", "score": 9, "critical_issues": []}'
llm = FakeLLM(responses=[arch, code, review])
sandbox = FakeSandbox(results=[
FakeSandboxResult(returncode=0, stdout="acc: 0.9"),
])
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(
architecture_planning=True,
exec_fix_max_iterations=2,
review_max_rounds=1,
),
stage_dir=stage_dir,
sandbox_factory=lambda cfg, wd: sandbox,
experiment_config=None,
)
result = agent.generate(
topic="image classification", exp_plan="test plan",
metric="accuracy", pkg_hint="torch",
)
assert result.architecture_spec
assert "main.py" in result.files
assert result.total_llm_calls >= 3 # arch + code + review
assert result.total_sandbox_runs >= 1
assert result.review_rounds == 1
assert result.validation_log
def test_agent_writes_attempt_directories(
self, stage_dir: Path, pm: PromptManager,
) -> None:
code = '```filename:main.py\nprint("x: 1")\n```'
review = '{"verdict": "APPROVE", "score": 9, "critical_issues": []}'
llm = FakeLLM(responses=[code, review])
sandbox = FakeSandbox()
agent = CodeAgent(
llm=llm, prompts=pm,
config=CodeAgentConfig(architecture_planning=False),
stage_dir=stage_dir,
sandbox_factory=lambda cfg, wd: sandbox,
experiment_config=None,
)
result = agent.generate(
topic="t", exp_plan="p", metric="x", pkg_hint="",
)
attempt_dir = stage_dir / "agent_runs" / "attempt_001"
assert attempt_dir.exists()
assert (attempt_dir / "main.py").exists()
# ---------------------------------------------------------------------------
# SolutionNode and scoring
# ---------------------------------------------------------------------------
class TestSolutionNodeScoring:
def test_score_running_node(self) -> None:
node = SolutionNode(
node_id="test",
files={"main.py": "x"},
runs_ok=True,
stdout="lots of output " * 20,
metrics={"metric": 0.95},
)
score = CodeAgent._score_node(node, "metric")
assert score >= 2.0 # runs_ok(1.0) + output(0.3) + metrics(0.5) + key(0.5)
def test_score_crashing_node(self) -> None:
node = SolutionNode(
node_id="test",
files={"main.py": "x"},
runs_ok=False,
stderr="Error: something broke",
)
score = CodeAgent._score_node(node, "metric")
assert score == 0.0 # no runs_ok, error penalty, max(0)
def test_score_partial_output(self) -> None:
node = SolutionNode(
node_id="test",
files={"main.py": "x"},
runs_ok=True,
stdout="short",
metrics={},
)
score = CodeAgent._score_node(node, "metric")
assert score == 1.0 # only runs_ok
# ---------------------------------------------------------------------------
# Helper methods
# ---------------------------------------------------------------------------
class TestHelpers:
def test_format_files(self) -> None:
files = {"main.py": "print(1)", "utils.py": "x = 2"}
formatted = CodeAgent._format_files(files)
assert "```filename:main.py" in formatted
assert "```filename:utils.py" in formatted
assert "print(1)" in formatted
def test_parse_json_direct(self) -> None:
result = CodeAgent._parse_json('{"score": 5}')
assert result == {"score": 5}
def test_parse_json_fenced(self) -> None:
text = 'Some text\n```json\n{"verdict": "APPROVE"}\n```\nmore text'
result = CodeAgent._parse_json(text)
assert result == {"verdict": "APPROVE"}
def test_parse_json_embedded(self) -> None:
text = 'The review is: {"score": 7, "verdict": "REVISE"} end'
result = CodeAgent._parse_json(text)
assert result is not None
assert result["score"] == 7
def test_parse_json_invalid(self) -> None:
result = CodeAgent._parse_json("not json at all")
assert result is None
def test_simple_result_defaults(self) -> None:
r = _SimpleResult()
assert r.returncode == 1
assert r.stdout == ""
assert r.timed_out is False
# ---------------------------------------------------------------------------
# Config integration test
# ---------------------------------------------------------------------------
class TestConfigIntegration:
def test_code_agent_config_in_experiment_config(self) -> None:
from researchclaw.config import CodeAgentConfig, ExperimentConfig
exp = ExperimentConfig()
assert hasattr(exp, "code_agent")
assert isinstance(exp.code_agent, CodeAgentConfig)
assert exp.code_agent.enabled is True
def test_code_agent_config_from_dict(self, tmp_path: Path) -> None:
from researchclaw.config import RCConfig
data = {
"project": {"name": "test", "mode": "docs-first"},
"research": {
"topic": "test",
"domains": ["ml"],
"daily_paper_count": 1,
"quality_threshold": 7.0,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {
"backend": "markdown",
"root": str(tmp_path / "kb"),
},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "TEST",
"api_key": "test-key",
"primary_model": "test",
"fallback_models": [],
},
"experiment": {
"mode": "sandbox",
"code_agent": {
"enabled": False,
"tree_search_enabled": True,
"tree_search_candidates": 5,
},
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
assert cfg.experiment.code_agent.enabled is False
assert cfg.experiment.code_agent.tree_search_enabled is True
assert cfg.experiment.code_agent.tree_search_candidates == 5
# ---------------------------------------------------------------------------
# Prompts integration test
# ---------------------------------------------------------------------------
class TestPromptsIntegration:
def test_architecture_planning_prompt_exists(self, pm: PromptManager) -> None:
sp = pm.sub_prompt(
"architecture_planning",
topic="image classification",
exp_plan="test plan",
metric="accuracy",
)
assert "architect" in sp.system.lower()
assert "accuracy" in sp.user
assert "image classification" in sp.user
def test_code_exec_fix_prompt_exists(self, pm: PromptManager) -> None:
sp = pm.sub_prompt(
"code_exec_fix",
stderr="ImportError: no module named foo",
stdout_tail="loading data...",
returncode="1",
files_context="```filename:main.py\nimport foo\n```",
)
assert "debug" in sp.system.lower() or "fix" in sp.system.lower()
assert "ImportError" in sp.user
def test_code_reviewer_prompt_exists(self, pm: PromptManager) -> None:
sp = pm.sub_prompt(
"code_reviewer",
topic="RL",
exp_plan="test plan",
metric="reward",
files_context="```filename:main.py\nprint('hi')\n```",
)
assert "review" in sp.system.lower()
assert "reward" in sp.user
assert "APPROVE" in sp.user or "REVISE" in sp.user
================================================
FILE: tests/test_code_searcher.py
================================================
"""Tests for the Code Searcher agent."""
from __future__ import annotations
import json
import time
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from researchclaw.agents.code_searcher.agent import CodeSearchAgent, CodeSearchResult
from researchclaw.agents.code_searcher.cache import SearchCache
from researchclaw.agents.code_searcher.github_client import (
CodeSnippet,
GitHubClient,
RepoAnalysis,
RepoInfo,
)
from researchclaw.agents.code_searcher.pattern_extractor import (
CodePatterns,
extract_patterns,
_heuristic_extract,
)
from researchclaw.agents.code_searcher.query_gen import (
generate_search_queries,
_heuristic_generate,
_extract_key_phrases,
)
from researchclaw.domains.detector import DomainProfile, get_profile
# ---------------------------------------------------------------------------
# Query Generation tests
# ---------------------------------------------------------------------------
class TestQueryGeneration:
def test_heuristic_generates_queries(self):
queries = _heuristic_generate(
topic="finite element method for Poisson equation",
domain_name="PDE Solvers",
libraries=["numpy", "scipy", "fenics"],
needs=["FEM assembly", "mesh generation"],
)
assert len(queries) >= 3
assert len(queries) <= 5
# Should include library names
any_lib = any("numpy" in q or "scipy" in q or "fenics" in q for q in queries)
assert any_lib
def test_heuristic_no_duplicates(self):
queries = _heuristic_generate(
topic="simple test",
domain_name="Test",
libraries=["numpy"],
needs=[],
)
# No exact duplicates
assert len(queries) == len(set(q.lower().strip() for q in queries))
def test_extract_key_phrases(self):
result = _extract_key_phrases("A Novel Approach for Image Classification Using Deep Learning")
# Should remove filler words
assert "novel" not in result.lower()
assert "using" not in result.lower()
def test_generate_without_llm(self):
queries = generate_search_queries(
topic="molecular dynamics simulation",
domain_name="Computational Physics",
core_libraries=["jax", "numpy"],
llm=None,
)
assert isinstance(queries, list)
assert len(queries) >= 2
# ---------------------------------------------------------------------------
# Pattern Extractor tests
# ---------------------------------------------------------------------------
class TestPatternExtractor:
def test_heuristic_extract_imports(self):
snippets = [
"import numpy as np\nimport scipy.sparse as sp\n\ndef solve():\n pass",
"from pyscf import gto, scf\nmol = gto.M(atom='H 0 0 0')",
]
patterns = _heuristic_extract(snippets)
assert len(patterns.api_patterns) > 0
assert any("numpy" in p for p in patterns.api_patterns)
def test_heuristic_extract_functions(self):
snippets = [
"class Solver:\n pass\ndef solve_pde():\n pass\ndef analyze():\n pass",
]
patterns = _heuristic_extract(snippets)
assert len(patterns.file_structure) > 0
def test_empty_snippets(self):
patterns = extract_patterns([], topic="test", domain_name="test")
assert not patterns.has_content
def test_code_patterns_to_prompt(self):
patterns = CodePatterns(
api_patterns=["import numpy as np\nresult = np.linalg.solve(A, b)"],
file_structure={"solver.py": "Main solver implementation"},
evaluation_patterns=["error = np.linalg.norm(x - x_exact)"],
)
ctx = patterns.to_prompt_context()
assert "numpy" in ctx
assert "solver.py" in ctx
assert "error" in ctx
def test_code_patterns_has_content(self):
empty = CodePatterns()
assert not empty.has_content
with_data = CodePatterns(api_patterns=["import x"])
assert with_data.has_content
# ---------------------------------------------------------------------------
# Search Cache tests
# ---------------------------------------------------------------------------
class TestSearchCache:
def test_put_and_get(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path, ttl_days=30)
data = {"api_patterns": ["import numpy"], "repos": []}
cache.put("ml_vision", "image classification", data)
result = cache.get("ml_vision", "image classification")
assert result is not None
assert result["api_patterns"] == ["import numpy"]
def test_cache_miss(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path)
result = cache.get("unknown", "unknown topic")
assert result is None
def test_cache_expiry(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path, ttl_days=0) # immediate expiry
data = {"test": True}
cache.put("test", "topic", data)
# Manually set old timestamp
cache_path = tmp_path / "test"
for f in cache_path.glob("*.json"):
content = json.loads(f.read_text())
content["_cached_at"] = time.time() - 86400 # 1 day ago
f.write_text(json.dumps(content))
result = cache.get("test", "topic")
assert result is None # expired
def test_clear_domain(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path)
cache.put("ml_vision", "topic1", {"data": 1})
cache.put("ml_vision", "topic2", {"data": 2})
cache.put("physics", "topic3", {"data": 3})
count = cache.clear("ml_vision")
assert count == 2
assert cache.get("ml_vision", "topic1") is None
assert cache.get("physics", "topic3") is not None
def test_clear_all(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path)
cache.put("a", "t1", {"x": 1})
cache.put("b", "t2", {"x": 2})
count = cache.clear()
assert count == 2
def test_stats(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path)
cache.put("ml_vision", "t1", {"x": 1})
cache.put("ml_vision", "t2", {"x": 2})
cache.put("physics", "t3", {"x": 3})
stats = cache.stats()
assert stats["total"] == 3
assert stats.get("ml_vision", 0) == 2
def test_topic_hash_deterministic(self):
h1 = SearchCache._topic_hash("test topic")
h2 = SearchCache._topic_hash("test topic")
assert h1 == h2
def test_topic_hash_case_insensitive(self):
h1 = SearchCache._topic_hash("Test Topic")
h2 = SearchCache._topic_hash("test topic")
assert h1 == h2
# ---------------------------------------------------------------------------
# GitHubClient tests (mocked)
# ---------------------------------------------------------------------------
class TestGitHubClient:
def test_has_token_false(self):
with patch.dict("os.environ", {}, clear=True):
client = GitHubClient(token="")
# Can't easily clear env, but token="" means no token
assert not client.has_token
def test_has_token_true(self):
client = GitHubClient(token="ghp_test123")
assert client.has_token
def test_headers_with_token(self):
client = GitHubClient(token="ghp_test123")
headers = client._headers()
assert "Authorization" in headers
assert "Bearer" in headers["Authorization"]
def test_headers_without_token(self):
client = GitHubClient(token="")
headers = client._headers()
assert "Authorization" not in headers
# ---------------------------------------------------------------------------
# RepoInfo / CodeSnippet data class tests
# ---------------------------------------------------------------------------
class TestDataClasses:
def test_repo_info_defaults(self):
repo = RepoInfo(full_name="owner/repo")
assert repo.stars == 0
assert repo.default_branch == "main"
def test_code_snippet(self):
snippet = CodeSnippet(
repo_full_name="owner/repo",
file_path="src/main.py",
)
assert snippet.content == ""
def test_repo_analysis(self):
analysis = RepoAnalysis(
repo=RepoInfo(full_name="test/repo"),
readme="# Test Repo",
requirements=["numpy", "scipy"],
)
assert len(analysis.requirements) == 2
# ---------------------------------------------------------------------------
# CodeSearchResult tests
# ---------------------------------------------------------------------------
class TestCodeSearchResult:
def test_empty_result(self):
result = CodeSearchResult()
assert result.to_prompt_context() == ""
assert not result.from_cache
def test_result_with_patterns(self):
result = CodeSearchResult(
patterns=CodePatterns(
api_patterns=["import numpy as np"],
file_structure={"main.py": "Entry point"},
),
)
ctx = result.to_prompt_context()
assert "numpy" in ctx
def test_cache_roundtrip(self):
result = CodeSearchResult(
patterns=CodePatterns(
api_patterns=["import numpy"],
file_structure={"main.py": "Entry"},
evaluation_patterns=["error = norm(diff)"],
),
repos_found=[
RepoInfo(full_name="test/repo", stars=100, html_url="https://example.com"),
],
queries_used=["test query"],
)
cache_dict = result.to_cache_dict()
restored = CodeSearchResult.from_cache_dict(cache_dict)
assert restored.from_cache
assert restored.patterns.api_patterns == ["import numpy"]
assert len(restored.repos_found) == 1
assert restored.queries_used == ["test query"]
# ---------------------------------------------------------------------------
# CodeSearchAgent tests (mocked GitHub)
# ---------------------------------------------------------------------------
class TestCodeSearchAgent:
def _mock_github(self):
"""Create a mock GitHub client."""
mock = MagicMock(spec=GitHubClient)
mock.search_repos.return_value = [
RepoInfo(
full_name="user/physics-sim",
description="Physics simulation framework",
stars=500,
html_url="https://github.com/user/physics-sim",
),
]
mock.search_code.return_value = [
CodeSnippet(
repo_full_name="user/physics-sim",
file_path="main.py",
score=10.0,
),
]
mock.get_readme.return_value = "# Physics Simulation\nA framework for physics sims."
mock.get_repo_tree.return_value = ["main.py", "solver.py", "requirements.txt"]
mock.get_file_content.return_value = "import numpy as np\ndef solve(): pass"
mock.request_count = 5
return mock
def test_search_uses_cache(self, tmp_path):
cache = SearchCache(cache_dir=tmp_path)
cache.put("physics_simulation", "N-body sim", {
"api_patterns": ["cached pattern"],
"file_structure": {},
"evaluation_patterns": [],
"library_versions": {},
"repos": [],
"queries": ["cached query"],
})
agent = CodeSearchAgent(cache=cache)
profile = DomainProfile(
domain_id="physics_simulation",
display_name="Physics",
core_libraries=["numpy"],
)
result = agent.search("N-body sim", profile)
assert result.from_cache
assert result.patterns.api_patterns == ["cached pattern"]
def test_search_with_mock_github(self, tmp_path):
mock_github = self._mock_github()
cache = SearchCache(cache_dir=tmp_path)
agent = CodeSearchAgent(cache=cache)
agent._github = mock_github
profile = DomainProfile(
domain_id="physics_simulation",
display_name="Computational Physics",
core_libraries=["numpy", "scipy"],
github_search_terms=["physics simulation python"],
)
result = agent.search("molecular dynamics simulation", profile)
assert not result.from_cache
assert len(result.queries_used) >= 2
mock_github.search_repos.assert_called_once()
def test_search_graceful_failure(self, tmp_path):
"""If GitHub fails, should still return empty result without crashing."""
mock_github = MagicMock(spec=GitHubClient)
mock_github.search_repos.side_effect = Exception("Network error")
mock_github.search_code.side_effect = Exception("Network error")
mock_github.request_count = 0
cache = SearchCache(cache_dir=tmp_path)
agent = CodeSearchAgent(cache=cache)
agent._github = mock_github
profile = DomainProfile(
domain_id="test",
display_name="Test",
core_libraries=["numpy"],
)
result = agent.search("test topic", profile)
# Should not crash
assert isinstance(result, CodeSearchResult)
================================================
FILE: tests/test_collaboration.py
================================================
"""Tests for the collaboration system (15+ tests).
Covers:
- ResearchRepository (publish, search, list)
- ArtifactPublisher (extraction from run dirs)
- ArtifactSubscriber (queries)
- Deduplication (content_hash, deduplicate_artifacts)
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.collaboration.repository import ResearchRepository
from researchclaw.collaboration.publisher import ArtifactPublisher
from researchclaw.collaboration.subscriber import ArtifactSubscriber
from researchclaw.collaboration.dedup import content_hash, deduplicate_artifacts
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture
def repo(tmp_path: Path) -> ResearchRepository:
return ResearchRepository(repo_dir=tmp_path / "shared_repo")
@pytest.fixture
def populated_repo(repo: ResearchRepository) -> ResearchRepository:
repo.publish(
run_id="run-001",
artifacts={
"literature_summary": {"papers": ["Paper A on transformer", "Paper B on vision"]},
"experiment_results": {"accuracy": 0.95, "model": "ResNet50"},
},
)
repo.publish(
run_id="run-002",
artifacts={
"literature_summary": {"papers": ["Paper C on nlp transformer"]},
"code_template": "import torch\nmodel = ResNet()\n# pytorch training",
},
)
return repo
@pytest.fixture
def run_dir(tmp_path: Path) -> Path:
"""Create a fake pipeline run directory with stage outputs."""
d = tmp_path / "run-test"
d.mkdir()
# Stage 07 — literature synthesis
s07 = d / "stage-07-literature_synthesis"
s07.mkdir()
(s07 / "synthesis.json").write_text(
json.dumps({"papers": [{"title": "Test Paper", "year": 2024}]}),
encoding="utf-8",
)
# Stage 10 — code generation
s10 = d / "stage-10-code_generation"
s10.mkdir()
(s10 / "main.py").write_text("print('hello')", encoding="utf-8")
# Stage 14 — result analysis
s14 = d / "stage-14-result_analysis"
s14.mkdir()
(s14 / "experiment_summary.json").write_text(
json.dumps({"accuracy": 0.92}), encoding="utf-8"
)
# Stage 18 — peer review
s18 = d / "stage-18-peer_review"
s18.mkdir()
(s18 / "review.md").write_text("Good paper overall.", encoding="utf-8")
return d
# ── Repository Tests ─────────────────────────────────────────────────
class TestResearchRepository:
def test_publish(self, repo: ResearchRepository) -> None:
count = repo.publish(
run_id="run-001",
artifacts={"literature_summary": {"papers": ["P1"]}},
)
assert count == 1
def test_publish_creates_dirs(self, repo: ResearchRepository) -> None:
repo.publish(
run_id="run-new",
artifacts={"code_template": "print('hi')"},
)
assert (repo.repo_dir / "run-new").is_dir()
def test_publish_unknown_type_skipped(self, repo: ResearchRepository) -> None:
count = repo.publish(
run_id="run-bad",
artifacts={"unknown_type": "data"},
)
assert count == 0
def test_search_by_query(self, populated_repo: ResearchRepository) -> None:
results = populated_repo.search("transformer")
assert len(results) >= 2
def test_search_by_type(self, populated_repo: ResearchRepository) -> None:
results = populated_repo.search(
"paper", artifact_type="literature_summary"
)
assert len(results) >= 1
def test_search_no_results(self, populated_repo: ResearchRepository) -> None:
results = populated_repo.search("quantum_nonexistent_xyz")
assert len(results) == 0
def test_search_empty_repo(self, repo: ResearchRepository) -> None:
results = repo.search("anything")
assert results == []
def test_list_runs(self, populated_repo: ResearchRepository) -> None:
runs = populated_repo.list_runs()
assert "run-001" in runs
assert "run-002" in runs
def test_list_runs_empty(self, repo: ResearchRepository) -> None:
runs = repo.list_runs()
assert runs == []
def test_get_run_artifacts(self, populated_repo: ResearchRepository) -> None:
artifacts = populated_repo.get_run_artifacts("run-001")
assert "literature_summary" in artifacts
assert "experiment_results" in artifacts
def test_get_run_artifacts_missing(self, populated_repo: ResearchRepository) -> None:
artifacts = populated_repo.get_run_artifacts("run-999")
assert artifacts == {}
def test_import_literature(self, populated_repo: ResearchRepository) -> None:
lit = populated_repo.import_literature("run-001")
assert isinstance(lit, list)
assert len(lit) >= 1
def test_import_literature_missing_run(self, populated_repo: ResearchRepository) -> None:
lit = populated_repo.import_literature("run-999")
assert lit == []
def test_import_code_template(self, populated_repo: ResearchRepository) -> None:
code = populated_repo.import_code_template("run-002", "pytorch")
assert code is not None
assert "torch" in code
def test_import_code_template_no_match(self, populated_repo: ResearchRepository) -> None:
code = populated_repo.import_code_template("run-002", "tensorflow_xyz")
assert code is None
# ── Publisher Tests ──────────────────────────────────────────────────
class TestArtifactPublisher:
def test_publish_from_run_dir(self, run_dir: Path, tmp_path: Path) -> None:
repo = ResearchRepository(repo_dir=tmp_path / "pub_repo")
publisher = ArtifactPublisher(repo)
count = publisher.publish_from_run_dir("test-run", run_dir)
assert count >= 1
def test_publish_empty_dir(self, tmp_path: Path) -> None:
empty = tmp_path / "empty_run"
empty.mkdir()
repo = ResearchRepository(repo_dir=tmp_path / "pub_repo2")
publisher = ArtifactPublisher(repo)
count = publisher.publish_from_run_dir("empty", empty)
assert count == 0
def test_publish_nonexistent_dir(self, tmp_path: Path) -> None:
repo = ResearchRepository(repo_dir=tmp_path / "pub_repo3")
publisher = ArtifactPublisher(repo)
count = publisher.publish_from_run_dir("missing", tmp_path / "nope")
assert count == 0
# ── Subscriber Tests ─────────────────────────────────────────────────
class TestArtifactSubscriber:
def test_find_relevant_literature(self, populated_repo: ResearchRepository) -> None:
sub = ArtifactSubscriber(populated_repo)
results = sub.find_relevant_literature("transformer")
assert len(results) >= 1
def test_find_similar_experiments(self, populated_repo: ResearchRepository) -> None:
sub = ArtifactSubscriber(populated_repo)
results = sub.find_similar_experiments("resnet")
assert len(results) >= 1
def test_find_code_templates(self, populated_repo: ResearchRepository) -> None:
sub = ArtifactSubscriber(populated_repo)
results = sub.find_code_templates("pytorch")
assert len(results) >= 1
def test_import_best_practices(self, populated_repo: ResearchRepository) -> None:
sub = ArtifactSubscriber(populated_repo)
practices = sub.import_best_practices("transformer")
assert isinstance(practices, str)
def test_import_best_practices_empty(self, repo: ResearchRepository) -> None:
sub = ArtifactSubscriber(repo)
practices = sub.import_best_practices("nonexistent")
assert practices == ""
# ── Dedup Tests ──────────────────────────────────────────────────────
class TestDedup:
def test_content_hash_deterministic(self) -> None:
h1 = content_hash({"a": 1, "b": 2})
h2 = content_hash({"b": 2, "a": 1})
assert h1 == h2
def test_content_hash_different(self) -> None:
h1 = content_hash({"a": 1})
h2 = content_hash({"a": 2})
assert h1 != h2
def test_deduplicate_artifacts(self) -> None:
artifacts = [
{"content": {"x": 1}, "tags": ["a"]},
{"content": {"x": 1}, "tags": ["b"]}, # duplicate content
{"content": {"y": 2}, "tags": ["c"]},
]
unique = deduplicate_artifacts(artifacts)
assert len(unique) == 2
def test_deduplicate_empty(self) -> None:
assert deduplicate_artifacts([]) == []
================================================
FILE: tests/test_compiler.py
================================================
"""Tests for researchclaw.templates.compiler — BUG-197 and general compilation.
BUG-197: pdflatex stdout containing broken UTF-8 (from U+202F error messages)
caused UnicodeDecodeError that killed the compilation pipeline, preventing
bibtex from running and leaving all citations as [?].
"""
from __future__ import annotations
import re
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.templates.compiler import (
CompileResult,
_is_fatal_error,
_sanitize_tex_unicode,
fix_common_latex_errors,
)
# ---------------------------------------------------------------------------
# _is_fatal_error
# ---------------------------------------------------------------------------
class TestIsFatalError:
"""Test that _is_fatal_error correctly classifies errors."""
def test_unicode_char_not_set_up_is_nonfatal(self):
"""BUG-197: Unicode character errors should be non-fatal.
The error line captured by _parse_log is a single line:
``! LaTeX Error: Unicode character X (U+202F)``
(the "not set up" text is on a continuation line).
"""
err = "! LaTeX Error: Unicode character \u202f (U+202F)"
assert not _is_fatal_error(err)
def test_unicode_char_various_codepoints_nonfatal(self):
"""Various Unicode character codepoints should be non-fatal."""
for cp in ["U+00A0", "U+2009", "U+2007", "U+3000"]:
err = f"! LaTeX Error: Unicode character X ({cp})"
assert not _is_fatal_error(err), f"Expected non-fatal for {cp}"
def test_undefined_control_sequence_is_fatal(self):
err = "! Undefined control sequence."
assert _is_fatal_error(err)
def test_missing_dollar_is_fatal(self):
err = "! Missing $ inserted."
assert _is_fatal_error(err)
def test_overfull_hbox_is_nonfatal(self):
err = "! Overfull \\hbox (12.3pt too wide)"
assert not _is_fatal_error(err)
def test_float_lost_is_nonfatal(self):
err = "! Float(s) lost."
assert not _is_fatal_error(err)
def test_unavailable_in_encoding_is_nonfatal(self):
err = "! Package inputenc Error: Unicode character unavailable in encoding OT1."
assert not _is_fatal_error(err)
def test_emergency_stop_is_fatal(self):
err = "! ==> Fatal error occurred, no output PDF file produced!"
assert _is_fatal_error(err)
def test_non_bang_file_not_found_is_fatal(self):
err = "File `missing.sty' not found."
assert _is_fatal_error(err)
# ---------------------------------------------------------------------------
# _sanitize_tex_unicode
# ---------------------------------------------------------------------------
class TestSanitizeTexUnicode:
"""Test that _sanitize_tex_unicode strips problematic Unicode."""
def test_replaces_narrow_no_break_space(self, tmp_path: Path):
"""BUG-197: U+202F should be replaced with ASCII space."""
tex = tmp_path / "test.tex"
tex.write_text("Hello\u202fWorld\n", encoding="utf-8")
_sanitize_tex_unicode(tex)
assert tex.read_text(encoding="utf-8") == "Hello World\n"
def test_replaces_no_break_space(self, tmp_path: Path):
"""U+00A0 should be replaced with ASCII space."""
tex = tmp_path / "test.tex"
tex.write_text("Hello\u00a0World\n", encoding="utf-8")
_sanitize_tex_unicode(tex)
assert tex.read_text(encoding="utf-8") == "Hello World\n"
def test_removes_zero_width_space(self, tmp_path: Path):
"""U+200B should be removed entirely."""
tex = tmp_path / "test.tex"
tex.write_text("Hello\u200bWorld\n", encoding="utf-8")
_sanitize_tex_unicode(tex)
assert tex.read_text(encoding="utf-8") == "HelloWorld\n"
def test_removes_bom(self, tmp_path: Path):
"""U+FEFF BOM should be removed."""
tex = tmp_path / "test.tex"
tex.write_text("\ufeffHello\n", encoding="utf-8")
_sanitize_tex_unicode(tex)
assert tex.read_text(encoding="utf-8") == "Hello\n"
def test_preserves_normal_text(self, tmp_path: Path):
"""Normal ASCII + standard Unicode should be untouched."""
content = "Hello World, \\section{Intro} $x^2$\n"
tex = tmp_path / "test.tex"
tex.write_text(content, encoding="utf-8")
_sanitize_tex_unicode(tex)
assert tex.read_text(encoding="utf-8") == content
def test_handles_multiple_types(self, tmp_path: Path):
"""Multiple types of problematic chars in one file."""
tex = tmp_path / "test.tex"
tex.write_text(
"A\u202fB\u00a0C\u200bD\u200eE\n",
encoding="utf-8",
)
_sanitize_tex_unicode(tex)
result = tex.read_text(encoding="utf-8")
assert result == "A B CDE\n"
def test_nonexistent_file(self, tmp_path: Path):
"""Should not crash on nonexistent file."""
_sanitize_tex_unicode(tmp_path / "nonexistent.tex")
def test_cyrillic_transliterated_to_latin(self, tmp_path: Path):
"""BUG-201: Cyrillic author names should be transliterated."""
tex = tmp_path / "test.tex"
tex.write_text(
"А. И. Колесников\n",
encoding="utf-8",
)
_sanitize_tex_unicode(tex)
result = tex.read_text(encoding="utf-8")
assert "А" not in result # no Cyrillic left
assert "И" not in result
assert "A. I. Kolesnikov" in result
# ---------------------------------------------------------------------------
# _sanitize_bib_file — Cyrillic transliteration
# ---------------------------------------------------------------------------
class TestSanitizeBibFile:
"""Test _sanitize_bib_file fixes."""
def test_cyrillic_author_transliterated(self, tmp_path: Path):
"""BUG-201: Cyrillic in bib author names should be transliterated."""
from researchclaw.templates.compiler import _sanitize_bib_file
bib = tmp_path / "references.bib"
bib.write_text(
'@article{dehghani2023scaling,\n'
' author = {А. И. Колесников and J. Doe},\n'
' title = {Scaling Vision},\n'
'}\n',
encoding="utf-8",
)
_sanitize_bib_file(bib)
result = bib.read_text(encoding="utf-8")
assert "А" not in result
assert "A. I. Kolesnikov" in result
assert "J. Doe" in result # Latin unchanged
# ---------------------------------------------------------------------------
# fix_common_latex_errors — Unicode handler
# ---------------------------------------------------------------------------
class TestFixUnicodeErrors:
"""Test fix_common_latex_errors for Unicode character issues."""
def test_unicode_u202f_replaced_with_space(self):
"""BUG-197: U+202F in text should be replaced with space."""
tex = "Hello\u202fWorld"
errors = [
"! LaTeX Error: Unicode character \u202f (U+202F)"
]
fixed, fixes = fix_common_latex_errors(tex, errors)
assert "\u202f" not in fixed
assert "Hello World" in fixed
assert any("U+202F" in f for f in fixes)
def test_unicode_u200b_removed(self):
"""U+200B (zero-width space, category Cf) should be removed."""
tex = "Hello\u200bWorld"
errors = [
"! LaTeX Error: Unicode character \u200b (U+200B)"
]
fixed, fixes = fix_common_latex_errors(tex, errors)
assert "\u200b" not in fixed
assert "HelloWorld" in fixed
def test_no_unicode_error_no_change(self):
"""Text without the offending char should not be modified."""
tex = "Hello World"
errors = [
"! LaTeX Error: Unicode character \u202f (U+202F)"
]
fixed, fixes = fix_common_latex_errors(tex, errors)
assert fixed == tex
# No fix should be applied since the char isn't in the text
assert not any("U+202F" in f for f in fixes)
# ---------------------------------------------------------------------------
# _run_pdflatex — bytes mode decoding
# ---------------------------------------------------------------------------
class TestRunPdflatexByteMode:
"""Test that _run_pdflatex handles broken UTF-8 in stdout."""
@patch("researchclaw.templates.compiler.subprocess.run")
def test_broken_utf8_in_stdout_does_not_crash(self, mock_run):
"""BUG-197: Broken UTF-8 bytes should be decoded with replacement."""
from researchclaw.templates.compiler import _run_pdflatex
# Simulate pdflatex returning broken UTF-8 in stdout
mock_proc = MagicMock()
mock_proc.stdout = b"Normal output \xe2\x80 broken" # Invalid UTF-8
mock_proc.stderr = b""
mock_proc.returncode = 1
mock_run.return_value = mock_proc
log_text, success = _run_pdflatex(Path("/tmp"), "test.tex", timeout=60)
assert log_text is not None
assert "Normal output" in log_text
assert not success
@patch("researchclaw.templates.compiler.subprocess.run")
def test_valid_utf8_works(self, mock_run):
"""Normal UTF-8 output should work fine."""
from researchclaw.templates.compiler import _run_pdflatex
mock_proc = MagicMock()
mock_proc.stdout = b"Output written on test.pdf (1 page)"
mock_proc.stderr = b""
mock_proc.returncode = 0
mock_run.return_value = mock_proc
log_text, success = _run_pdflatex(Path("/tmp"), "test.tex", timeout=60)
assert log_text is not None
assert "Output written" in log_text
assert success
# ---------------------------------------------------------------------------
# _run_bibtex — bytes mode decoding + logging
# ---------------------------------------------------------------------------
class TestRunBibtex:
"""Test that _run_bibtex handles errors and logs properly."""
@patch("researchclaw.templates.compiler.shutil.which", return_value="/usr/bin/bibtex")
@patch("researchclaw.templates.compiler.subprocess.run")
def test_bibtex_failure_logged(self, mock_run, mock_which, tmp_path):
"""Failed bibtex should log warning and return False."""
from researchclaw.templates.compiler import _run_bibtex
mock_proc = MagicMock()
mock_proc.stdout = b"I couldn't open file name.aux"
mock_proc.stderr = b""
mock_proc.returncode = 1
mock_run.return_value = mock_proc
result = _run_bibtex(tmp_path, "paper", timeout=60)
assert result is False
@patch("researchclaw.templates.compiler.shutil.which", return_value="/usr/bin/bibtex")
@patch("researchclaw.templates.compiler.subprocess.run")
def test_bibtex_success_with_bbl(self, mock_run, mock_which, tmp_path):
"""Successful bibtex with .bbl creation should return True."""
from researchclaw.templates.compiler import _run_bibtex
# Create fake .bbl so the check passes
(tmp_path / "paper.bbl").write_text("\\begin{thebibliography}{}")
mock_proc = MagicMock()
mock_proc.stdout = b"Database file #1: references.bib"
mock_proc.stderr = b""
mock_proc.returncode = 0
mock_run.return_value = mock_proc
result = _run_bibtex(tmp_path, "paper", timeout=60)
assert result is True
@patch("researchclaw.templates.compiler.shutil.which", return_value=None)
def test_bibtex_not_found(self, mock_which, tmp_path):
"""Missing bibtex binary should return False."""
from researchclaw.templates.compiler import _run_bibtex
result = _run_bibtex(tmp_path, "paper", timeout=60)
assert result is False
@patch("researchclaw.templates.compiler.shutil.which", return_value="/usr/bin/bibtex")
@patch("researchclaw.templates.compiler.subprocess.run")
def test_bibtex_broken_utf8(self, mock_run, mock_which, tmp_path):
"""BUG-197: Broken UTF-8 in bibtex output should not crash."""
from researchclaw.templates.compiler import _run_bibtex
(tmp_path / "paper.bbl").write_text("\\begin{thebibliography}{}")
mock_proc = MagicMock()
mock_proc.stdout = b"Database file \xe2\x80 broken"
mock_proc.stderr = b""
mock_proc.returncode = 0
mock_run.return_value = mock_proc
# Should not raise
result = _run_bibtex(tmp_path, "paper", timeout=60)
assert result is True
================================================
FILE: tests/test_convergence_evaluator.py
================================================
"""Tests for the convergence study evaluator."""
from __future__ import annotations
import math
import pytest
from researchclaw.experiment.evaluators.convergence import (
ConvergenceReport,
ConvergenceResult,
analyze_convergence,
compute_convergence_order,
)
# ---------------------------------------------------------------------------
# compute_convergence_order tests
# ---------------------------------------------------------------------------
class TestComputeConvergenceOrder:
def test_second_order(self):
"""h, h/2, h/4, h/8 with error ~ h^2."""
hs = [0.1, 0.05, 0.025, 0.0125]
errors = [h**2 for h in hs]
order, r2 = compute_convergence_order(hs, errors)
assert abs(order - 2.0) < 0.1
assert r2 > 0.99
def test_fourth_order(self):
"""Error ~ h^4."""
hs = [0.1, 0.05, 0.025, 0.0125]
errors = [h**4 for h in hs]
order, r2 = compute_convergence_order(hs, errors)
assert abs(order - 4.0) < 0.1
assert r2 > 0.99
def test_first_order(self):
"""Error ~ h."""
hs = [0.1, 0.05, 0.025, 0.0125]
errors = [h for h in hs]
order, r2 = compute_convergence_order(hs, errors)
assert abs(order - 1.0) < 0.1
def test_too_few_points(self):
order, r2 = compute_convergence_order([0.1], [0.01])
assert order == 0.0
assert r2 == 0.0
def test_empty_input(self):
order, r2 = compute_convergence_order([], [])
assert order == 0.0
def test_filters_invalid(self):
hs = [0.1, 0.0, 0.025, -0.01] # 0 and negative should be filtered
errors = [0.01, 0.0, 0.001, 0.0001]
order, r2 = compute_convergence_order(hs, errors)
# Should still work with valid points
assert order > 0
# ---------------------------------------------------------------------------
# analyze_convergence tests
# ---------------------------------------------------------------------------
class TestAnalyzeConvergence:
def test_single_method(self):
data = {
"euler": [
{"h": 0.1, "error": 0.1},
{"h": 0.05, "error": 0.05},
{"h": 0.025, "error": 0.025},
]
}
report = analyze_convergence(data)
assert len(report.methods) == 1
assert report.methods[0].method == "euler"
assert abs(report.methods[0].convergence_order - 1.0) < 0.2
assert report.best_method == "euler"
def test_multiple_methods(self):
data = {
"euler": [
{"h": 0.1, "error": 0.1},
{"h": 0.05, "error": 0.05},
{"h": 0.025, "error": 0.025},
],
"rk4": [
{"h": 0.1, "error": 1e-4},
{"h": 0.05, "error": 6.25e-6},
{"h": 0.025, "error": 3.9e-7},
],
}
report = analyze_convergence(data)
assert len(report.methods) == 2
# RK4 should have higher order
orders = {r.method: r.convergence_order for r in report.methods}
assert orders["rk4"] > orders["euler"]
assert report.best_method == "rk4"
def test_expected_orders(self):
data = {
"euler": [
{"h": 0.1, "error": 0.1},
{"h": 0.05, "error": 0.05},
{"h": 0.025, "error": 0.025},
],
}
report = analyze_convergence(data, expected_orders={"euler": 1.0})
assert report.methods[0].expected_order == 1.0
assert report.methods[0].order_matches_expected is True
def test_non_converging(self):
data = {
"bad_method": [
{"h": 0.1, "error": 0.5},
{"h": 0.05, "error": 0.6}, # error increases
{"h": 0.025, "error": 0.7},
],
}
report = analyze_convergence(data)
# Negative or very low order indicates no convergence
assert not report.methods[0].is_converging
def test_summary_string(self):
data = {
"method_a": [
{"h": 0.1, "error": 0.01},
{"h": 0.05, "error": 0.0025},
],
}
report = analyze_convergence(data)
assert report.summary # should not be empty
assert "method_a" in report.summary
def test_l2_error_key(self):
"""Should handle l2_error as the error key."""
data = {
"fem": [
{"h": 0.1, "l2_error": 0.01},
{"h": 0.05, "l2_error": 0.0025},
{"h": 0.025, "l2_error": 0.000625},
],
}
report = analyze_convergence(data)
assert abs(report.methods[0].convergence_order - 2.0) < 0.2
def test_empty_data(self):
report = analyze_convergence({})
assert len(report.methods) == 0
assert report.best_method == ""
================================================
FILE: tests/test_copilot.py
================================================
"""Tests for researchclaw.copilot — Interactive Co-Pilot Mode (Agent D2).
30+ tests covering modes, feedback, branching, and controller.
"""
from __future__ import annotations
import json
import shutil
import time
from datetime import date, timedelta
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from researchclaw.copilot.modes import ResearchMode
from researchclaw.copilot.feedback import (
FEEDBACK_ACTIONS,
Feedback,
FeedbackHandler,
)
from researchclaw.copilot.branching import BranchManager
from researchclaw.copilot.controller import CoPilotController
from researchclaw.config import CoPilotConfig
# ===================================================================
# ResearchMode tests
# ===================================================================
class TestResearchMode:
def test_all_modes(self):
assert ResearchMode.CO_PILOT.value == "co-pilot"
assert ResearchMode.AUTO_PILOT.value == "auto-pilot"
assert ResearchMode.ZERO_TOUCH.value == "zero-touch"
def test_from_value(self):
assert ResearchMode("co-pilot") == ResearchMode.CO_PILOT
assert ResearchMode("auto-pilot") == ResearchMode.AUTO_PILOT
assert ResearchMode("zero-touch") == ResearchMode.ZERO_TOUCH
def test_invalid_mode_raises(self):
with pytest.raises(ValueError):
ResearchMode("invalid")
def test_mode_count(self):
assert len(ResearchMode) == 3
# ===================================================================
# Feedback tests
# ===================================================================
class TestFeedback:
def test_feedback_actions_defined(self):
expected = {"approve", "modify", "retry", "skip", "discuss", "branch", "rollback"}
assert FEEDBACK_ACTIONS == expected
def test_feedback_frozen(self):
fb = Feedback(action="approve", stage=5)
with pytest.raises(AttributeError):
fb.action = "retry" # type: ignore[misc]
def test_feedback_defaults(self):
fb = Feedback(action="approve", stage=1)
assert fb.message == ""
assert fb.modifications is None
assert fb.branch_name == ""
assert fb.rollback_to is None
def test_feedback_with_modifications(self):
fb = Feedback(
action="modify",
stage=5,
message="Update hypothesis",
modifications={"hypothesis": "new hypothesis"},
)
assert fb.modifications == {"hypothesis": "new hypothesis"}
# ===================================================================
# FeedbackHandler tests
# ===================================================================
class TestFeedbackHandler:
def test_write_feedback_request(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
request_path = handler.write_feedback_request(
stage=5,
stage_name="LITERATURE_SCREEN",
summary="10 papers screened",
)
assert request_path.exists()
data = json.loads(request_path.read_text(encoding="utf-8"))
assert data["stage"] == 5
assert data["stage_name"] == "LITERATURE_SCREEN"
assert data["status"] == "waiting"
assert isinstance(data["options"], list)
def test_read_feedback_response_valid(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
response = {
"action": "approve",
"stage": 5,
"message": "Looks good",
}
resp_path = tmp_path / "copilot_feedback_response.json"
resp_path.write_text(json.dumps(response), encoding="utf-8")
fb = handler.read_feedback_response()
assert fb is not None
assert fb.action == "approve"
assert fb.stage == 5
assert fb.message == "Looks good"
def test_read_feedback_response_invalid_action(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
response = {"action": "invalid_action", "stage": 5}
resp_path = tmp_path / "copilot_feedback_response.json"
resp_path.write_text(json.dumps(response), encoding="utf-8")
fb = handler.read_feedback_response()
assert fb is None
def test_read_feedback_response_missing(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
assert handler.read_feedback_response() is None
def test_read_feedback_response_malformed(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
resp_path = tmp_path / "copilot_feedback_response.json"
resp_path.write_text("{invalid json", encoding="utf-8")
assert handler.read_feedback_response() is None
def test_read_feedback_response_with_rollback(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
response = {
"action": "rollback",
"stage": 15,
"rollback_to": 8,
}
resp_path = tmp_path / "copilot_feedback_response.json"
resp_path.write_text(json.dumps(response), encoding="utf-8")
fb = handler.read_feedback_response()
assert fb is not None
assert fb.action == "rollback"
assert fb.rollback_to == 8
def test_read_feedback_response_branch(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
response = {
"action": "branch",
"stage": 9,
"branch_name": "alt_experiment",
}
resp_path = tmp_path / "copilot_feedback_response.json"
resp_path.write_text(json.dumps(response), encoding="utf-8")
fb = handler.read_feedback_response()
assert fb is not None
assert fb.branch_name == "alt_experiment"
def test_clear_request(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
handler.write_feedback_request(1, "TOPIC_INIT", "Done")
handler.clear_request()
assert not (tmp_path / "copilot_feedback_request.json").exists()
def test_clear_request_no_file(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
handler.clear_request() # should not raise
def test_wait_for_feedback_timeout(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
result = handler.wait_for_feedback(stage=1, timeout_sec=0, poll_interval_sec=0.01)
assert result is None
def test_wait_for_feedback_finds_response(self, tmp_path: Path):
handler = FeedbackHandler(tmp_path)
# Pre-clear any stale response (wait_for_feedback clears first)
# Then write a response matching stage
response = {"action": "approve", "stage": 5}
resp_path = tmp_path / "copilot_feedback_response.json"
def write_response():
"""Simulate delayed response writing."""
time.sleep(0.05)
resp_path.write_text(json.dumps(response), encoding="utf-8")
import threading
t = threading.Thread(target=write_response)
t.start()
fb = handler.wait_for_feedback(stage=5, timeout_sec=2, poll_interval_sec=0.02)
t.join()
assert fb is not None
assert fb.action == "approve"
# ===================================================================
# BranchManager tests
# ===================================================================
class TestBranchManager:
def test_create_branch(self, tmp_path: Path):
# Create stage dirs
(tmp_path / "stage-01").mkdir()
(tmp_path / "stage-01" / "output.json").write_text("{}")
(tmp_path / "stage-02").mkdir()
(tmp_path / "stage-02" / "result.txt").write_text("ok")
bm = BranchManager(tmp_path, max_branches=3)
branch_path = bm.create_branch("exp_alt", from_stage=2)
assert Path(branch_path).exists()
assert (Path(branch_path) / "stage-01" / "output.json").exists()
assert (Path(branch_path) / "stage-02" / "result.txt").exists()
assert (Path(branch_path) / "branch_meta.json").exists()
meta = json.loads(
(Path(branch_path) / "branch_meta.json").read_text(encoding="utf-8")
)
assert meta["name"] == "exp_alt"
assert meta["from_stage"] == 2
def test_create_branch_max_reached(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=1)
bm.create_branch("b1", from_stage=1)
with pytest.raises(ValueError, match="Maximum branches"):
bm.create_branch("b2", from_stage=1)
def test_create_branch_duplicate_name(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=5)
bm.create_branch("dup", from_stage=1)
with pytest.raises(ValueError, match="already exists"):
bm.create_branch("dup", from_stage=1)
def test_list_branches_empty(self, tmp_path: Path):
bm = BranchManager(tmp_path)
assert bm.list_branches() == []
def test_list_branches(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=5)
bm.create_branch("alpha", from_stage=1)
bm.create_branch("beta", from_stage=2)
branches = bm.list_branches()
assert len(branches) == 2
names = {b["name"] for b in branches}
assert names == {"alpha", "beta"}
def test_switch_branch(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=3)
bm.create_branch("test_branch", from_stage=1)
path = bm.switch_branch("test_branch")
assert path.exists()
def test_switch_branch_nonexistent(self, tmp_path: Path):
bm = BranchManager(tmp_path)
with pytest.raises(ValueError, match="does not exist"):
bm.switch_branch("nonexistent")
def test_delete_branch(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=3)
bm.create_branch("doomed", from_stage=1)
assert len(bm.list_branches()) == 1
bm.delete_branch("doomed")
assert len(bm.list_branches()) == 0
def test_delete_branch_nonexistent(self, tmp_path: Path):
bm = BranchManager(tmp_path)
with pytest.raises(ValueError, match="does not exist"):
bm.delete_branch("ghost")
def test_compare_branches(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=5)
(tmp_path / "stage-01").mkdir()
(tmp_path / "stage-02").mkdir()
bm.create_branch("a", from_stage=2)
bm.create_branch("b", from_stage=1)
result = bm.compare_branches("a", "b")
assert result["branch_a"] == "a"
assert result["stages_a"] == 2
assert result["stages_b"] == 1
def test_compare_branches_nonexistent(self, tmp_path: Path):
bm = BranchManager(tmp_path, max_branches=3)
bm.create_branch("real", from_stage=1)
result = bm.compare_branches("real", "fake")
assert "error" in result
def test_count_stages(self, tmp_path: Path):
(tmp_path / "stage-01").mkdir()
(tmp_path / "stage-02").mkdir()
(tmp_path / "other_dir").mkdir()
assert BranchManager._count_stages(tmp_path) == 2
# ===================================================================
# CoPilotController tests
# ===================================================================
class TestCoPilotController:
def _make_config(self, **overrides) -> CoPilotConfig:
defaults = {
"mode": "co-pilot",
"pause_at_gates": True,
"pause_at_every_stage": False,
"feedback_timeout_sec": 3600,
"allow_branching": True,
"max_branches": 3,
}
defaults.update(overrides)
return CoPilotConfig(**defaults)
def test_should_pause_zero_touch(self, tmp_path: Path):
config = self._make_config(mode="zero-touch")
ctrl = CoPilotController(config, tmp_path)
assert ctrl.should_pause(5, is_gate=True) is False
assert ctrl.should_pause(1, is_gate=False) is False
def test_should_pause_auto_pilot_gate(self, tmp_path: Path):
config = self._make_config(mode="auto-pilot")
ctrl = CoPilotController(config, tmp_path)
assert ctrl.should_pause(5, is_gate=True) is True
assert ctrl.should_pause(1, is_gate=False) is False
def test_should_pause_auto_pilot_gates_disabled(self, tmp_path: Path):
config = self._make_config(mode="auto-pilot", pause_at_gates=False)
ctrl = CoPilotController(config, tmp_path)
assert ctrl.should_pause(5, is_gate=True) is False
def test_should_pause_copilot_every_stage(self, tmp_path: Path):
config = self._make_config(mode="co-pilot", pause_at_every_stage=True)
ctrl = CoPilotController(config, tmp_path)
assert ctrl.should_pause(1, is_gate=False) is True
assert ctrl.should_pause(5, is_gate=True) is True
def test_should_pause_copilot_gates_only(self, tmp_path: Path):
config = self._make_config(mode="co-pilot", pause_at_every_stage=False)
ctrl = CoPilotController(config, tmp_path)
assert ctrl.should_pause(5, is_gate=True) is True
assert ctrl.should_pause(1, is_gate=False) is False
def test_present_stage_result(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
summary = ctrl.present_stage_result(
stage_num=5,
stage_name="LITERATURE_SCREEN",
artifacts=["screen_report.json"],
status="done",
)
assert "Stage 5: LITERATURE_SCREEN" in summary
assert "Status: done" in summary
assert "screen_report.json" in summary
def test_present_stage_result_with_error(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
summary = ctrl.present_stage_result(
stage_num=12,
stage_name="EXPERIMENT_RUN",
artifacts=[],
status="failed",
error="CUDA out of memory",
)
assert "Error: CUDA out of memory" in summary
def test_handle_feedback_approve(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="approve", stage=5)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "continue"
def test_handle_feedback_modify(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(
action="modify",
stage=5,
message="Change approach",
modifications={"key": "value"},
)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "apply_modifications"
assert result["modifications"] == {"key": "value"}
def test_handle_feedback_retry(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="retry", stage=12)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "rerun_stage"
def test_handle_feedback_skip(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="skip", stage=21)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "skip_stage"
def test_handle_feedback_branch(self, tmp_path: Path):
config = self._make_config(allow_branching=True)
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="branch", stage=9, branch_name="alt_design")
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "branch_created"
assert result["branch_name"] == "alt_design"
def test_handle_feedback_branch_disabled(self, tmp_path: Path):
config = self._make_config(allow_branching=False)
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="branch", stage=9)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "branching_disabled"
def test_handle_feedback_branch_max_reached(self, tmp_path: Path):
config = self._make_config(allow_branching=True, max_branches=1)
ctrl = CoPilotController(config, tmp_path)
# Create first branch
fb1 = Feedback(action="branch", stage=1, branch_name="b1")
ctrl.handle_feedback(fb1)
# Second branch should fail
fb2 = Feedback(action="branch", stage=2, branch_name="b2")
result = ctrl.handle_feedback(fb2)
assert result["instruction"] == "branch_failed"
def test_handle_feedback_rollback(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="rollback", stage=15, rollback_to=8)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "rollback"
assert result["rollback_to"] == 8
def test_handle_feedback_unknown_action(self, tmp_path: Path):
config = self._make_config()
ctrl = CoPilotController(config, tmp_path)
# Construct with a technically valid action but unhandled by match
fb = Feedback(action="discuss", stage=1)
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "continue"
def test_from_config_zero_touch_returns_none(self, tmp_path: Path):
config = self._make_config(mode="zero-touch")
ctrl = CoPilotController.from_config(config, tmp_path)
assert ctrl is None
def test_from_config_copilot_returns_controller(self, tmp_path: Path):
config = self._make_config(mode="co-pilot")
ctrl = CoPilotController.from_config(config, tmp_path)
assert ctrl is not None
assert isinstance(ctrl, CoPilotController)
def test_from_config_auto_pilot_returns_controller(self, tmp_path: Path):
config = self._make_config(mode="auto-pilot")
ctrl = CoPilotController.from_config(config, tmp_path)
assert ctrl is not None
def test_handle_feedback_branch_default_name(self, tmp_path: Path):
config = self._make_config(allow_branching=True)
ctrl = CoPilotController(config, tmp_path)
fb = Feedback(action="branch", stage=9) # no branch_name
result = ctrl.handle_feedback(fb)
assert result["instruction"] == "branch_created"
assert result["branch_name"] == "branch_9"
================================================
FILE: tests/test_decision_agent.py
================================================
"""Tests for FigureDecisionAgent, NanoBananaAgent, and Docker renderer.
Covers:
- FigureDecisionAgent._parse_decisions() — JSON parsing edge cases
- FigureDecisionAgent._heuristic_decide() — fallback coverage
- FigureDecisionAgent._infer_backend() — backend classification
- FigureDecisionAgent._enforce_bounds() — min/max enforcement
- NanoBananaAgent._build_prompt() — prompt construction
- NanoBananaAgent._get_type_guidelines() — guideline lookup
- RendererAgent._execute_in_docker() — docker command construction
- strip_thinking_tags() — safety verification
- End-to-end decision + orchestration with mock LLM
"""
from __future__ import annotations
import json
import os
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from unittest import mock
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@dataclass
class _FakeLLMResponse:
content: str = ""
model: str = "gpt-4.1"
prompt_tokens: int = 100
completion_tokens: int = 200
total_tokens: int = 300
finish_reason: str = "stop"
truncated: bool = False
raw: dict = None # type: ignore[assignment]
def __post_init__(self):
if self.raw is None:
self.raw = {}
class _FakeLLM:
"""Minimal mock LLM client."""
def __init__(self, response: str = "{}"):
self._response = response
self.calls: list[dict[str, Any]] = []
def chat(self, messages, *, system=None, max_tokens=None,
temperature=None, json_mode=False, **kwargs):
self.calls.append({
"messages": messages,
"system": system,
"json_mode": json_mode,
})
return _FakeLLMResponse(content=self._response)
# =========================================================================
# FigureDecisionAgent._parse_decisions()
# =========================================================================
class TestParseDecisions:
"""Edge cases for JSON parsing in the decision agent."""
def _agent(self):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
return FigureDecisionAgent(_FakeLLM())
def test_valid_json_array(self):
agent = self._agent()
raw = json.dumps([
{
"section": "Method",
"figure_type": "architecture_diagram",
"backend": "image",
"description": "Architecture overview",
"priority": 1,
},
{
"section": "Results",
"figure_type": "bar_comparison",
"backend": "code",
"description": "Main results",
"priority": 1,
},
])
decisions = agent._parse_decisions(raw)
assert len(decisions) == 2
assert decisions[0]["backend"] == "image"
assert decisions[1]["backend"] == "code"
def test_json_inside_markdown_fences(self):
agent = self._agent()
raw = '```json\n[{"section": "Method", "figure_type": "pipeline_overview", "backend": "image", "description": "Pipeline", "priority": 1}]\n```'
decisions = agent._parse_decisions(raw)
assert len(decisions) == 1
assert decisions[0]["figure_type"] == "pipeline_overview"
def test_json_with_surrounding_text(self):
agent = self._agent()
raw = 'Here are the decisions:\n[{"section": "Results", "figure_type": "heatmap", "backend": "code", "description": "Heatmap", "priority": 2}]\nThat is all.'
decisions = agent._parse_decisions(raw)
assert len(decisions) == 1
def test_no_json_array_raises(self):
agent = self._agent()
with pytest.raises(ValueError, match="No JSON array"):
agent._parse_decisions("This is not JSON at all.")
def test_empty_array(self):
agent = self._agent()
decisions = agent._parse_decisions("[]")
assert decisions == []
def test_non_dict_items_skipped(self):
agent = self._agent()
raw = json.dumps([
"not a dict",
42,
{"section": "Method", "figure_type": "architecture_diagram",
"backend": "image", "description": "Arch", "priority": 1},
])
decisions = agent._parse_decisions(raw)
assert len(decisions) == 1
def test_invalid_backend_auto_inferred(self):
agent = self._agent()
raw = json.dumps([
{"section": "Method", "figure_type": "architecture_diagram",
"backend": "invalid_backend", "description": "Arch", "priority": 1},
])
decisions = agent._parse_decisions(raw)
assert decisions[0]["backend"] == "image" # architecture → image
def test_missing_fields_get_defaults(self):
agent = self._agent()
raw = json.dumps([{}])
decisions = agent._parse_decisions(raw)
assert len(decisions) == 1
assert decisions[0]["section"] == "Results"
assert decisions[0]["figure_type"] == "bar_comparison"
assert decisions[0]["backend"] == "code"
assert decisions[0]["priority"] == 2
# =========================================================================
# FigureDecisionAgent._heuristic_decide()
# =========================================================================
class TestHeuristicDecide:
"""Test the rule-based fallback decision logic."""
def _agent(self, min_figures=3, max_figures=10):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
return FigureDecisionAgent(
_FakeLLM(), min_figures=min_figures, max_figures=max_figures
)
def test_with_experiments(self):
agent = self._agent()
decisions = agent._heuristic_decide(
topic="Graph anomaly detection",
has_experiments=True,
condition_summaries={"proposed": {}, "baseline": {}, "ablation": {}},
)
# Should have: arch_diagram + bar_comparison + training_curve + pipeline
assert len(decisions) >= 4
backends = {d["backend"] for d in decisions}
assert "code" in backends
assert "image" in backends
def test_without_experiments(self):
agent = self._agent()
decisions = agent._heuristic_decide(
topic="Theoretical framework",
has_experiments=False,
condition_summaries={},
)
# Should have: arch_diagram + pipeline (image only, no code)
assert len(decisions) >= 2
assert all(d["backend"] == "image" for d in decisions)
def test_ablation_trigger(self):
"""When >= 4 conditions, an ablation figure should be added."""
agent = self._agent()
decisions = agent._heuristic_decide(
topic="Test",
has_experiments=True,
condition_summaries={"a": {}, "b": {}, "c": {}, "d": {}},
)
descriptions = [d["description"].lower() for d in decisions]
assert any("ablation" in desc for desc in descriptions)
def test_max_figures_respected(self):
agent = self._agent(max_figures=2)
decisions = agent._heuristic_decide(
topic="Test",
has_experiments=True,
condition_summaries={"a": {}, "b": {}, "c": {}, "d": {}},
)
assert len(decisions) <= 2
# =========================================================================
# FigureDecisionAgent._infer_backend()
# =========================================================================
class TestInferBackend:
def test_code_types(self):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
code_types = [
"bar_comparison", "line_chart", "heatmap", "confusion_matrix",
"training_curve", "ablation_chart", "scatter_plot",
]
for t in code_types:
assert FigureDecisionAgent._infer_backend(t) == "code", f"Failed for {t}"
def test_image_types(self):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
image_types = [
"architecture_diagram", "method_flowchart", "pipeline_overview",
"concept_illustration", "system_diagram",
]
for t in image_types:
assert FigureDecisionAgent._infer_backend(t) == "image", f"Failed for {t}"
def test_unknown_defaults_to_image(self):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
assert FigureDecisionAgent._infer_backend("unknown_chart_type") == "image"
# =========================================================================
# FigureDecisionAgent._enforce_bounds()
# =========================================================================
class TestEnforceBounds:
def _agent(self, min_figures=3, max_figures=6):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
return FigureDecisionAgent(
_FakeLLM(), min_figures=min_figures, max_figures=max_figures
)
def test_min_padding(self):
"""When fewer than min figures, should pad."""
agent = self._agent(min_figures=4)
decisions = [
{"section": "Results", "figure_type": "bar_comparison",
"backend": "code", "description": "Test", "priority": 1},
]
result = agent._enforce_bounds(decisions, has_experiments=True)
assert len(result) >= 4
def test_max_truncation(self):
"""When more than max figures, should truncate."""
agent = self._agent(max_figures=3)
decisions = [
{"section": f"S{i}", "figure_type": "bar_comparison",
"backend": "code", "description": f"Fig {i}", "priority": i}
for i in range(8)
]
result = agent._enforce_bounds(decisions, has_experiments=True)
assert len(result) <= 3
def test_ensures_image_figure(self):
"""Should add architecture diagram if none present."""
agent = self._agent(min_figures=1)
decisions = [
{"section": "Results", "figure_type": "bar_comparison",
"backend": "code", "description": "Bar", "priority": 1},
]
result = agent._enforce_bounds(decisions, has_experiments=True)
assert any(d["backend"] == "image" for d in result)
def test_ensures_code_figure_with_experiments(self):
"""Should add bar_comparison if experiments exist but no code figure."""
agent = self._agent(min_figures=1)
decisions = [
{"section": "Method", "figure_type": "architecture_diagram",
"backend": "image", "description": "Arch", "priority": 1},
]
result = agent._enforce_bounds(decisions, has_experiments=True)
assert any(d["backend"] == "code" for d in result)
# =========================================================================
# NanoBananaAgent._build_prompt()
# =========================================================================
class TestBuildPrompt:
def _agent(self):
from researchclaw.agents.figure_agent.nano_banana import NanoBananaAgent
return NanoBananaAgent(
_FakeLLM(), gemini_api_key="fake-key", use_sdk=False,
)
def test_prompt_contains_description(self):
agent = self._agent()
prompt = agent._build_prompt(
description="Encoder-decoder with attention",
figure_type="architecture_diagram",
section="Method",
topic="Graph anomaly detection",
)
assert "Encoder-decoder with attention" in prompt
assert "Method" in prompt
assert "Graph anomaly detection" in prompt
def test_prompt_contains_style(self):
agent = self._agent()
prompt = agent._build_prompt(
description="Test",
figure_type="architecture_diagram",
section="Method",
topic="Test",
)
assert "academic" in prompt.lower()
assert "publication" in prompt.lower()
def test_prompt_varies_by_type(self):
agent = self._agent()
arch_prompt = agent._build_prompt(
description="Test", figure_type="architecture_diagram",
section="Method", topic="Test",
)
flow_prompt = agent._build_prompt(
description="Test", figure_type="method_flowchart",
section="Method", topic="Test",
)
# Different guidelines for different types
assert arch_prompt != flow_prompt
# =========================================================================
# NanoBananaAgent._get_type_guidelines()
# =========================================================================
class TestGetTypeGuidelines:
def test_known_types(self):
from researchclaw.agents.figure_agent.nano_banana import NanoBananaAgent
known = [
"architecture_diagram", "method_flowchart", "pipeline_overview",
"concept_illustration", "system_diagram", "attention_visualization",
"comparison_illustration",
]
for t in known:
g = NanoBananaAgent._get_type_guidelines(t)
assert len(g) > 0, f"Empty guidelines for {t}"
def test_unknown_type_falls_back(self):
from researchclaw.agents.figure_agent.nano_banana import NanoBananaAgent
g = NanoBananaAgent._get_type_guidelines("totally_unknown")
fallback = NanoBananaAgent._get_type_guidelines("concept_illustration")
assert g == fallback
# =========================================================================
# NanoBananaAgent — no API key
# =========================================================================
class TestNanoBananaNoKey:
def test_execute_without_key_fails(self, tmp_path):
from researchclaw.agents.figure_agent.nano_banana import NanoBananaAgent
# Clear env
with mock.patch.dict(os.environ, {}, clear=True):
agent = NanoBananaAgent(
_FakeLLM(), gemini_api_key="", use_sdk=False,
)
result = agent.execute({
"image_figures": [
{"figure_id": "fig_1", "description": "Test",
"figure_type": "architecture_diagram", "section": "Method"},
],
"topic": "Test",
"output_dir": str(tmp_path),
})
assert not result.success
assert "API key" in result.error
def test_execute_empty_figures_succeeds(self, tmp_path):
from researchclaw.agents.figure_agent.nano_banana import NanoBananaAgent
with mock.patch.dict(os.environ, {}, clear=True):
agent = NanoBananaAgent(
_FakeLLM(), gemini_api_key="", use_sdk=False,
)
result = agent.execute({
"image_figures": [],
"topic": "Test",
"output_dir": str(tmp_path),
})
assert result.success
assert result.data["count"] == 0
# =========================================================================
# RendererAgent._execute_in_docker() — Docker command construction
# =========================================================================
class TestDockerRenderer:
def _agent(self):
from researchclaw.agents.figure_agent.renderer import RendererAgent
return RendererAgent(
_FakeLLM(),
timeout_sec=10,
use_docker=True,
docker_image="researchclaw/experiment:latest",
)
def test_docker_command_construction(self, tmp_path):
"""Verify docker command includes security flags."""
agent = self._agent()
script_path = tmp_path / "scripts" / "fig_test.py"
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text("print('hello')")
output_dir = tmp_path / "output"
output_dir.mkdir()
with mock.patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(
args=[], returncode=0, stdout="", stderr=""
)
agent._execute_in_docker(
script_path=script_path,
output_dir=output_dir,
figure_id="fig_test",
)
args = mock_run.call_args
cmd = args[0][0]
# Verify security flags
assert "--network" in cmd
assert "none" in cmd
assert "--read-only" in cmd
assert "--rm" in cmd
assert "--memory=512m" in cmd
# Verify mount binds
cmd_str = " ".join(cmd)
assert "script.py:ro" in cmd_str # read-only script
assert "output:rw" in cmd_str # writable output
def test_docker_timeout_kills_container(self, tmp_path):
"""Verify container is killed on timeout."""
agent = self._agent()
script_path = tmp_path / "scripts" / "fig_timeout.py"
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text("import time; time.sleep(999)")
output_dir = tmp_path / "output"
output_dir.mkdir()
with mock.patch("subprocess.run") as mock_run:
mock_run.side_effect = subprocess.TimeoutExpired(
cmd=["docker", "run"], timeout=10
)
result = agent._execute_in_docker(
script_path=script_path,
output_dir=output_dir,
figure_id="fig_timeout",
)
assert "timed out" in result["error"]
def test_docker_not_found(self, tmp_path):
"""Verify graceful handling when Docker is not installed."""
agent = self._agent()
script_path = tmp_path / "scripts" / "fig_no_docker.py"
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text("print('hello')")
output_dir = tmp_path / "output"
output_dir.mkdir()
with mock.patch("subprocess.run") as mock_run:
mock_run.side_effect = FileNotFoundError("docker not found")
result = agent._execute_in_docker(
script_path=script_path,
output_dir=output_dir,
figure_id="fig_no_docker",
)
assert "not found" in result["error"]
def test_docker_script_failure(self, tmp_path):
"""Verify error message includes stderr on non-zero exit."""
agent = self._agent()
script_path = tmp_path / "scripts" / "fig_fail.py"
script_path.parent.mkdir(parents=True, exist_ok=True)
script_path.write_text("raise Exception('boom')")
output_dir = tmp_path / "output"
output_dir.mkdir()
with mock.patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(
args=[], returncode=1,
stdout="", stderr="Traceback: Exception: boom",
)
result = agent._execute_in_docker(
script_path=script_path,
output_dir=output_dir,
figure_id="fig_fail",
)
assert result["error"]
assert "boom" in result["error"]
# =========================================================================
# strip_thinking_tags() — safety tests
# =========================================================================
class TestStripThinkingTags:
def test_closed_tags_removed(self):
from researchclaw.utils.thinking_tags import strip_thinking_tags
text = "Hello internal reasoning World"
assert strip_thinking_tags(text) == "Hello World"
def test_no_tags(self):
from researchclaw.utils.thinking_tags import strip_thinking_tags
text = "Normal text without tags"
assert strip_thinking_tags(text) == text
def test_empty_string(self):
from researchclaw.utils.thinking_tags import strip_thinking_tags
assert strip_thinking_tags("") == ""
def test_nested_code_preserved(self):
"""Literal in code blocks should NOT be corrupted
when used by chat() without strip_thinking=True."""
text = '```python\n# The tag is used by...\nprint("hello")\n```'
# Without stripping, text is untouched
assert "" in text
def test_unclosed_tag_behavior(self):
"""Document the behavior: unclosed removes everything after it."""
from researchclaw.utils.thinking_tags import strip_thinking_tags
text = "Prefix reasoning that never closes"
result = strip_thinking_tags(text)
# The unclosed tag strips everything after
assert "Prefix" in result
assert "reasoning" not in result
# =========================================================================
# FigureDecisionAgent.execute() — full integration with mock LLM
# =========================================================================
class TestDecisionAgentExecute:
def test_llm_decision(self):
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
llm_response = json.dumps([
{"section": "Method", "figure_type": "architecture_diagram",
"backend": "image", "description": "Arch", "priority": 1},
{"section": "Results", "figure_type": "bar_comparison",
"backend": "code", "description": "Results", "priority": 1},
{"section": "Results", "figure_type": "heatmap",
"backend": "code", "description": "Heatmap", "priority": 2},
])
agent = FigureDecisionAgent(_FakeLLM(llm_response), min_figures=3)
result = agent.execute({
"topic": "Graph anomaly detection",
"hypothesis": "GRACE improves detection",
"paper_draft": "# Introduction\n...",
"has_experiments": True,
"condition_summaries": {"proposed": {}, "baseline": {}},
})
assert result.success
assert result.data["total"] >= 3
assert len(result.data["code_figures"]) >= 1
assert len(result.data["image_figures"]) >= 1
def test_fallback_on_bad_llm(self):
"""When LLM returns garbage, heuristic fallback should kick in."""
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
agent = FigureDecisionAgent(
_FakeLLM("This is not JSON"),
min_figures=3,
)
result = agent.execute({
"topic": "Test topic",
"has_experiments": True,
"condition_summaries": {"a": {}, "b": {}},
})
assert result.success # fallback succeeds
assert result.data["total"] >= 3
def test_fallback_on_no_llm(self):
"""When LLM is None, heuristic fallback should work."""
from researchclaw.agents.figure_agent.decision import FigureDecisionAgent
agent = FigureDecisionAgent(None, min_figures=2)
result = agent.execute({
"topic": "Test",
"has_experiments": False,
"condition_summaries": {},
})
assert result.success
assert result.data["total"] >= 2
# =========================================================================
# CWD regression test (Issue #2)
# =========================================================================
class TestRendererCwd:
"""Verify the CWD is set to output_dir, not its parent."""
def test_local_cwd_is_output_dir(self, tmp_path):
"""Scripts using relative savefig should write to output_dir."""
from researchclaw.agents.figure_agent.renderer import RendererAgent
agent = RendererAgent(_FakeLLM(), timeout_sec=10, use_docker=False)
output_dir = tmp_path / "charts"
with mock.patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(
args=[], returncode=0, stdout="", stderr=""
)
agent._execute_local(
script_path=tmp_path / "test.py",
output_dir=output_dir,
)
call_kwargs = mock_run.call_args
cwd = call_kwargs[1]["cwd"] if isinstance(call_kwargs[1], dict) else None
# CWD should be output_dir, NOT output_dir.parent
assert cwd == str(output_dir.resolve())
# =========================================================================
# chat(strip_thinking=True) — opt-in parameter (Issue #1 fix)
# =========================================================================
class TestChatStripThinking:
"""Verify the opt-in strip_thinking parameter on LLMClient.chat()."""
def test_strip_thinking_false_by_default(self):
"""Default chat() should NOT strip tags."""
from researchclaw.llm.client import LLMClient, LLMConfig, LLMResponse
config = LLMConfig(
base_url="http://fake",
api_key="fake-key",
primary_model="test-model",
)
client = LLMClient(config)
response_with_think = (
'internal reasoning The actual answer is 42.'
)
fake_api_response = {
"choices": [{
"message": {"content": response_with_think},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
"model": "test-model",
}
with mock.patch("urllib.request.urlopen") as mock_urlopen:
mock_resp = mock.MagicMock()
mock_resp.read.return_value = json.dumps(fake_api_response).encode()
mock_resp.__enter__ = mock.MagicMock(return_value=mock_resp)
mock_resp.__exit__ = mock.MagicMock(return_value=False)
mock_urlopen.return_value = mock_resp
result = client.chat(
[{"role": "user", "content": "test"}],
strip_thinking=False,
)
# With strip_thinking=False, tags are preserved
assert "" in result.content
def test_strip_thinking_true_removes_tags(self):
"""chat(strip_thinking=True) should strip tags."""
from researchclaw.llm.client import LLMClient, LLMConfig
config = LLMConfig(
base_url="http://fake",
api_key="fake-key",
primary_model="test-model",
)
client = LLMClient(config)
response_with_think = (
'internal reasoning The actual answer is 42.'
)
fake_api_response = {
"choices": [{
"message": {"content": response_with_think},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
"model": "test-model",
}
with mock.patch("urllib.request.urlopen") as mock_urlopen:
mock_resp = mock.MagicMock()
mock_resp.read.return_value = json.dumps(fake_api_response).encode()
mock_resp.__enter__ = mock.MagicMock(return_value=mock_resp)
mock_resp.__exit__ = mock.MagicMock(return_value=False)
mock_urlopen.return_value = mock_resp
result = client.chat(
[{"role": "user", "content": "test"}],
strip_thinking=True,
)
# With strip_thinking=True, tags are removed
assert "" not in result.content
assert "The actual answer is 42." in result.content
# =========================================================================
# LaTeX converter — display math $$...$$ fix
# =========================================================================
class TestLatexDisplayMath:
"""Verify the $$...$$ → equation environment fix in converter.py."""
def test_dollar_dollar_to_equation(self):
"""$$...$$ display math should become \\begin{equation}."""
from researchclaw.templates.converter import _convert_block
md = (
"Some text before.\n"
"\n"
"$$\\alpha_{ij} = \\frac{x}{y}$$\n"
"\n"
"Some text after."
)
result = _convert_block(md)
assert "\\begin{equation}" in result
assert "\\end{equation}" in result
assert "\\alpha_{ij}" in result
# Should NOT contain escaped $$
assert "\\$\\$" not in result
def test_multiline_dollar_dollar(self):
"""$$...$$ spanning multiple lines should also convert."""
from researchclaw.templates.converter import _convert_block
md = (
"$$\n"
"\\mathcal{L} = -\\log \\frac{a}{b}\n"
"$$\n"
)
result = _convert_block(md)
assert "\\begin{equation}" in result
assert "\\mathcal{L}" in result
def test_inline_dollar_dollar_not_escaped(self):
"""$$ in inline context should not be escaped to \\$\\$."""
from researchclaw.templates.converter import _convert_inline
text = "The formula $$x+y$$ is important"
result = _convert_inline(text)
# Should not contain \\textasciicircum or \\$
assert "\\textasciicircum" not in result
# =========================================================================
# LaTeX converter — figure [t] placement
# =========================================================================
class TestLatexFigurePlacement:
"""Verify figures use [t] placement specifier."""
def test_figure_uses_top_placement(self):
from researchclaw.templates.converter import _render_figure
result = _render_figure("Test Caption", "charts/test.png")
assert "\\begin{figure}[t]" in result
assert "[ht]" not in result
def test_figure_has_centering(self):
from researchclaw.templates.converter import _render_figure
result = _render_figure("My Figure", "path/to/image.png")
assert "\\centering" in result
assert "\\includegraphics" in result
assert "\\caption{My Figure}" in result
assert "\\label{fig:" in result
# =========================================================================
# Pipeline wrapper — _chat_with_prompt strip_thinking default
# =========================================================================
class TestChatWithPromptStripThinking:
"""Verify _chat_with_prompt passes strip_thinking to llm.chat()."""
def test_default_strips_thinking(self):
"""_chat_with_prompt should pass strip_thinking=True by default."""
from unittest.mock import MagicMock
from researchclaw.pipeline.executor import _chat_with_prompt
from researchclaw.llm.client import LLMResponse
mock_llm = MagicMock()
mock_llm.chat.return_value = LLMResponse(
content="clean output", model="test", finish_reason="stop",
)
result = _chat_with_prompt(mock_llm, system="sys", user="hello")
call_kwargs = mock_llm.chat.call_args
assert call_kwargs.kwargs.get("strip_thinking") is True
def test_can_disable_stripping(self):
"""_chat_with_prompt(strip_thinking=False) should forward the flag."""
from unittest.mock import MagicMock
from researchclaw.pipeline.executor import _chat_with_prompt
from researchclaw.llm.client import LLMResponse
mock_llm = MagicMock()
mock_llm.chat.return_value = LLMResponse(
content="reasoning output",
model="test", finish_reason="stop",
)
_chat_with_prompt(
mock_llm, system="sys", user="hello", strip_thinking=False,
)
call_kwargs = mock_llm.chat.call_args
assert call_kwargs.kwargs.get("strip_thinking") is False
================================================
FILE: tests/test_domain_detector.py
================================================
"""Tests for domain detection and profile loading."""
from __future__ import annotations
import pytest
from pathlib import Path
from researchclaw.domains.detector import (
DomainProfile,
ExperimentParadigm,
MetricType,
detect_domain,
detect_domain_id,
get_generic_profile,
get_profile,
is_ml_domain,
load_all_profiles,
_keyword_detect,
_profile_cache,
)
# ---------------------------------------------------------------------------
# Profile loading tests
# ---------------------------------------------------------------------------
class TestProfileLoading:
def setup_method(self):
_profile_cache.clear()
def test_load_all_profiles_returns_dict(self):
profiles = load_all_profiles()
assert isinstance(profiles, dict)
assert len(profiles) >= 10 # we created 14 profiles
def test_profiles_have_required_fields(self):
profiles = load_all_profiles()
for domain_id, profile in profiles.items():
assert profile.domain_id == domain_id
assert profile.display_name
assert profile.experiment_paradigm
assert profile.entry_point
def test_get_profile_existing(self):
profile = get_profile("ml_vision")
assert profile is not None
assert profile.domain_id == "ml_vision"
assert profile.display_name == "Computer Vision (ML)"
assert profile.gpu_required is True
def test_get_profile_nonexistent(self):
profile = get_profile("nonexistent_domain_xyz")
assert profile is None
def test_get_generic_profile(self):
profile = get_generic_profile()
assert profile.domain_id == "generic"
assert "numpy" in profile.core_libraries
def test_ml_profiles_exist(self):
for domain_id in ["ml_vision", "ml_nlp", "ml_rl", "ml_generic"]:
profile = get_profile(domain_id)
assert profile is not None, f"Missing profile: {domain_id}"
def test_physics_profiles_exist(self):
for domain_id in ["physics_simulation", "physics_pde"]:
profile = get_profile(domain_id)
assert profile is not None, f"Missing profile: {domain_id}"
def test_other_domain_profiles_exist(self):
for domain_id in [
"mathematics_numerical",
"chemistry_qm",
"chemistry_molprop",
"biology_singlecell",
"economics_empirical",
"security_detection",
"robotics_control",
]:
profile = get_profile(domain_id)
assert profile is not None, f"Missing profile: {domain_id}"
def test_physics_profile_paradigm(self):
profile = get_profile("physics_pde")
assert profile is not None
assert profile.experiment_paradigm == "convergence"
assert "convergence_order_fit" in profile.statistical_tests
def test_economics_profile_paradigm(self):
profile = get_profile("economics_empirical")
assert profile is not None
assert profile.experiment_paradigm == "progressive_spec"
assert "hausman_test" in profile.statistical_tests
# ---------------------------------------------------------------------------
# Keyword detection tests
# ---------------------------------------------------------------------------
class TestKeywordDetection:
def test_ml_vision_keywords(self):
assert _keyword_detect("image classification with ResNet") == "ml_vision"
assert _keyword_detect("convolutional neural network for object detection") == "ml_vision"
def test_ml_nlp_keywords(self):
assert _keyword_detect("text classification using BERT") == "ml_nlp"
assert _keyword_detect("natural language processing transformer") == "ml_nlp"
def test_ml_rl_keywords(self):
assert _keyword_detect("reinforcement learning policy gradient") == "ml_rl"
assert _keyword_detect("actor-critic algorithm for robot control") == "ml_rl"
def test_physics_keywords(self):
assert _keyword_detect("molecular dynamics simulation with Lennard-Jones") == "physics_simulation"
assert _keyword_detect("finite element method for Navier-Stokes equation") == "physics_pde"
def test_chemistry_keywords(self):
assert _keyword_detect("DFT calculation with PySCF") == "chemistry_qm"
assert _keyword_detect("molecular property prediction using RDKit fingerprints") == "chemistry_molprop"
def test_biology_keywords(self):
assert _keyword_detect("single-cell RNA-seq analysis with scanpy") == "biology_singlecell"
def test_economics_keywords(self):
assert _keyword_detect("panel data regression with fixed effects") == "economics_empirical"
assert _keyword_detect("instrumental variable causal inference") == "economics_empirical"
def test_math_keywords(self):
assert _keyword_detect("Runge-Kutta ODE solver convergence") == "mathematics_numerical"
assert _keyword_detect("numerical analysis of quadrature methods") == "mathematics_numerical"
def test_security_keywords(self):
assert _keyword_detect("intrusion detection system for network traffic") == "security_detection"
def test_robotics_keywords(self):
assert _keyword_detect("robot manipulation with MuJoCo") == "robotics_control"
def test_generic_ml_fallback(self):
assert _keyword_detect("neural network training with pytorch") == "ml_generic"
assert _keyword_detect("deep learning for regression") == "ml_generic"
def test_unknown_topic(self):
assert _keyword_detect("cooking recipes for italian food") is None
def test_case_insensitive(self):
assert _keyword_detect("IMAGE CLASSIFICATION WITH RESNET") == "ml_vision"
assert _keyword_detect("DFT Calculation") == "chemistry_qm"
# ---------------------------------------------------------------------------
# detect_domain tests
# ---------------------------------------------------------------------------
class TestDetectDomain:
def test_detect_ml_vision(self):
profile = detect_domain("image classification on CIFAR-10")
assert is_ml_domain(profile)
assert profile.domain_id == "ml_vision"
def test_detect_physics(self):
profile = detect_domain("molecular dynamics simulation of Lennard-Jones fluid")
assert profile.domain_id == "physics_simulation"
assert not is_ml_domain(profile)
def test_detect_with_hypotheses(self):
profile = detect_domain(
topic="novel numerical scheme",
hypotheses="We propose a 4th order finite difference scheme for the Poisson equation",
)
assert profile.domain_id == "physics_pde"
def test_detect_generic_fallback(self):
profile = detect_domain("studying the behavior of abstract systems")
assert profile.domain_id == "generic"
def test_detect_domain_id_shortcut(self):
domain_id = detect_domain_id("image classification")
assert domain_id == "ml_vision"
domain_id = detect_domain_id("cooking recipes")
assert domain_id == "generic"
# ---------------------------------------------------------------------------
# is_ml_domain tests
# ---------------------------------------------------------------------------
class TestIsMLDomain:
def test_ml_domains(self):
for domain_id in ["ml_vision", "ml_nlp", "ml_rl", "ml_generic"]:
profile = get_profile(domain_id)
assert profile is not None
assert is_ml_domain(profile)
def test_non_ml_domains(self):
for domain_id in ["physics_simulation", "chemistry_qm", "economics_empirical"]:
profile = get_profile(domain_id)
assert profile is not None
assert not is_ml_domain(profile)
def test_generic_not_ml(self):
profile = get_generic_profile()
assert not is_ml_domain(profile)
# ---------------------------------------------------------------------------
# DomainProfile dataclass tests
# ---------------------------------------------------------------------------
class TestDomainProfile:
def test_default_values(self):
profile = DomainProfile(domain_id="test", display_name="Test")
assert profile.experiment_paradigm == ExperimentParadigm.COMPARISON.value
assert profile.entry_point == "main.py"
assert profile.gpu_required is False
assert "paired_t_test" in profile.statistical_tests
def test_custom_values(self):
profile = DomainProfile(
domain_id="custom",
display_name="Custom Domain",
experiment_paradigm="convergence",
gpu_required=True,
core_libraries=["numpy", "custom_lib"],
)
assert profile.experiment_paradigm == "convergence"
assert profile.gpu_required is True
assert "custom_lib" in profile.core_libraries
# ---------------------------------------------------------------------------
# Enum tests
# ---------------------------------------------------------------------------
class TestEnums:
def test_experiment_paradigm_values(self):
assert ExperimentParadigm.COMPARISON.value == "comparison"
assert ExperimentParadigm.CONVERGENCE.value == "convergence"
assert ExperimentParadigm.PROGRESSIVE_SPEC.value == "progressive_spec"
assert ExperimentParadigm.SIMULATION.value == "simulation"
def test_metric_type_values(self):
assert MetricType.SCALAR.value == "scalar"
assert MetricType.TABLE.value == "table"
assert MetricType.CONVERGENCE.value == "convergence"
# ---------------------------------------------------------------------------
# Domain detection accuracy test (50-topic benchmark)
# ---------------------------------------------------------------------------
class TestDetectionAccuracy:
"""Test domain detection accuracy on a diverse set of topics."""
TOPIC_EXPECTATIONS = [
# ML topics
("Image classification with ResNet on CIFAR-10", "ml_vision"),
("Object detection using YOLO", "ml_vision"),
("Text sentiment analysis with BERT", "ml_nlp"),
("Language model fine-tuning", "ml_nlp"),
("Reinforcement learning for Atari games", "ml_rl"),
("Policy gradient optimization in continuous control", "ml_rl"),
("Graph neural network for node classification", "ml_graph"),
("Knowledge distillation from large teacher models", "ml_compression"),
("GAN for image synthesis", "ml_generative"),
("Tabular data prediction with XGBoost", "ml_tabular"),
("Deep learning regression model", "ml_generic"),
("Neural network for time series forecasting", "ml_generic"),
# Physics topics
("Molecular dynamics of Lennard-Jones particles", "physics_simulation"),
("N-body gravitational simulation", "physics_simulation"),
("Symplectic integrator for Hamiltonian systems", "physics_simulation"),
("Finite element solution of Poisson equation", "physics_pde"),
("Heat equation solver comparison", "physics_pde"),
("Navier-Stokes finite difference scheme", "physics_pde"),
# Chemistry topics
("Hartree-Fock calculation for small molecules", "chemistry_qm"),
("DFT energy with PySCF", "chemistry_qm"),
("Molecular property prediction from SMILES", "chemistry_molprop"),
("Drug binding affinity with RDKit fingerprints", "chemistry_molprop"),
# Biology topics
("Single-cell clustering with scanpy", "biology_singlecell"),
("scRNA-seq differential expression analysis", "biology_singlecell"),
("Genome variant calling pipeline", "biology_genomics"),
("Protein folding prediction", "biology_protein"),
# Economics topics
("Panel data regression with fixed effects", "economics_empirical"),
("Instrumental variable estimation", "economics_empirical"),
("Causal inference with difference-in-differences", "economics_empirical"),
# Math topics
("Runge-Kutta ODE solver convergence analysis", "mathematics_numerical"),
("Numerical quadrature comparison", "mathematics_numerical"),
("Convex optimization benchmark", "mathematics_optimization"),
# Security topics
("Network intrusion detection system", "security_detection"),
("Malware classification using random forest", "security_detection"),
# Robotics topics
("Robot manipulation policy learning", "robotics_control"),
("Locomotion control with MuJoCo", "robotics_control"),
]
def test_keyword_detection_accuracy(self):
"""Test that keyword detection achieves > 90% accuracy."""
correct = 0
total = len(self.TOPIC_EXPECTATIONS)
for topic, expected_domain in self.TOPIC_EXPECTATIONS:
detected = _keyword_detect(topic)
if detected == expected_domain:
correct += 1
accuracy = correct / total
assert accuracy > 0.90, (
f"Keyword detection accuracy: {accuracy:.1%} ({correct}/{total}). "
f"Expected > 90%."
)
def test_full_detection_accuracy(self):
"""Test that full detect_domain achieves > 90% accuracy."""
correct = 0
total = len(self.TOPIC_EXPECTATIONS)
for topic, expected_domain in self.TOPIC_EXPECTATIONS:
profile = detect_domain(topic)
if profile.domain_id == expected_domain:
correct += 1
accuracy = correct / total
assert accuracy > 0.90, (
f"Full detection accuracy: {accuracy:.1%} ({correct}/{total}). "
f"Expected > 90%."
)
================================================
FILE: tests/test_entry_point_validation.py
================================================
"""Tests for entry point path traversal validation."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from researchclaw.experiment.sandbox import (
ExperimentSandbox,
validate_entry_point,
validate_entry_point_resolved,
)
# ── Unit tests: validate_entry_point (syntax) ─────────────────────────
class TestValidateEntryPoint:
"""Syntax-only checks — no filesystem needed."""
def test_valid_entry_point(self) -> None:
assert validate_entry_point("main.py") is None
def test_valid_nested_entry_point(self) -> None:
assert validate_entry_point("src/train.py") is None
def test_valid_dot_slash_prefix(self) -> None:
assert validate_entry_point("./main.py") is None
def test_valid_dot_in_middle(self) -> None:
assert validate_entry_point("src/./train.py") is None
def test_valid_deeply_nested(self) -> None:
assert validate_entry_point("a/b/c/d/main.py") is None
def test_rejects_absolute_path(self) -> None:
err = validate_entry_point("/etc/passwd")
assert err is not None
assert "relative" in err.lower() or "absolute" in err.lower()
def test_rejects_path_traversal(self) -> None:
err = validate_entry_point("../../../etc/passwd")
assert err is not None
assert ".." in err
def test_rejects_dotdot_in_middle(self) -> None:
err = validate_entry_point("src/../../etc/passwd")
assert err is not None
assert ".." in err
def test_rejects_empty_string(self) -> None:
err = validate_entry_point("")
assert err is not None
assert "empty" in err.lower()
def test_rejects_whitespace_only(self) -> None:
err = validate_entry_point(" ")
assert err is not None
assert "empty" in err.lower()
# ── Unit tests: validate_entry_point_resolved (containment) ───────────
class TestValidateEntryPointResolved:
"""Resolve-based checks — needs a real staging directory."""
def test_valid_path_passes(self, tmp_path: Path) -> None:
(tmp_path / "main.py").write_text("pass")
assert validate_entry_point_resolved(tmp_path, "main.py") is None
def test_symlink_escape_rejected(self, tmp_path: Path) -> None:
"""A symlink pointing outside staging must be caught."""
escape_target = tmp_path / "outside" / "secret.py"
escape_target.parent.mkdir()
escape_target.write_text("print('escaped!')")
staging = tmp_path / "staging"
staging.mkdir()
(staging / "legit.py").symlink_to(escape_target)
err = validate_entry_point_resolved(staging, "legit.py")
assert err is not None
assert "escapes" in err.lower()
def test_nested_valid_path_passes(self, tmp_path: Path) -> None:
sub = tmp_path / "src"
sub.mkdir()
(sub / "train.py").write_text("pass")
assert validate_entry_point_resolved(tmp_path, "src/train.py") is None
# ── Integration tests: ExperimentSandbox.run_project() ────────────────
class TestExperimentSandboxEntryPointValidation:
"""Verify validation is wired into ExperimentSandbox.run_project()."""
def _make_sandbox(self, tmp_path: Path) -> ExperimentSandbox:
from researchclaw.config import SandboxConfig
cfg = SandboxConfig()
return ExperimentSandbox(cfg, tmp_path / "work")
def test_rejects_path_traversal(self, tmp_path: Path) -> None:
project = tmp_path / "proj"
project.mkdir()
(project / "main.py").write_text("print('hi')")
sandbox = self._make_sandbox(tmp_path)
# Create escape target so .exists() alone wouldn't catch it
work = tmp_path / "work"
work.mkdir(parents=True, exist_ok=True)
(work / "escape.py").write_text("print('escaped!')")
with patch("subprocess.run") as mock_run:
result = sandbox.run_project(project, entry_point="../escape.py")
assert result.returncode == -1
assert ".." in result.stderr
mock_run.assert_not_called()
def test_rejects_absolute_path(self, tmp_path: Path) -> None:
project = tmp_path / "proj"
project.mkdir()
(project / "main.py").write_text("print('hi')")
sandbox = self._make_sandbox(tmp_path)
with patch("subprocess.run") as mock_run:
result = sandbox.run_project(project, entry_point="/etc/passwd")
assert result.returncode == -1
assert "relative" in result.stderr.lower() or "absolute" in result.stderr.lower()
mock_run.assert_not_called()
# NOTE: A symlink integration test is not included here because the
# copy loop (write_bytes/read_bytes) follows symlinks and creates
# regular files in staging. The resolve check is defense-in-depth
# for future copy mechanism changes; see
# TestValidateEntryPointResolved.test_symlink_escape_rejected for
# the unit-level proof that the function catches symlink escapes.
================================================
FILE: tests/test_experiment_diagnosis.py
================================================
"""Tests for experiment_diagnosis — failure analysis agent."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.pipeline.experiment_diagnosis import (
DeficiencyType,
ExperimentDiagnosis,
ExperimentQualityAssessment,
PaperMode,
assess_experiment_quality,
diagnose_experiment,
)
ARTIFACTS = Path(__file__).resolve().parent.parent / "artifacts"
# ---------------------------------------------------------------------------
# Unit tests — individual checks
# ---------------------------------------------------------------------------
class TestMissingDependency:
def test_detects_module_not_found(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stdout="",
stderr="ModuleNotFoundError: No module named 'utils'",
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.MISSING_DEPENDENCY in types
def test_detects_box2d(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stdout="BOX2D_WARNING: Box2D/LunarLander-v3 not available; skipping",
stderr="",
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.MISSING_DEPENDENCY in types
class TestPermissionError:
def test_detects_hf_permission(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stdout="",
stderr="PermissionError: Cannot download huggingface model",
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.PERMISSION_ERROR in types
class TestTimeGuard:
def test_detects_dominant_time_guard(self):
summary = {
"condition_summaries": {"CondA": {"metrics": {"metric": 80.0}}},
"best_run": {"metrics": {}},
}
plan = {"conditions": [{"name": "CondA"}, {"name": "CondB"}, {"name": "CondC"}, {"name": "CondD"}]}
diag = diagnose_experiment(
experiment_summary=summary,
experiment_plan=plan,
stdout="TIME_GUARD: skipping CondB\nTIME_GUARD: skipping CondC\nTIME_GUARD: skipping CondD",
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.TIME_GUARD_DOMINANT in types
def test_no_time_guard_if_most_complete(self):
summary = {
"condition_summaries": {
"A": {"metrics": {"metric": 1.0}},
"B": {"metrics": {"metric": 2.0}},
"C": {"metrics": {"metric": 3.0}},
},
"best_run": {"metrics": {}},
}
plan = {"conditions": [{"name": "A"}, {"name": "B"}, {"name": "C"}, {"name": "D"}]}
diag = diagnose_experiment(experiment_summary=summary, experiment_plan=plan)
types = {d.type for d in diag.deficiencies}
# 1/4 skipped = 25%, below 50% threshold
assert DeficiencyType.TIME_GUARD_DOMINANT not in types
class TestSyntheticData:
def test_detects_synthetic_fallback(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stdout="[data] WARNING: Alpaca load failed ... using synthetic data.",
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.SYNTHETIC_DATA_FALLBACK in types
class TestGPUOOM:
def test_detects_oom(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stderr="RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB",
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.GPU_OOM in types
class TestIdenticalConditions:
def test_detects_from_ablation_warnings(self):
summary = {
"condition_summaries": {"A": {"metrics": {"m": 1}}, "B": {"metrics": {"m": 1}}},
"best_run": {"metrics": {}},
"ablation_warnings": [
"ABLATION FAILURE: Conditions 'A' and 'B' produce identical outputs across all 1 metrics."
],
}
diag = diagnose_experiment(experiment_summary=summary)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.IDENTICAL_CONDITIONS in types
class TestCodeCrash:
def test_detects_traceback(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stderr=(
"Traceback (most recent call last):\n"
" File 'main.py', line 42, in main\n"
" result = train(model)\n"
"TypeError: train() missing argument 'data'\n"
),
)
types = {d.type for d in diag.deficiencies}
assert DeficiencyType.CODE_CRASH in types
# ---------------------------------------------------------------------------
# Quality assessment
# ---------------------------------------------------------------------------
class TestQualityAssessment:
def test_full_paper_mode(self):
summary = {
"condition_summaries": {
"A": {"metrics": {"metric": 80.0}},
"B": {"metrics": {"metric": 85.0}},
"C": {"metrics": {"metric": 90.0}},
},
"best_run": {
"metrics": {
"A/0/m": 80.0, "A/1/m": 81.0,
"B/0/m": 85.0, "B/1/m": 86.0,
"C/0/m": 90.0, "C/1/m": 91.0,
},
},
}
qa = assess_experiment_quality(summary)
assert qa.mode == PaperMode.FULL_PAPER
assert qa.sufficient
def test_preliminary_study_mode(self):
summary = {
"condition_summaries": {
"A": {"metrics": {"metric": 80.0}},
"B": {"metrics": {"metric": 85.0}},
},
"best_run": {"metrics": {"A/0/m": 80.0, "B/0/m": 85.0}},
}
qa = assess_experiment_quality(summary)
assert qa.mode == PaperMode.PRELIMINARY_STUDY
assert not qa.sufficient
def test_technical_report_no_conditions(self):
summary = {
"condition_summaries": {},
"best_run": {"metrics": {}},
}
qa = assess_experiment_quality(summary)
assert qa.mode == PaperMode.TECHNICAL_REPORT
assert not qa.sufficient
def test_technical_report_synthetic_data(self):
summary = {
"condition_summaries": {"A": {"metrics": {"metric": 80.0}}},
"best_run": {"metrics": {}, "stdout": "using synthetic data"},
}
qa = assess_experiment_quality(summary)
assert qa.mode == PaperMode.TECHNICAL_REPORT
# ---------------------------------------------------------------------------
# Repair prompt generation
# ---------------------------------------------------------------------------
class TestRepairPrompt:
def test_generates_prompt(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stderr="ModuleNotFoundError: No module named 'special_lib'",
)
prompt = diag.to_repair_prompt()
assert "special_lib" in prompt
assert "DIAGNOSIS" in prompt
assert "CRITICAL" in prompt
def test_serialization(self):
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {"A": {"metrics": {"m": 1}}}, "best_run": {"metrics": {}}},
)
d = diag.to_dict()
assert isinstance(d, dict)
assert "deficiencies" in d
assert "conditions_completed" in d
# ---------------------------------------------------------------------------
# Integration — real artifacts
# ---------------------------------------------------------------------------
class TestRealArtifacts:
def _load(self, run_id: str) -> tuple[dict, dict | None]:
pattern = f"rc-*-{run_id}"
matches = sorted(ARTIFACTS.glob(pattern))
if not matches:
pytest.skip(f"Artifact {run_id} not found")
base = matches[0]
summary_path = base / "stage-14" / "experiment_summary.json"
ref_path = base / "stage-13" / "refinement_log.json"
if not summary_path.exists():
pytest.skip(f"No experiment_summary for {run_id}")
summary = json.loads(summary_path.read_text())
ref_log = json.loads(ref_path.read_text()) if ref_path.exists() else None
return summary, ref_log
def test_run_e57360_diagnosis(self):
"""Run 38 — 3/8 conditions completed, Box2D missing."""
summary, ref_log = self._load("e57360")
qa = assess_experiment_quality(summary, ref_log)
# Should identify issues and NOT rate as full_paper
assert qa.mode != PaperMode.FULL_PAPER or len(qa.deficiencies) > 0
def test_run_8b4a1b_diagnosis(self):
"""Run 8b4a1b — all NaN, permission errors."""
summary, ref_log = self._load("8b4a1b")
qa = assess_experiment_quality(summary, ref_log)
# Should be technical_report or preliminary_study at best
assert qa.mode in (PaperMode.TECHNICAL_REPORT, PaperMode.PRELIMINARY_STUDY)
class TestDatasetNotFoundError:
"""BUG-203: HuggingFace DatasetNotFoundError should be caught."""
def test_detects_hf_dataset_not_found(self):
stderr = (
"Traceback (most recent call last):\n"
" File \"/workspace/setup.py\", line 11, in main\n"
"datasets.exceptions.DatasetNotFoundError: "
"Dataset 'cifar10_corrupted' doesn't exist on the Hub or cannot be accessed.\n"
)
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stderr=stderr,
)
ds_issues = [d for d in diag.deficiencies if d.type == DeficiencyType.DATASET_UNAVAILABLE]
assert len(ds_issues) >= 1
assert "HuggingFace" in ds_issues[0].description
# Should NOT also appear as a generic CODE_CRASH
crashes = [d for d in diag.deficiencies if d.type == DeficiencyType.CODE_CRASH]
assert not any("DatasetNotFoundError" in c.description for c in crashes)
def test_suggested_fix_mentions_precached(self):
stderr = (
"DatasetNotFoundError: Dataset 'imagenet_v2' "
"doesn't exist on the Hub or cannot be accessed.\n"
)
diag = diagnose_experiment(
experiment_summary={"condition_summaries": {}, "best_run": {"metrics": {}}},
stderr=stderr,
)
ds_issues = [d for d in diag.deficiencies if d.type == DeficiencyType.DATASET_UNAVAILABLE]
assert any("/opt/datasets" in d.suggested_fix for d in ds_issues)
class TestNearRandomAccuracy:
"""BUG-204: Detect near-random accuracy in experiment results."""
def test_detects_near_random_cifar10(self):
"""8.91% accuracy on CIFAR-10 should be flagged."""
diag = diagnose_experiment(
experiment_summary={
"condition_summaries": {"cond_a": {"metrics": {"top1_accuracy": 8.91}}},
"metrics_summary": {"top1_accuracy": {"min": 8.42, "max": 8.91, "mean": 8.67}},
"best_run": {"metrics": {}},
},
)
hp_issues = [d for d in diag.deficiencies if d.type == DeficiencyType.HYPERPARAMETER_ISSUE]
assert any("random chance" in d.description for d in hp_issues)
def test_normal_accuracy_not_flagged(self):
"""73% accuracy should NOT be flagged."""
diag = diagnose_experiment(
experiment_summary={
"condition_summaries": {"baseline": {"metrics": {"accuracy": 73.07}}},
"metrics_summary": {"accuracy": {"min": 68.0, "max": 73.07, "mean": 70.5}},
"best_run": {"metrics": {}},
},
)
hp_issues = [d for d in diag.deficiencies if d.type == DeficiencyType.HYPERPARAMETER_ISSUE]
assert not any("random chance" in d.description for d in hp_issues)
def test_zero_accuracy_not_flagged(self):
"""0% accuracy (no data) should NOT be flagged by this check."""
diag = diagnose_experiment(
experiment_summary={
"condition_summaries": {},
"metrics_summary": {},
"best_run": {"metrics": {}},
},
)
hp_issues = [d for d in diag.deficiencies if d.type == DeficiencyType.HYPERPARAMETER_ISSUE]
assert not any("random chance" in d.description for d in hp_issues)
class TestRealArtifactsContinued(TestRealArtifacts):
"""Continuation of real artifact tests (after TestDatasetNotFoundError)."""
def test_run_acbdfa_diagnosis(self):
"""Run acbdfa — 2 architectures, S4D nearly random."""
summary, ref_log = self._load("acbdfa")
diag = diagnose_experiment(
experiment_summary=summary,
refinement_log=ref_log,
stdout=summary.get("best_run", {}).get("stdout", ""),
stderr=summary.get("best_run", {}).get("stderr", ""),
)
assert diag.completion_rate > 0
================================================
FILE: tests/test_experiment_repair.py
================================================
"""Tests for experiment_repair — repair loop and prompt generation."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.pipeline.experiment_diagnosis import (
DeficiencyType,
Deficiency,
ExperimentDiagnosis,
PaperMode,
)
from researchclaw.pipeline.experiment_repair import (
ExperimentRepairResult,
RepairCycleResult,
build_repair_prompt,
run_repair_loop,
select_best_results,
_extract_code_blocks,
_build_experiment_summary_from_run,
_load_experiment_code,
_load_experiment_summary,
_summary_quality_score,
)
# ---------------------------------------------------------------------------
# build_repair_prompt tests
# ---------------------------------------------------------------------------
class TestBuildRepairPrompt:
def test_basic_prompt(self):
diag = ExperimentDiagnosis(
deficiencies=[
Deficiency(
type=DeficiencyType.MISSING_DEPENDENCY,
severity="critical",
description="Missing Python package: utils",
suggested_fix="Add 'utils' to requirements.txt",
)
],
conditions_completed=["CondA"],
conditions_failed=["CondB"],
total_planned=2,
completion_rate=0.5,
summary="1 deficiency. 1/2 conditions completed.",
)
prompt = build_repair_prompt(
diagnosis=diag,
original_code={"main.py": "import utils\nprint('hello')"},
time_budget_sec=2400,
)
assert "EXPERIMENT REPAIR TASK" in prompt
assert "utils" in prompt
assert "main.py" in prompt
assert "2400" in prompt
def test_scope_reduction_included(self):
diag = ExperimentDiagnosis(
deficiencies=[
Deficiency(
type=DeficiencyType.TIME_GUARD_DOMINANT,
severity="major",
description="Time guard killed 8/10 conditions",
affected_conditions=["C3", "C4", "C5"],
suggested_fix="Reduce conditions",
)
],
conditions_completed=["C1", "C2"],
conditions_failed=["C3", "C4", "C5", "C6", "C7", "C8", "C9", "C10"],
total_planned=10,
completion_rate=0.2,
)
prompt = build_repair_prompt(diag, original_code={})
assert "SCOPE REDUCTION" in prompt
assert "BASELINE" in prompt
def test_dep_fix_section(self):
diag = ExperimentDiagnosis(
deficiencies=[
Deficiency(
type=DeficiencyType.MISSING_DEPENDENCY,
severity="critical",
description="Missing Python package: box2d-py",
suggested_fix="Add 'box2d-py' to requirements.txt",
),
],
)
prompt = build_repair_prompt(diag, original_code={})
assert "DEPENDENCY FIXES" in prompt
assert "box2d-py" in prompt
def test_long_code_truncated(self):
long_code = "x = 1\n" * 5000
diag = ExperimentDiagnosis()
prompt = build_repair_prompt(diag, original_code={"big.py": long_code})
assert "truncated" in prompt
def test_output_format_section(self):
diag = ExperimentDiagnosis()
prompt = build_repair_prompt(diag, original_code={"main.py": "pass"})
assert "OUTPUT FORMAT" in prompt
assert "filename.py" in prompt
# ---------------------------------------------------------------------------
# ExperimentRepairResult tests
# ---------------------------------------------------------------------------
class TestRepairResult:
def test_serialization(self):
result = ExperimentRepairResult(
success=False,
total_cycles=2,
final_mode=PaperMode.PRELIMINARY_STUDY,
)
d = result.to_dict()
assert d["success"] is False
assert d["total_cycles"] == 2
assert d["final_mode"] == "preliminary_study"
def test_serialization_with_cycles(self):
diag = ExperimentDiagnosis(summary="test")
result = ExperimentRepairResult(
success=True,
total_cycles=1,
final_mode=PaperMode.FULL_PAPER,
cycle_history=[
RepairCycleResult(
cycle=1,
diagnosis=diag,
repair_applied=True,
repair_description="Fixed 2 files",
),
],
)
d = result.to_dict()
assert d["success"] is True
assert len(d["cycle_history"]) == 1
assert d["cycle_history"][0]["repair_applied"] is True
assert d["cycle_history"][0]["diagnosis_summary"] == "test"
# ---------------------------------------------------------------------------
# Code extraction tests
# ---------------------------------------------------------------------------
class TestExtractCodeBlocks:
def test_named_blocks(self):
text = """Here are the fixed files:
```python main.py
import torch
print("hello")
```
```python requirements.txt
torch>=2.0
numpy
```
"""
files = _extract_code_blocks(text)
assert "main.py" in files
assert "requirements.txt" in files
assert "torch" in files["main.py"]
assert "numpy" in files["requirements.txt"]
def test_unnamed_block_fallback(self):
text = """```python
import torch
model = torch.nn.Linear(10, 2)
print("condition=Baseline metric=0.95")
```"""
files = _extract_code_blocks(text)
assert "main.py" in files
assert "torch" in files["main.py"]
def test_no_blocks(self):
text = "No code here, just text."
files = _extract_code_blocks(text)
assert files == {}
def test_path_normalization(self):
text = """```python src/models/main.py
import torch
print("hello world, this is a test of the extraction")
```"""
files = _extract_code_blocks(text)
assert "main.py" in files
# ---------------------------------------------------------------------------
# Summary building tests
# ---------------------------------------------------------------------------
class TestBuildExperimentSummary:
def test_basic_summary(self):
run_result = {
"stdout": "condition=Baseline metric=80.0\ncondition=Proposed metric=90.0",
"stderr": "",
"returncode": 0,
"metrics": {
"Baseline/0/accuracy": 80.0,
"Proposed/0/accuracy": 90.0,
"primary_metric": 90.0,
},
"elapsed_sec": 120.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {"main.py": "pass"})
assert "condition_summaries" in summary
assert "Baseline" in summary["condition_summaries"]
assert "Proposed" in summary["condition_summaries"]
assert summary["total_conditions"] == 2
assert summary["best_run"]["status"] == "completed"
def test_failed_run(self):
run_result = {
"stdout": "",
"stderr": "Error: crash",
"returncode": 1,
"metrics": {},
"elapsed_sec": 5.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {})
assert summary["best_run"]["status"] == "failed"
assert summary["total_conditions"] == 0
def test_multi_seed_grouping(self):
run_result = {
"stdout": "",
"stderr": "",
"returncode": 0,
"metrics": {
"Baseline/0/accuracy": 80.0,
"Baseline/1/accuracy": 82.0,
"Proposed/0/accuracy": 90.0,
"Proposed/1/accuracy": 92.0,
},
"elapsed_sec": 300.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {})
assert len(summary["condition_summaries"]) == 2
# Mean of 80.0 and 82.0
bl = summary["condition_summaries"]["Baseline"]
assert abs(bl["metrics"]["accuracy"] - 81.0) < 0.01
assert bl["n_seeds"] == 2
# ---------------------------------------------------------------------------
# File loading tests
# ---------------------------------------------------------------------------
class TestLoadExperimentCode:
def test_loads_from_stage_13(self, tmp_path):
exp_dir = tmp_path / "stage-13" / "experiment_final"
exp_dir.mkdir(parents=True)
(exp_dir / "main.py").write_text("print('hello')")
(exp_dir / "requirements.txt").write_text("torch")
code = _load_experiment_code(tmp_path)
assert "main.py" in code
assert "requirements.txt" in code
def test_loads_from_stage_10(self, tmp_path):
exp_dir = tmp_path / "stage-10" / "experiment"
exp_dir.mkdir(parents=True)
(exp_dir / "main.py").write_text("print('hello')")
code = _load_experiment_code(tmp_path)
assert "main.py" in code
def test_empty_when_no_code(self, tmp_path):
code = _load_experiment_code(tmp_path)
assert code == {}
class TestLoadExperimentSummary:
def test_loads_summary(self, tmp_path):
stage_dir = tmp_path / "stage-14"
stage_dir.mkdir()
summary = {"condition_summaries": {"A": {}}}
(stage_dir / "experiment_summary.json").write_text(json.dumps(summary))
result = _load_experiment_summary(tmp_path)
assert result is not None
assert "A" in result["condition_summaries"]
# ---------------------------------------------------------------------------
# select_best_results tests
# ---------------------------------------------------------------------------
class TestSelectBestResults:
def test_picks_best_across_cycles(self, tmp_path):
# Original (1 condition)
s14 = tmp_path / "stage-14"
s14.mkdir()
(s14 / "experiment_summary.json").write_text(json.dumps({
"condition_summaries": {"A": {}},
"best_run": {"metrics": {}},
}))
# Repair v1 (3 conditions — better)
r1 = tmp_path / "stage-14_repair_v1"
r1.mkdir()
(r1 / "experiment_summary.json").write_text(json.dumps({
"condition_summaries": {"A": {}, "B": {}, "C": {}},
"best_run": {"metrics": {"primary_metric": 90.0}},
}))
best = select_best_results(tmp_path, [])
assert best is not None
assert len(best["condition_summaries"]) == 3
def test_returns_none_when_empty(self, tmp_path):
result = select_best_results(tmp_path, [])
assert result is None
# ---------------------------------------------------------------------------
# Full repair loop tests (mocked)
# ---------------------------------------------------------------------------
class TestRunRepairLoop:
def _make_run_dir(self, tmp_path, n_conditions=1, has_code=True):
"""Create a minimal run directory for testing."""
# Stage 14 — experiment summary
s14 = tmp_path / "stage-14"
s14.mkdir()
(s14 / "runs").mkdir()
conds = {f"Cond{i}": {"metrics": {"accuracy": 70.0 + i}} for i in range(n_conditions)}
summary = {
"condition_summaries": conds,
"best_run": {"metrics": {f"Cond{i}/0/accuracy": 70.0 + i for i in range(n_conditions)}},
"metrics_summary": {"accuracy": {"mean": 70.5}},
}
(s14 / "experiment_summary.json").write_text(json.dumps(summary))
run_data = {
"stdout": "\n".join(f"condition=Cond{i} metric={70.0 + i}" for i in range(n_conditions)),
"stderr": "",
}
(s14 / "runs" / "run_0.json").write_text(json.dumps(run_data))
# Stage 10 — experiment code
if has_code:
s10 = tmp_path / "stage-10" / "experiment"
s10.mkdir(parents=True)
(s10 / "main.py").write_text("import torch\nprint('hello')")
return tmp_path
def test_skips_when_already_sufficient(self, tmp_path):
"""If experiment is already sufficient, return immediately."""
# 3 conditions with 2+ seeds = full_paper
s14 = tmp_path / "stage-14"
s14.mkdir()
(s14 / "runs").mkdir()
summary = {
"condition_summaries": {
"A": {"metrics": {"m": 80.0}},
"B": {"metrics": {"m": 85.0}},
"C": {"metrics": {"m": 90.0}},
},
"best_run": {
"metrics": {
"A/0/m": 80.0, "A/1/m": 81.0,
"B/0/m": 85.0, "B/1/m": 86.0,
"C/0/m": 90.0, "C/1/m": 91.0,
},
},
}
(s14 / "experiment_summary.json").write_text(json.dumps(summary))
from researchclaw.config import ExperimentConfig, ExperimentRepairConfig
class FakeConfig:
class experiment:
time_budget_sec = 2400
repair = ExperimentRepairConfig(enabled=True)
class llm:
pass
result = run_repair_loop(tmp_path, FakeConfig(), "test")
assert result.success is True
assert result.total_cycles == 0
assert result.final_mode == PaperMode.FULL_PAPER
def test_returns_failure_when_no_code(self, tmp_path):
"""If no experiment code found, return failure."""
s14 = tmp_path / "stage-14"
s14.mkdir()
(s14 / "experiment_summary.json").write_text(json.dumps({
"condition_summaries": {"A": {"metrics": {"m": 80.0}}},
"best_run": {"metrics": {}},
}))
from researchclaw.config import ExperimentRepairConfig
class FakeConfig:
class experiment:
time_budget_sec = 2400
repair = ExperimentRepairConfig(enabled=True)
class llm:
pass
result = run_repair_loop(tmp_path, FakeConfig(), "test")
assert result.success is False
assert result.total_cycles == 0
def test_repair_loop_with_mocked_llm(self, tmp_path):
"""Test full repair loop with mocked LLM and sandbox."""
run_dir = self._make_run_dir(tmp_path, n_conditions=1)
from researchclaw.config import ExperimentRepairConfig, ExperimentConfig, OpenCodeConfig
class FakeConfig:
class experiment:
time_budget_sec = 2400
mode = "simulated"
repair = ExperimentRepairConfig(enabled=True, max_cycles=1, use_opencode=False)
opencode = OpenCodeConfig(enabled=False)
metric_key = "primary_metric"
class llm:
pass
# Mock the LLM to return fixed code
mock_llm = MagicMock()
mock_resp = MagicMock()
mock_resp.content = """```python main.py
import torch
for cond in ["Baseline", "Proposed", "Ablation"]:
for seed in range(2):
acc = 80.0 + hash(cond) % 20 + seed
print(f"condition={cond}/{seed}/accuracy metric={acc}")
print("condition=Baseline metric=80.0")
print("condition=Proposed metric=90.0")
print("condition=Ablation metric=85.0")
```"""
mock_llm.chat.return_value = mock_resp
# Mock sandbox to return good results
mock_sandbox_result = MagicMock()
mock_sandbox_result.stdout = (
"condition=Baseline/0/accuracy metric=80.0\n"
"condition=Baseline/1/accuracy metric=82.0\n"
"condition=Proposed/0/accuracy metric=90.0\n"
"condition=Proposed/1/accuracy metric=92.0\n"
"condition=Ablation/0/accuracy metric=85.0\n"
"condition=Ablation/1/accuracy metric=87.0\n"
)
mock_sandbox_result.stderr = ""
mock_sandbox_result.returncode = 0
mock_sandbox_result.metrics = {
"Baseline/0/accuracy": 80.0, "Baseline/1/accuracy": 82.0,
"Proposed/0/accuracy": 90.0, "Proposed/1/accuracy": 92.0,
"Ablation/0/accuracy": 85.0, "Ablation/1/accuracy": 87.0,
}
mock_sandbox_result.elapsed_sec = 120.0
mock_sandbox_result.timed_out = False
mock_sandbox = MagicMock()
mock_sandbox.run_project.return_value = mock_sandbox_result
with patch("researchclaw.llm.create_llm_client") as mock_create_llm, \
patch("researchclaw.experiment.factory.create_sandbox") as mock_create_sb:
mock_create_llm.return_value = mock_llm
mock_create_sb.return_value = mock_sandbox
result = run_repair_loop(run_dir, FakeConfig(), "test-mock")
assert result.total_cycles == 1
assert len(result.cycle_history) == 1
assert result.cycle_history[0].repair_applied is True
# Check that repair files were saved
repair_dir = run_dir / "stage-14_repair_v1"
assert repair_dir.exists()
assert (repair_dir / "experiment" / "main.py").exists()
assert (repair_dir / "experiment_summary.json").exists()
# ---------------------------------------------------------------------------
# BUG-199: 2-part metric keys (condition/metric) in summary builder
# ---------------------------------------------------------------------------
class TestBuildExperimentSummaryTwoPartKeys:
"""BUG-199: Stage 13 refinement produces 2-part keys (condition/metric)
instead of 3-part keys (condition/seed/metric). The parser must handle
both formats.
"""
def test_two_part_keys_parsed(self):
"""2-part keys like 'Baseline/accuracy' should create conditions."""
run_result = {
"stdout": "",
"stderr": "",
"returncode": 0,
"metrics": {
"Baseline/accuracy": 0.85,
"Proposed/accuracy": 0.94,
"Ablation/accuracy": 0.88,
},
"elapsed_sec": 120.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {})
assert summary["total_conditions"] == 3
assert "Baseline" in summary["condition_summaries"]
assert "Proposed" in summary["condition_summaries"]
assert "Ablation" in summary["condition_summaries"]
assert summary["condition_summaries"]["Proposed"]["metrics"]["accuracy"] == 0.94
def test_two_part_keys_create_synthetic_seed(self):
"""2-part keys should create a synthetic seed '0' entry."""
run_result = {
"stdout": "",
"stderr": "",
"returncode": 0,
"metrics": {
"Baseline/accuracy": 0.80,
"Baseline/loss": 0.45,
},
"elapsed_sec": 60.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {})
bl = summary["condition_summaries"]["Baseline"]
assert bl["metrics"]["accuracy"] == 0.80
assert bl["metrics"]["loss"] == 0.45
assert bl["n_seeds"] == 1 # synthetic seed "0"
def test_mixed_two_and_three_part_keys(self):
"""Mix of 2-part and 3-part keys for different conditions."""
run_result = {
"stdout": "",
"stderr": "",
"returncode": 0,
"metrics": {
# 3-part keys (with seed)
"Baseline/0/accuracy": 0.80,
"Baseline/1/accuracy": 0.82,
# 2-part keys (Stage 13 refinement output)
"Proposed/accuracy": 0.94,
},
"elapsed_sec": 120.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {})
assert summary["total_conditions"] == 2
# 3-part: mean of seeds
bl = summary["condition_summaries"]["Baseline"]
assert abs(bl["metrics"]["accuracy"] - 0.81) < 0.01
assert bl["n_seeds"] == 2
# 2-part: single value
pr = summary["condition_summaries"]["Proposed"]
assert pr["metrics"]["accuracy"] == 0.94
assert pr["n_seeds"] == 1
def test_empty_metrics_still_empty(self):
"""Empty metrics dict should still produce 0 conditions."""
run_result = {
"stdout": "",
"stderr": "",
"returncode": 1,
"metrics": {},
"elapsed_sec": 5.0,
"timed_out": False,
}
summary = _build_experiment_summary_from_run(run_result, {})
assert summary["total_conditions"] == 0
# ---------------------------------------------------------------------------
# BUG-198: Conditional promotion of repair summary in runner.py
# ---------------------------------------------------------------------------
class TestRepairSummaryPromotion:
"""BUG-198: runner.py should NOT overwrite a richer stage-14 summary
with an empty/poorer repair result.
"""
def test_empty_repair_does_not_overwrite_rich_summary(self, tmp_path):
"""Repair result with 0 conditions must NOT replace a summary
that has real conditions and metrics.
"""
# Create a rich existing stage-14 summary
s14 = tmp_path / "stage-14"
s14.mkdir()
rich_summary = {
"condition_summaries": {
"Baseline": {"metrics": {"accuracy": 0.80}},
"Proposed": {"metrics": {"accuracy": 0.94}},
"Ablation": {"metrics": {"accuracy": 0.88}},
},
"best_run": {
"metrics": {
"Baseline/0/accuracy": 0.80,
"Proposed/0/accuracy": 0.94,
"primary_metric": 0.94,
},
},
"total_conditions": 3,
"total_metric_keys": 3,
}
(s14 / "experiment_summary.json").write_text(json.dumps(rich_summary))
# Compute scores to verify the logic
rich_score = _summary_quality_score(rich_summary)
empty_summary = {
"condition_summaries": {},
"best_run": {"metrics": {}},
"total_conditions": 0,
"total_metric_keys": 0,
}
empty_score = _summary_quality_score(empty_summary)
# The rich summary must score higher
assert rich_score > empty_score
# Verify that the existing file is preserved (simulate what runner does)
existing = json.loads(
(s14 / "experiment_summary.json").read_text(encoding="utf-8")
)
existing_score = _summary_quality_score(existing)
repair_score = _summary_quality_score(empty_summary)
# runner.py should NOT overwrite because repair_score <= existing_score
assert repair_score <= existing_score
# The file should still contain the rich data
after = json.loads(
(s14 / "experiment_summary.json").read_text(encoding="utf-8")
)
assert len(after["condition_summaries"]) == 3
def test_richer_repair_does_overwrite(self, tmp_path):
"""Repair result with MORE conditions should replace a poorer summary."""
s14 = tmp_path / "stage-14"
s14.mkdir()
poor_summary = {
"condition_summaries": {"A": {"metrics": {"m": 0.5}}},
"best_run": {"metrics": {}},
"total_conditions": 1,
"total_metric_keys": 0,
}
(s14 / "experiment_summary.json").write_text(json.dumps(poor_summary))
rich_repair = {
"condition_summaries": {
"A": {"metrics": {"m": 0.80}},
"B": {"metrics": {"m": 0.85}},
"C": {"metrics": {"m": 0.90}},
},
"best_run": {"metrics": {"primary_metric": 0.90}},
"total_conditions": 3,
"total_metric_keys": 4,
}
poor_score = _summary_quality_score(poor_summary)
rich_score = _summary_quality_score(rich_repair)
assert rich_score > poor_score
================================================
FILE: tests/test_experiment_schema.py
================================================
"""Tests for the universal experiment schema."""
from __future__ import annotations
import pytest
import yaml
from researchclaw.domains.experiment_schema import (
Condition,
ConditionRole,
EvaluationSpec,
ExperimentType,
MetricSpec,
UniversalExperimentPlan,
from_legacy_exp_plan,
)
# ---------------------------------------------------------------------------
# Condition tests
# ---------------------------------------------------------------------------
class TestCondition:
def test_default_role(self):
c = Condition(name="test")
assert c.role == ConditionRole.PROPOSED.value
def test_custom_role(self):
c = Condition(name="baseline_method", role=ConditionRole.REFERENCE.value)
assert c.role == "reference"
def test_variant_with_parent(self):
c = Condition(
name="ablation_no_attn",
role=ConditionRole.VARIANT.value,
varies_from="proposed_method",
variation="remove_attention",
)
assert c.varies_from == "proposed_method"
# ---------------------------------------------------------------------------
# UniversalExperimentPlan tests
# ---------------------------------------------------------------------------
class TestUniversalExperimentPlan:
def test_empty_plan(self):
plan = UniversalExperimentPlan()
assert plan.conditions == []
assert plan.experiment_type == "comparison"
def test_plan_with_conditions(self):
plan = UniversalExperimentPlan(
experiment_type="comparison",
conditions=[
Condition(name="baseline", role="reference"),
Condition(name="proposed", role="proposed"),
Condition(name="ablation", role="variant", varies_from="proposed"),
],
)
assert len(plan.references) == 1
assert len(plan.proposed) == 1
assert len(plan.variants) == 1
def test_to_legacy_format(self):
plan = UniversalExperimentPlan(
conditions=[
Condition(name="ResNet-18", role="reference", description="Standard baseline"),
Condition(name="OurMethod", role="proposed", description="Our new method"),
Condition(name="OurMethod-NoAttn", role="variant", varies_from="OurMethod"),
],
evaluation=EvaluationSpec(
primary_metric=MetricSpec(name="accuracy", direction="maximize"),
),
)
legacy = plan.to_legacy_format()
assert len(legacy["baselines"]) == 1
assert legacy["baselines"][0]["name"] == "ResNet-18"
assert len(legacy["proposed_methods"]) == 1
assert len(legacy["ablations"]) == 1
assert "accuracy" in legacy["metrics"]
def test_to_yaml(self):
plan = UniversalExperimentPlan(
experiment_type="convergence",
domain_id="physics_pde",
conditions=[
Condition(name="FD2", role="reference"),
Condition(name="FD4", role="proposed"),
],
)
yaml_str = plan.to_yaml()
data = yaml.safe_load(yaml_str)
assert data["experiment"]["type"] == "convergence"
assert data["experiment"]["domain"] == "physics_pde"
assert len(data["experiment"]["conditions"]) == 2
# ---------------------------------------------------------------------------
# from_legacy_exp_plan tests
# ---------------------------------------------------------------------------
class TestFromLegacy:
def test_basic_legacy_plan(self):
legacy = {
"baselines": [
{"name": "ResNet-18", "description": "Standard CNN"},
],
"proposed_methods": [
{"name": "OurNet", "description": "Our new architecture"},
],
"ablations": [
{"name": "OurNet-NoSkip", "description": "Without skip connections"},
],
"metrics": {
"accuracy": {"direction": "maximize"},
},
}
plan = from_legacy_exp_plan(legacy, domain_id="ml_vision")
assert plan.domain_id == "ml_vision"
assert len(plan.references) == 1
assert plan.references[0].name == "ResNet-18"
assert len(plan.proposed) == 1
assert len(plan.variants) == 1
assert plan.evaluation.primary_metric.name == "accuracy"
assert plan.evaluation.primary_metric.direction == "maximize"
def test_legacy_string_names(self):
legacy = {
"baselines": ["baseline_1", "baseline_2"],
"proposed_methods": ["our_method"],
"ablations": [],
}
plan = from_legacy_exp_plan(legacy)
assert len(plan.references) == 2
assert plan.references[0].name == "baseline_1"
def test_legacy_yaml_string(self):
yaml_str = """
baselines:
- name: Euler
description: Basic Euler method
proposed_methods:
- name: RK4
description: Runge-Kutta 4th order
metrics:
convergence_order:
direction: maximize
"""
plan = from_legacy_exp_plan(yaml_str, domain_id="mathematics_numerical")
assert plan.domain_id == "mathematics_numerical"
assert len(plan.references) == 1
assert plan.evaluation.primary_metric.name == "convergence_order"
def test_roundtrip_legacy(self):
"""Test that converting to legacy and back preserves structure."""
plan = UniversalExperimentPlan(
conditions=[
Condition(name="A", role="reference"),
Condition(name="B", role="proposed"),
],
evaluation=EvaluationSpec(
primary_metric=MetricSpec(name="error", direction="minimize"),
),
)
legacy = plan.to_legacy_format()
plan2 = from_legacy_exp_plan(legacy)
assert len(plan2.references) == 1
assert len(plan2.proposed) == 1
assert plan2.evaluation.primary_metric.direction == "minimize"
def test_empty_legacy(self):
plan = from_legacy_exp_plan({})
assert plan.conditions == []
def test_metrics_as_list(self):
legacy = {"metrics": ["accuracy", "f1"]}
plan = from_legacy_exp_plan(legacy)
assert plan.evaluation.primary_metric.name == "accuracy"
# ---------------------------------------------------------------------------
# Enum tests
# ---------------------------------------------------------------------------
class TestEnums:
def test_condition_role_values(self):
assert ConditionRole.REFERENCE.value == "reference"
assert ConditionRole.PROPOSED.value == "proposed"
assert ConditionRole.VARIANT.value == "variant"
def test_experiment_type_values(self):
assert ExperimentType.COMPARISON.value == "comparison"
assert ExperimentType.CONVERGENCE.value == "convergence"
assert ExperimentType.PROGRESSIVE_SPEC.value == "progressive_spec"
================================================
FILE: tests/test_figure_agent.py
================================================
"""Tests for the FigureAgent multi-agent chart generation system."""
from __future__ import annotations
import json
import os
import sys
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from unittest import mock
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@dataclass
class _FakeLLMResponse:
content: str = ""
model: str = "gpt-4.1"
prompt_tokens: int = 100
completion_tokens: int = 200
total_tokens: int = 300
class _FakeLLM:
"""Minimal mock LLM client conforming to _LLMClientLike."""
def __init__(self, response: str = "{}"):
self._response = response
self.calls: list[dict[str, Any]] = []
def chat(self, messages, *, system=None, max_tokens=None,
temperature=None, json_mode=False):
self.calls.append({
"messages": messages,
"system": system,
"json_mode": json_mode,
})
return _FakeLLMResponse(content=self._response)
# Sample experiment data for tests
_SAMPLE_CONDITIONS = {
"proposed_method": {
"metrics": {
"primary_metric": 0.85,
"primary_metric_mean": 0.85,
"primary_metric_std": 0.02,
"secondary_metric": 0.72,
},
"ci95_low": 0.83,
"ci95_high": 0.87,
"n_seeds": 3,
},
"baseline_resnet": {
"metrics": {
"primary_metric": 0.78,
"primary_metric_mean": 0.78,
"primary_metric_std": 0.03,
"secondary_metric": 0.65,
},
"ci95_low": 0.75,
"ci95_high": 0.81,
"n_seeds": 3,
},
"ablation_no_attention": {
"metrics": {
"primary_metric": 0.80,
"primary_metric_mean": 0.80,
"primary_metric_std": 0.02,
"secondary_metric": 0.68,
},
"ci95_low": 0.78,
"ci95_high": 0.82,
"n_seeds": 3,
},
}
_SAMPLE_METRICS_SUMMARY = {
"primary_metric": {"mean": 0.81, "min": 0.78, "max": 0.85, "count": 3},
"secondary_metric": {"mean": 0.68, "min": 0.65, "max": 0.72, "count": 3},
}
# =========================================================================
# Style Config tests
# =========================================================================
class TestStyleConfig:
def test_constants_exist(self):
from researchclaw.agents.figure_agent.style_config import (
COLORS_BRIGHT, DPI_PUBLICATION, FIGURE_WIDTH,
MATPLOTLIB_STYLES, OUTPUT_FORMAT_PRIMARY,
)
assert len(COLORS_BRIGHT) >= 7
assert DPI_PUBLICATION >= 300
assert "single_column" in FIGURE_WIDTH
assert "double_column" in FIGURE_WIDTH
assert len(MATPLOTLIB_STYLES) >= 1
assert OUTPUT_FORMAT_PRIMARY in ("pdf", "png")
def test_get_style_preamble(self):
from researchclaw.agents.figure_agent.style_config import get_style_preamble
preamble = get_style_preamble()
assert "matplotlib" in preamble
assert "plt" in preamble
assert "COLORS" in preamble
assert "300" in preamble
def test_custom_dpi(self):
from researchclaw.agents.figure_agent.style_config import get_style_preamble
preamble = get_style_preamble(dpi=150)
assert "150" in preamble
# =========================================================================
# Planner Agent tests
# =========================================================================
class TestPlannerAgent:
def test_domain_detection_classification(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
agent = PlannerAgent(_FakeLLM())
assert agent._detect_domain("Image classification with CIFAR-10") == "classification"
def test_domain_detection_rl(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
agent = PlannerAgent(_FakeLLM())
assert agent._detect_domain("Reinforcement learning with reward shaping") == "reinforcement_learning"
def test_domain_detection_default(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
agent = PlannerAgent(_FakeLLM())
assert agent._detect_domain("Quantum computing analysis") == "default"
def test_analyze_data_basic(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
agent = PlannerAgent(_FakeLLM())
analysis = agent._analyze_data(
results={},
conditions=["proposed", "baseline", "ablation_no_x"],
metrics_summary=_SAMPLE_METRICS_SUMMARY,
condition_summaries=_SAMPLE_CONDITIONS,
metric_key="primary_metric",
)
assert analysis["num_conditions"] == 3
assert analysis["has_ablation"] is True
assert analysis["has_per_condition_data"] is True
assert analysis["has_multiple_seeds"] is True
def test_analyze_data_training_history(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
agent = PlannerAgent(_FakeLLM())
analysis = agent._analyze_data(
results={"training_history": [1.0, 0.5, 0.3]},
conditions=["a"],
metrics_summary={},
condition_summaries={},
metric_key="loss",
)
assert analysis["has_training_history"] is True
def test_fallback_plan(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
agent = PlannerAgent(_FakeLLM())
analysis = {
"num_conditions": 3,
"num_metrics": 2,
"metric_names": ["primary_metric", "secondary_metric"],
"has_training_history": False,
"has_ablation": True,
"has_multiple_seeds": True,
"has_per_condition_data": True,
"condition_values": {"proposed": 0.85, "baseline": 0.78},
}
figures = agent._fallback_plan("classification", analysis, "primary_metric", ["proposed", "baseline"])
assert len(figures) >= 2
types = {f["chart_type"] for f in figures}
assert "bar_comparison" in types
assert "ablation_grouped" in types
def test_execute_with_llm_response(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
llm = _FakeLLM(json.dumps({
"figures": [
{
"figure_id": "fig_main",
"chart_type": "bar_comparison",
"title": "Main Results",
"caption": "Comparison of methods.",
"data_source": {"type": "condition_comparison", "metric": "primary_metric"},
"x_label": "Method",
"y_label": "Accuracy",
"width": "single_column",
"priority": 1,
"section": "results",
},
{
"figure_id": "fig_ablation",
"chart_type": "ablation_grouped",
"title": "Ablation",
"caption": "Component analysis.",
"data_source": {"type": "ablation_comparison", "metric": "primary_metric"},
"x_label": "Variant",
"y_label": "Accuracy",
"width": "single_column",
"priority": 1,
"section": "results",
},
{
"figure_id": "fig_heatmap",
"chart_type": "heatmap",
"title": "Metric Heatmap",
"caption": "Cross-metric analysis.",
"data_source": {"type": "multi_metric"},
"x_label": "Metric",
"y_label": "Method",
"width": "double_column",
"priority": 2,
"section": "analysis",
},
]
}))
agent = PlannerAgent(llm, min_figures=3)
result = agent.execute({
"experiment_results": {},
"topic": "Image classification with knowledge distillation",
"metric_key": "primary_metric",
"conditions": list(_SAMPLE_CONDITIONS.keys()),
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"condition_summaries": _SAMPLE_CONDITIONS,
})
assert result.success
assert len(result.data["figures"]) == 3
def test_execute_fallback_on_empty_llm(self):
from researchclaw.agents.figure_agent.planner import PlannerAgent
llm = _FakeLLM("{}") # Empty response
agent = PlannerAgent(llm, min_figures=2)
result = agent.execute({
"experiment_results": {},
"topic": "Image classification",
"metric_key": "primary_metric",
"conditions": list(_SAMPLE_CONDITIONS.keys()),
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"condition_summaries": _SAMPLE_CONDITIONS,
})
assert result.success
assert len(result.data["figures"]) >= 2
# =========================================================================
# CodeGen Agent tests
# =========================================================================
class TestCodeGenAgent:
def test_template_bar_comparison(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
agent = CodeGenAgent(_FakeLLM())
result = agent.execute({
"figures": [{
"figure_id": "fig_main",
"chart_type": "bar_comparison",
"title": "Results",
"caption": "Main results.",
"data_source": {"type": "condition_comparison", "metric": "primary_metric"},
"x_label": "Method",
"y_label": "Accuracy",
"width": "single_column",
"section": "results",
}],
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
"output_dir": "charts",
})
assert result.success
scripts = result.data["scripts"]
assert len(scripts) == 1
script = scripts[0]["script"]
assert "0.85" in script # proposed_method value
assert "0.78" in script # baseline value
assert "savefig" in script
def test_template_grouped_bar(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
agent = CodeGenAgent(_FakeLLM())
result = agent.execute({
"figures": [{
"figure_id": "fig_multi",
"chart_type": "grouped_bar",
"title": "Multi-metric",
"caption": "Multi-metric comparison.",
"data_source": {
"type": "multi_metric",
"metrics": ["primary_metric", "secondary_metric"],
},
"x_label": "Method",
"y_label": "Score",
"width": "double_column",
"section": "analysis",
}],
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
"output_dir": "charts",
})
assert result.success
scripts = result.data["scripts"]
assert len(scripts) == 1
assert "secondary_metric" in scripts[0]["script"]
def test_template_heatmap(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
agent = CodeGenAgent(_FakeLLM())
result = agent.execute({
"figures": [{
"figure_id": "fig_heat",
"chart_type": "heatmap",
"title": "Heatmap",
"caption": "Analysis.",
"data_source": {"type": "heatmap"},
"x_label": "Metric",
"y_label": "Method",
"width": "double_column",
"section": "analysis",
}],
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
"output_dir": "charts",
})
assert result.success
scripts = result.data["scripts"]
assert len(scripts) == 1
assert "imshow" in scripts[0]["script"]
def test_llm_fallback_for_unknown_type(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
llm = _FakeLLM("```python\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nfig, ax = plt.subplots()\nax.plot([1,2,3])\nfig.savefig('charts/fig_custom.png')\nplt.close(fig)\n```")
agent = CodeGenAgent(llm)
result = agent.execute({
"figures": [{
"figure_id": "fig_custom",
"chart_type": "radar_chart",
"title": "Radar",
"caption": "Custom chart.",
"data_source": {},
"x_label": "X",
"y_label": "Y",
"width": "single_column",
"section": "analysis",
}],
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
"output_dir": "charts",
})
assert result.success
assert "matplotlib" in result.data["scripts"][0]["script"]
def test_strip_fences(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
code = "```python\nprint('hello')\n```"
assert CodeGenAgent._strip_fences(code) == "print('hello')"
def test_strip_fences_no_fences(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
code = "print('hello')"
assert CodeGenAgent._strip_fences(code) == "print('hello')"
def test_multiple_figures(self):
from researchclaw.agents.figure_agent.codegen import CodeGenAgent
agent = CodeGenAgent(_FakeLLM())
figures = [
{
"figure_id": f"fig_{i}",
"chart_type": "bar_comparison",
"title": f"Figure {i}",
"caption": f"Caption {i}.",
"data_source": {"type": "condition_comparison", "metric": "primary_metric"},
"x_label": "X",
"y_label": "Y",
"width": "single_column",
"section": "results",
}
for i in range(3)
]
result = agent.execute({
"figures": figures,
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
"output_dir": "charts",
})
assert result.success
assert len(result.data["scripts"]) == 3
# =========================================================================
# Renderer Agent tests
# =========================================================================
class TestRendererAgent:
def test_render_simple_script(self, tmp_path):
from researchclaw.agents.figure_agent.renderer import RendererAgent
agent = RendererAgent(_FakeLLM(), timeout_sec=10, use_docker=False)
output_dir = tmp_path / "charts"
# Use a script that creates a valid PNG without matplotlib
# (creates a minimal 1x1 PNG file directly)
script = textwrap.dedent("""\
import struct, zlib
output_path = "{output_dir}/fig_test.png"
# Minimal valid PNG: 1x1 white pixel
def write_png(path):
sig = b'\\x89PNG\\r\\n\\x1a\\n'
def chunk(ctype, data):
c = ctype + data
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
ihdr = struct.pack('>IIBBBBB', 1, 1, 8, 2, 0, 0, 0)
raw = zlib.compress(b'\\x00\\xff\\xff\\xff')
with open(path, 'wb') as f:
f.write(sig)
f.write(chunk(b'IHDR', ihdr))
f.write(chunk(b'IDAT', raw))
f.write(chunk(b'IEND', b''))
write_png(output_path)
# Pad file to meet minimum size requirement
with open(output_path, 'ab') as f:
f.write(b'\\x00' * 2048)
print(f"Saved: {{output_path}}")
""").format(output_dir=output_dir)
result = agent.execute({
"scripts": [{
"figure_id": "fig_test",
"script": script,
"output_filename": "fig_test.png",
"title": "Test",
"caption": "Test chart",
"section": "results",
}],
"output_dir": str(output_dir),
})
assert result.success
rendered = result.data["rendered"]
assert len(rendered) == 1
assert rendered[0]["success"] is True
assert Path(rendered[0]["output_path"]).exists()
def test_render_syntax_error(self, tmp_path):
from researchclaw.agents.figure_agent.renderer import RendererAgent
agent = RendererAgent(_FakeLLM(), timeout_sec=5)
result = agent.execute({
"scripts": [{
"figure_id": "fig_bad",
"script": "this is not valid python!!!",
"output_filename": "fig_bad.png",
}],
"output_dir": str(tmp_path / "charts"),
})
# The renderer itself succeeds (returns results), but individual
# figures have success=False
rendered = result.data["rendered"]
assert len(rendered) == 1
assert rendered[0]["success"] is False
assert rendered[0]["error"]
def test_render_empty_script(self, tmp_path):
from researchclaw.agents.figure_agent.renderer import RendererAgent
agent = RendererAgent(_FakeLLM(), timeout_sec=5)
result = agent.execute({
"scripts": [{
"figure_id": "fig_empty",
"script": "",
"output_filename": "fig_empty.png",
}],
"output_dir": str(tmp_path / "charts"),
})
rendered = result.data["rendered"]
assert rendered[0]["success"] is False
assert "Empty" in rendered[0]["error"]
def test_script_saved_for_reproducibility(self, tmp_path):
from researchclaw.agents.figure_agent.renderer import RendererAgent
agent = RendererAgent(_FakeLLM(), timeout_sec=5)
output_dir = tmp_path / "charts"
result = agent.execute({
"scripts": [{
"figure_id": "fig_save",
"script": "print('hello')",
"output_filename": "fig_save.png",
}],
"output_dir": str(output_dir),
})
# Script should be saved even if rendering fails
script_path = output_dir / "scripts" / "fig_save.py"
assert script_path.exists()
assert script_path.read_text() == "print('hello')"
# =========================================================================
# Critic Agent tests
# =========================================================================
class TestCriticAgent:
def test_numerical_accuracy_pass(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
llm = _FakeLLM(json.dumps({
"quality_score": 8,
"issues": [],
}))
agent = CriticAgent(llm)
script = "values = [0.85, 0.78, 0.80]\nax.bar(x, values)\nfig.savefig('out.png')\nplt.close(fig)"
issues = agent._check_numerical_accuracy(script, _SAMPLE_CONDITIONS, "primary_metric")
# Values 0.85 and 0.78 are in script → should pass
assert not any(i["severity"] == "critical" for i in issues)
def test_numerical_accuracy_fail(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
agent = CriticAgent(_FakeLLM())
script = "values = [0.99, 0.98, 0.97]" # Wrong values
issues = agent._check_numerical_accuracy(script, _SAMPLE_CONDITIONS, "primary_metric")
assert any(i["severity"] == "critical" for i in issues)
def test_text_correctness_missing_labels(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
agent = CriticAgent(_FakeLLM())
script = "fig, ax = plt.subplots()\nax.bar([0], [1])" # Missing labels + savefig
issues = agent._check_text_correctness(script, {})
types = {i["message"] for i in issues}
assert any("x-axis" in t for t in types)
assert any("savefig" in t for t in types)
def test_text_correctness_all_present(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
agent = CriticAgent(_FakeLLM())
script = (
"ax.set_xlabel('X')\n"
"ax.set_ylabel('Y')\n"
"ax.set_title('T')\n"
"fig.savefig('out.png')\n"
"plt.close(fig)"
)
issues = agent._check_text_correctness(script, {})
assert len(issues) == 0
def test_visual_quality_llm_review(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
llm = _FakeLLM(json.dumps({
"quality_score": 9,
"issues": [],
}))
agent = CriticAgent(llm)
issues = agent._check_visual_quality(
"import matplotlib\nplt.figure()\nplt.savefig('x.png')",
{"title": "Test"},
)
assert not any(i["severity"] == "critical" for i in issues)
def test_visual_quality_low_score(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
llm = _FakeLLM(json.dumps({
"quality_score": 3,
"issues": [{"severity": "critical", "message": "Bad colors"}],
}))
agent = CriticAgent(llm)
issues = agent._check_visual_quality("plt.plot([1,2])", {"title": "Bad"})
assert any(i["severity"] == "critical" for i in issues)
def test_execute_full_review(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
llm = _FakeLLM(json.dumps({
"quality_score": 8,
"issues": [],
}))
agent = CriticAgent(llm)
result = agent.execute({
"rendered": [
{
"figure_id": "fig_1",
"success": True,
"output_path": "/tmp/fig.png",
"title": "Test",
"caption": "Test fig",
},
],
"scripts": [
{
"figure_id": "fig_1",
"script": (
"values = [0.85, 0.78]\n"
"ax.set_xlabel('X')\nax.set_ylabel('Y')\n"
"ax.set_title('T')\nfig.savefig('x.png')\nplt.close(fig)"
),
},
],
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
})
assert result.success
assert result.data["passed_count"] >= 0
def test_review_failed_render(self):
from researchclaw.agents.figure_agent.critic import CriticAgent
agent = CriticAgent(_FakeLLM())
result = agent.execute({
"rendered": [
{"figure_id": "fig_1", "success": False, "error": "Crash"},
],
"scripts": [],
"condition_summaries": {},
"metrics_summary": {},
"metric_key": "primary_metric",
})
assert result.success
assert result.data["reviews"][0]["passed"] is False
# =========================================================================
# Integrator Agent tests
# =========================================================================
class TestIntegratorAgent:
def test_build_manifest(self):
from researchclaw.agents.figure_agent.integrator import IntegratorAgent
agent = IntegratorAgent(_FakeLLM())
rendered = [
{
"figure_id": "fig_main",
"success": True,
"output_path": "/tmp/charts/fig_main.png",
"title": "Main Results",
"caption": "Comparison.",
"section": "results",
"width": "single_column",
},
{
"figure_id": "fig_ablation",
"success": True,
"output_path": "/tmp/charts/fig_ablation.png",
"title": "Ablation",
"caption": "Analysis.",
"section": "results",
"width": "single_column",
},
]
manifest = agent._build_manifest(rendered, Path("/tmp/charts"))
assert len(manifest) == 2
assert manifest[0]["figure_number"] == 1
assert manifest[0]["paper_section"] == "Results"
assert "charts/" in manifest[0]["file_path"]
def test_generate_markdown_refs(self):
from researchclaw.agents.figure_agent.integrator import IntegratorAgent
agent = IntegratorAgent(_FakeLLM())
manifest = [
{
"figure_number": 1,
"file_path": "charts/fig_1.png",
"caption": "Main results comparison",
},
]
refs = agent._generate_markdown_refs(manifest)
assert "![Figure 1:" in refs
assert "charts/fig_1.png" in refs
def test_generate_descriptions(self):
from researchclaw.agents.figure_agent.integrator import IntegratorAgent
agent = IntegratorAgent(_FakeLLM())
manifest = [
{
"figure_number": 1,
"file_path": "charts/fig_1.png",
"title": "Main Results",
"caption": "Comparison",
"paper_section": "Results",
},
]
desc = agent._generate_descriptions(manifest)
assert "AVAILABLE FIGURES" in desc
assert "Main Results" in desc
assert "Results" in desc
def test_execute_empty(self):
from researchclaw.agents.figure_agent.integrator import IntegratorAgent
agent = IntegratorAgent(_FakeLLM())
result = agent.execute({
"rendered": [],
"topic": "Test",
"output_dir": "/tmp/charts",
})
assert result.success
assert result.data["figure_count"] == 0
def test_execute_with_figures(self, tmp_path):
from researchclaw.agents.figure_agent.integrator import IntegratorAgent
agent = IntegratorAgent(_FakeLLM())
output_dir = tmp_path / "charts"
output_dir.mkdir()
result = agent.execute({
"rendered": [
{
"figure_id": "fig_main",
"success": True,
"output_path": str(output_dir / "fig_main.png"),
"title": "Main",
"caption": "Main comparison.",
"section": "results",
},
],
"topic": "Test",
"output_dir": str(output_dir),
})
assert result.success
assert result.data["figure_count"] == 1
assert (output_dir / "figure_manifest.json").exists()
def test_section_ordering(self):
from researchclaw.agents.figure_agent.integrator import IntegratorAgent
assert IntegratorAgent._section_order("method") < IntegratorAgent._section_order("results")
assert IntegratorAgent._section_order("results") < IntegratorAgent._section_order("analysis")
# =========================================================================
# Orchestrator tests
# =========================================================================
class TestOrchestrator:
def test_orchestrate_basic(self, tmp_path):
from researchclaw.agents.figure_agent.orchestrator import (
FigureAgentConfig, FigureOrchestrator,
)
# LLM returns plan, then quality review
responses = iter([
json.dumps({
"figures": [{
"figure_id": "fig_main",
"chart_type": "bar_comparison",
"title": "Main",
"caption": "Main comparison.",
"data_source": {"type": "condition_comparison", "metric": "primary_metric"},
"x_label": "Method",
"y_label": "Accuracy",
"width": "single_column",
"priority": 1,
"section": "results",
}, {
"figure_id": "fig_ablation",
"chart_type": "ablation_grouped",
"title": "Ablation",
"caption": "Ablation study.",
"data_source": {"type": "ablation_comparison", "metric": "primary_metric"},
"x_label": "Variant",
"y_label": "Accuracy",
"width": "single_column",
"priority": 1,
"section": "results",
}, {
"figure_id": "fig_heatmap",
"chart_type": "heatmap",
"title": "Heatmap",
"caption": "Metric heatmap.",
"data_source": {"type": "heatmap"},
"x_label": "Metric",
"y_label": "Method",
"width": "double_column",
"priority": 2,
"section": "analysis",
}],
}),
# Critic review (called multiple times)
json.dumps({"quality_score": 8, "issues": []}),
json.dumps({"quality_score": 8, "issues": []}),
json.dumps({"quality_score": 8, "issues": []}),
])
class _MultiLLM:
def __init__(self):
self.calls = []
def chat(self, messages, **kwargs):
self.calls.append(messages)
try:
resp = next(responses)
except StopIteration:
resp = json.dumps({"quality_score": 8, "issues": []})
return _FakeLLMResponse(content=resp)
cfg = FigureAgentConfig(
min_figures=3,
max_figures=5,
max_iterations=1,
render_timeout_sec=10,
)
orch = FigureOrchestrator(_MultiLLM(), cfg, stage_dir=tmp_path)
plan = orch.orchestrate({
"experiment_results": {},
"condition_summaries": _SAMPLE_CONDITIONS,
"metrics_summary": _SAMPLE_METRICS_SUMMARY,
"metric_key": "primary_metric",
"conditions": list(_SAMPLE_CONDITIONS.keys()),
"topic": "Image classification",
"output_dir": str(tmp_path / "charts"),
})
assert plan.total_llm_calls > 0
assert plan.elapsed_sec > 0
# Plan should have chart files (some may fail rendering, that's OK)
assert isinstance(plan.manifest, list)
def test_figure_plan_serialization(self):
from researchclaw.agents.figure_agent.orchestrator import FigurePlan
plan = FigurePlan(
manifest=[{"figure_number": 1, "file_path": "charts/fig.png"}],
figure_count=1,
passed_count=1,
)
d = plan.to_dict()
assert d["figure_count"] == 1
assert len(d["manifest"]) == 1
def test_get_chart_files(self):
from researchclaw.agents.figure_agent.orchestrator import FigurePlan
plan = FigurePlan(
manifest=[
{"figure_number": 1, "file_path": "charts/fig_main.png"},
{"figure_number": 2, "file_path": "charts/fig_ablation.png"},
],
)
files = plan.get_chart_files()
assert files == ["fig_main.png", "fig_ablation.png"]
# =========================================================================
# Config tests
# =========================================================================
class TestFigureAgentConfig:
def test_default_config(self):
from researchclaw.config import FigureAgentConfig
cfg = FigureAgentConfig()
assert cfg.enabled is True
assert cfg.min_figures == 3
assert cfg.max_figures == 8
assert cfg.max_iterations == 3
assert cfg.dpi == 300
assert cfg.strict_mode is False
def test_parse_from_dict(self):
from researchclaw.config import _parse_figure_agent_config
cfg = _parse_figure_agent_config({
"enabled": False,
"min_figures": 2,
"max_figures": 6,
"dpi": 150,
})
assert cfg.enabled is False
assert cfg.min_figures == 2
assert cfg.max_figures == 6
assert cfg.dpi == 150
def test_parse_from_dict_extended_fields(self):
from researchclaw.config import _parse_figure_agent_config
cfg = _parse_figure_agent_config({
"use_docker": False,
"docker_image": "custom/figure:latest",
"output_format": "latex",
"gemini_api_key": "test-key",
"gemini_model": "gemini-test",
"nano_banana_enabled": False,
})
assert cfg.use_docker is False
assert cfg.docker_image == "custom/figure:latest"
assert cfg.output_format == "latex"
assert cfg.gemini_api_key == "test-key"
assert cfg.gemini_model == "gemini-test"
assert cfg.nano_banana_enabled is False
def test_parse_empty(self):
from researchclaw.config import _parse_figure_agent_config
cfg = _parse_figure_agent_config({})
assert cfg.enabled is True
assert cfg.min_figures == 3
def test_experiment_config_has_figure_agent(self):
from researchclaw.config import ExperimentConfig
ec = ExperimentConfig()
assert hasattr(ec, "figure_agent")
assert ec.figure_agent.enabled is True
# =========================================================================
# Backward compatibility test
# =========================================================================
class TestBackwardCompatibility:
def test_visualize_still_importable(self):
"""Old visualize.py functions should still be importable."""
from researchclaw.experiment.visualize import (
generate_all_charts,
plot_condition_comparison,
plot_experiment_comparison,
plot_metric_trajectory,
)
assert callable(generate_all_charts)
assert callable(plot_condition_comparison)
assert callable(plot_experiment_comparison)
assert callable(plot_metric_trajectory)
def test_figure_agent_importable(self):
from researchclaw.agents.figure_agent import FigureOrchestrator, FigurePlan
assert FigureOrchestrator is not None
assert FigurePlan is not None
================================================
FILE: tests/test_knowledge_graph.py
================================================
"""Tests for the research knowledge graph (20+ tests).
Covers:
- Entity/Relation CRUD
- Graph queries (gaps, trends, comparison)
- JSON serialization/deserialization
- Incremental updates
- Visualizer exports
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.knowledge.graph.entities import Entity, EntityType
from researchclaw.knowledge.graph.relations import Relation, RelationType
from researchclaw.knowledge.graph.builder import KnowledgeGraphBuilder
from researchclaw.knowledge.graph.query import KnowledgeGraphQuery
from researchclaw.knowledge.graph.visualizer import (
export_to_dot,
export_to_json_cytoscape,
graph_summary,
)
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture
def graph() -> KnowledgeGraphBuilder:
return KnowledgeGraphBuilder(max_entities=100)
@pytest.fixture
def populated_graph(graph: KnowledgeGraphBuilder) -> KnowledgeGraphBuilder:
# Papers
graph.add_paper("p1", "ResNet: Deep Residual Learning", year=2016, authors=["He"])
graph.add_paper("p2", "ViT: An Image is Worth 16x16 Words", year=2021, authors=["Dosovitskiy"])
graph.add_paper("p3", "DeiT: Training Data-efficient Image Transformers", year=2021, authors=["Touvron"])
# Methods
graph.add_method("m1", "ResNet", description="Residual connections for deep networks")
graph.add_method("m2", "Vision Transformer", description="Transformer for image classification")
graph.add_method("m3", "Knowledge Distillation", description="Teacher-student learning")
# Datasets
graph.add_dataset("d1", "ImageNet", domain="computer vision")
graph.add_dataset("d2", "CIFAR-10", domain="computer vision")
graph.add_dataset("d3", "CIFAR-100", domain="computer vision")
# Relations
graph.add_relation(Relation("p2", "p1", RelationType.CITES))
graph.add_relation(Relation("p3", "p2", RelationType.EXTENDS))
graph.add_relation(Relation("p3", "p1", RelationType.CITES))
graph.add_relation(Relation("m1", "d1", RelationType.USES_DATASET))
graph.add_relation(Relation("m1", "d2", RelationType.USES_DATASET))
graph.add_relation(Relation("m2", "d1", RelationType.USES_DATASET))
graph.add_relation(Relation("m2", "d2", RelationType.USES_DATASET))
graph.add_relation(Relation("p1", "m1", RelationType.APPLIES_METHOD))
graph.add_relation(Relation("p2", "m2", RelationType.APPLIES_METHOD))
graph.add_relation(Relation("m2", "m1", RelationType.OUTPERFORMS, {"dataset": "ImageNet"}))
return graph
# ── Entity Tests ─────────────────────────────────────────────────────
class TestEntity:
def test_create_entity(self) -> None:
e = Entity("e1", EntityType.PAPER, "Test Paper")
assert e.id == "e1"
assert e.entity_type == EntityType.PAPER
def test_to_dict(self) -> None:
e = Entity("e1", EntityType.METHOD, "TestMethod", {"key": "val"})
d = e.to_dict()
assert d["entity_type"] == "method"
assert d["attributes"]["key"] == "val"
def test_from_dict(self) -> None:
data = {"id": "x", "entity_type": "dataset", "name": "Test", "attributes": {}}
e = Entity.from_dict(data)
assert e.entity_type == EntityType.DATASET
class TestRelation:
def test_create_relation(self) -> None:
r = Relation("a", "b", RelationType.CITES)
assert r.source_id == "a"
assert r.target_id == "b"
def test_to_dict(self) -> None:
r = Relation("a", "b", RelationType.OUTPERFORMS, {"margin": 0.05})
d = r.to_dict()
assert d["relation_type"] == "outperforms"
assert d["attributes"]["margin"] == 0.05
def test_from_dict(self) -> None:
data = {"source_id": "x", "target_id": "y", "relation_type": "extends"}
r = Relation.from_dict(data)
assert r.relation_type == RelationType.EXTENDS
# ── Builder Tests ────────────────────────────────────────────────────
class TestKnowledgeGraphBuilder:
def test_add_entity(self, graph: KnowledgeGraphBuilder) -> None:
e = Entity("e1", EntityType.PAPER, "Test")
assert graph.add_entity(e)
assert graph.entity_count == 1
def test_add_duplicate_updates(self, graph: KnowledgeGraphBuilder) -> None:
graph.add_entity(Entity("e1", EntityType.PAPER, "V1", {"a": 1}))
graph.add_entity(Entity("e1", EntityType.PAPER, "V2", {"b": 2}))
assert graph.entity_count == 1
e = graph.get_entity("e1")
assert e is not None
assert e.name == "V2"
assert e.attributes["a"] == 1 # merged
assert e.attributes["b"] == 2
def test_capacity_limit(self) -> None:
g = KnowledgeGraphBuilder(max_entities=2)
g.add_entity(Entity("e1", EntityType.PAPER, "P1"))
g.add_entity(Entity("e2", EntityType.PAPER, "P2"))
assert not g.add_entity(Entity("e3", EntityType.PAPER, "P3"))
assert g.entity_count == 2
def test_add_relation(self, graph: KnowledgeGraphBuilder) -> None:
graph.add_entity(Entity("a", EntityType.PAPER, "A"))
graph.add_entity(Entity("b", EntityType.PAPER, "B"))
assert graph.add_relation(Relation("a", "b", RelationType.CITES))
assert graph.relation_count == 1
def test_add_relation_missing_entity(self, graph: KnowledgeGraphBuilder) -> None:
graph.add_entity(Entity("a", EntityType.PAPER, "A"))
assert not graph.add_relation(Relation("a", "missing", RelationType.CITES))
def test_duplicate_relation(self, graph: KnowledgeGraphBuilder) -> None:
graph.add_entity(Entity("a", EntityType.PAPER, "A"))
graph.add_entity(Entity("b", EntityType.PAPER, "B"))
graph.add_relation(Relation("a", "b", RelationType.CITES))
graph.add_relation(Relation("a", "b", RelationType.CITES)) # duplicate
assert graph.relation_count == 1
def test_get_entities_by_type(self, populated_graph: KnowledgeGraphBuilder) -> None:
papers = populated_graph.get_entities_by_type(EntityType.PAPER)
assert len(papers) == 3
def test_get_relations_for(self, populated_graph: KnowledgeGraphBuilder) -> None:
rels = populated_graph.get_relations_for("p2")
assert len(rels) >= 2 # outgoing + incoming
def test_remove_entity(self, populated_graph: KnowledgeGraphBuilder) -> None:
initial_rels = populated_graph.relation_count
assert populated_graph.remove_entity("p1")
assert populated_graph.get_entity("p1") is None
assert populated_graph.relation_count < initial_rels
def test_remove_nonexistent_entity(self, graph: KnowledgeGraphBuilder) -> None:
assert not graph.remove_entity("nope")
def test_convenience_methods(self, graph: KnowledgeGraphBuilder) -> None:
paper = graph.add_paper("p1", "Test Paper", year=2024)
method = graph.add_method("m1", "TestNet", description="A test")
dataset = graph.add_dataset("d1", "TestSet", domain="cv")
assert paper.entity_type == EntityType.PAPER
assert method.entity_type == EntityType.METHOD
assert dataset.entity_type == EntityType.DATASET
# ── Persistence ──────────────────────────────────────────────────────
class TestGraphPersistence:
def test_save_and_load(self, populated_graph: KnowledgeGraphBuilder, tmp_path: Path) -> None:
path = tmp_path / "graph.json"
populated_graph.save(path)
assert path.exists()
new_graph = KnowledgeGraphBuilder()
loaded = new_graph.load(path)
assert loaded == populated_graph.entity_count
assert new_graph.relation_count == populated_graph.relation_count
def test_load_nonexistent(self, graph: KnowledgeGraphBuilder, tmp_path: Path) -> None:
assert graph.load(tmp_path / "nope.json") == 0
def test_load_malformed(self, graph: KnowledgeGraphBuilder, tmp_path: Path) -> None:
path = tmp_path / "bad.json"
path.write_text("not json", encoding="utf-8")
assert graph.load(path) == 0
# ── Query Engine ─────────────────────────────────────────────────────
class TestKnowledgeGraphQuery:
def test_find_research_gaps(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
gaps = query.find_research_gaps()
# CIFAR-100 has no methods using it
assert any("CIFAR-100" in g for g in gaps)
def test_find_research_gaps_with_domain(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
gaps = query.find_research_gaps(domain="computer vision")
assert isinstance(gaps, list)
def test_find_trending_methods(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
trending = query.find_trending_methods(min_citations=1)
assert len(trending) > 0
def test_get_method_comparison(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
comparison = query.get_method_comparison("ResNet", "Vision Transformer")
assert "method_a" in comparison
assert "method_b" in comparison
assert "shared_datasets" in comparison
def test_get_method_comparison_not_found(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
comparison = query.get_method_comparison("NonexistentA", "NonexistentB")
assert "error" in comparison
def test_suggest_topics(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
topics = query.suggest_topics(["transformer", "vision"], top_k=3)
assert isinstance(topics, list)
def test_suggest_topics_empty_interests(self, populated_graph: KnowledgeGraphBuilder) -> None:
query = KnowledgeGraphQuery(populated_graph)
topics = query.suggest_topics([])
assert isinstance(topics, list)
# ── Visualizer ───────────────────────────────────────────────────────
class TestVisualizer:
def test_export_dot(self, populated_graph: KnowledgeGraphBuilder, tmp_path: Path) -> None:
path = tmp_path / "graph.dot"
export_to_dot(populated_graph, path)
assert path.exists()
content = path.read_text(encoding="utf-8")
assert "digraph" in content
assert "ResNet" in content
def test_export_cytoscape(self, populated_graph: KnowledgeGraphBuilder, tmp_path: Path) -> None:
path = tmp_path / "graph.json"
export_to_json_cytoscape(populated_graph, path)
assert path.exists()
data = json.loads(path.read_text(encoding="utf-8"))
assert "elements" in data
assert len(data["elements"]) > 0
def test_graph_summary(self, populated_graph: KnowledgeGraphBuilder) -> None:
summary = graph_summary(populated_graph)
assert "entities" in summary
assert "relations" in summary
assert "paper" in summary
================================================
FILE: tests/test_mcp.py
================================================
"""Tests for MCP integration (C3): Server, Client, Tools, Transport, Registry."""
from __future__ import annotations
import asyncio
import pytest
from researchclaw.mcp.tools import TOOL_DEFINITIONS, get_tool_schema, list_tool_names
from researchclaw.mcp.server import ResearchClawMCPServer
from researchclaw.mcp.client import MCPClient
from researchclaw.mcp.registry import MCPServerRegistry
from researchclaw.mcp.transport import SSETransport
# ══════════════════════════════════════════════════════════════════
# MCP Tools tests
# ══════════════════════════════════════════════════════════════════
class TestMCPTools:
def test_tool_definitions_not_empty(self) -> None:
assert len(TOOL_DEFINITIONS) >= 6
def test_all_tools_have_required_fields(self) -> None:
for tool in TOOL_DEFINITIONS:
assert "name" in tool
assert "description" in tool
assert "inputSchema" in tool
assert tool["inputSchema"]["type"] == "object"
def test_get_tool_schema_exists(self) -> None:
schema = get_tool_schema("run_pipeline")
assert schema is not None
assert schema["name"] == "run_pipeline"
def test_get_tool_schema_missing(self) -> None:
assert get_tool_schema("nonexistent") is None
def test_list_tool_names(self) -> None:
names = list_tool_names()
assert "run_pipeline" in names
assert "get_pipeline_status" in names
assert "search_literature" in names
def test_run_pipeline_requires_topic(self) -> None:
schema = get_tool_schema("run_pipeline")
assert schema is not None
assert "topic" in schema["inputSchema"]["required"]
def test_get_paper_has_format_enum(self) -> None:
schema = get_tool_schema("get_paper")
assert schema is not None
props = schema["inputSchema"]["properties"]
assert "format" in props
assert "enum" in props["format"]
# ══════════════════════════════════════════════════════════════════
# MCP Server tests
# ══════════════════════════════════════════════════════════════════
class TestMCPServer:
def test_get_tools(self) -> None:
server = ResearchClawMCPServer()
tools = server.get_tools()
assert len(tools) >= 6
names = [t["name"] for t in tools]
assert "run_pipeline" in names
def test_handle_unknown_tool(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("nonexistent", {}))
assert result["success"] is False
assert "Unknown tool" in result["error"]
def test_handle_run_pipeline(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("run_pipeline", {"topic": "GNN"}))
assert result["success"] is True
assert "GNN" in result["message"]
def test_handle_get_status_missing_run(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("get_pipeline_status", {"run_id": "nonexistent"}))
assert result["success"] is False
def test_handle_search_literature(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("search_literature", {"query": "transformers"}))
assert result["success"] is True
def test_handle_review_paper(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("review_paper", {"paper_path": "/tmp/paper.md"}))
assert result["success"] is True
def test_start_stop(self) -> None:
server = ResearchClawMCPServer()
assert not server.is_running
async def _run() -> None:
await server.start()
assert server.is_running
await server.stop()
assert not server.is_running
asyncio.run(_run())
def test_handle_get_results_missing(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("get_experiment_results", {"run_id": "missing"}))
assert result["success"] is False
def test_handle_get_paper_missing(self) -> None:
server = ResearchClawMCPServer()
result = asyncio.run(server.handle_tool_call("get_paper", {"run_id": "missing"}))
assert result["success"] is False
# ══════════════════════════════════════════════════════════════════
# MCP Client tests
# ══════════════════════════════════════════════════════════════════
class TestMCPClient:
def test_init(self) -> None:
client = MCPClient("http://localhost:3000")
assert client.uri == "http://localhost:3000"
assert not client.is_connected
def test_connect_disconnect(self) -> None:
client = MCPClient("http://localhost:3000")
async def _run() -> None:
await client.connect()
assert client.is_connected
await client.disconnect()
assert not client.is_connected
asyncio.run(_run())
def test_list_tools_not_connected(self) -> None:
client = MCPClient("http://localhost:3000")
with pytest.raises(ConnectionError):
asyncio.run(client.list_tools())
def test_call_tool_not_connected(self) -> None:
client = MCPClient("http://localhost:3000")
with pytest.raises(ConnectionError):
asyncio.run(client.call_tool("test", {}))
def test_list_resources_not_connected(self) -> None:
client = MCPClient("http://localhost:3000")
with pytest.raises(ConnectionError):
asyncio.run(client.list_resources())
def test_read_resource_not_connected(self) -> None:
client = MCPClient("http://localhost:3000")
with pytest.raises(ConnectionError):
asyncio.run(client.read_resource("test://resource"))
def test_list_tools_connected(self) -> None:
client = MCPClient("http://localhost:3000")
async def _run() -> list:
await client.connect()
return await client.list_tools()
tools = asyncio.run(_run())
assert isinstance(tools, list)
def test_tools_cached(self) -> None:
client = MCPClient("http://localhost:3000")
async def _run() -> tuple:
await client.connect()
t1 = await client.list_tools()
t2 = await client.list_tools()
return t1, t2
t1, t2 = asyncio.run(_run())
assert t1 is t2
# ══════════════════════════════════════════════════════════════════
# MCP Server Registry tests
# ══════════════════════════════════════════════════════════════════
class TestMCPServerRegistry:
def test_register_and_list(self) -> None:
async def _run() -> list:
reg = MCPServerRegistry()
await reg.register("test", "http://localhost:3000")
return reg.list_all()
servers = asyncio.run(_run())
assert len(servers) == 1
assert servers[0]["name"] == "test"
assert servers[0]["connected"] is True
def test_unregister(self) -> None:
async def _run() -> int:
reg = MCPServerRegistry()
await reg.register("test", "http://localhost:3000")
await reg.unregister("test")
return reg.count
count = asyncio.run(_run())
assert count == 0
def test_get(self) -> None:
async def _run() -> MCPClient | None:
reg = MCPServerRegistry()
await reg.register("test", "http://localhost:3000")
return reg.get("test")
client = asyncio.run(_run())
assert client is not None
assert client.is_connected
def test_get_missing(self) -> None:
reg = MCPServerRegistry()
assert reg.get("nonexistent") is None
def test_close_all(self) -> None:
async def _run() -> int:
reg = MCPServerRegistry()
await reg.register("a", "http://a:3000")
await reg.register("b", "http://b:3000")
await reg.close_all()
return reg.count
count = asyncio.run(_run())
assert count == 0
# ══════════════════════════════════════════════════════════════════
# Transport tests
# ══════════════════════════════════════════════════════════════════
class TestSSETransport:
def test_start_stop(self) -> None:
transport = SSETransport(port=9999)
async def _run() -> None:
await transport.start()
assert transport._running is True
await transport.close()
assert transport._running is False
asyncio.run(_run())
def test_receive_not_implemented(self) -> None:
transport = SSETransport()
with pytest.raises(NotImplementedError):
asyncio.run(transport.receive())
================================================
FILE: tests/test_memory_system.py
================================================
"""Tests for the persistent memory system (40+ tests).
Covers:
- MemoryStore CRUD operations
- Vector embedding generation (mocked)
- Similarity retrieval
- Time decay computation
- Confidence updates
- Persistence (JSONL read/write)
- IdeationMemory, ExperimentMemory, WritingMemory
"""
from __future__ import annotations
import json
import math
from datetime import datetime, timezone, timedelta
from pathlib import Path
import pytest
from researchclaw.memory.store import MemoryEntry, MemoryStore, VALID_CATEGORIES
from researchclaw.memory.decay import time_decay_weight, confidence_update
from researchclaw.memory.embeddings import EmbeddingProvider, _tokenize, _hash_token
from researchclaw.memory.retriever import MemoryRetriever, cosine_similarity
from researchclaw.memory.ideation_memory import IdeationMemory
from researchclaw.memory.experiment_memory import ExperimentMemory
from researchclaw.memory.writing_memory import WritingMemory
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture
def tmp_store_dir(tmp_path: Path) -> Path:
d = tmp_path / "memory_store"
d.mkdir()
return d
@pytest.fixture
def store(tmp_store_dir: Path) -> MemoryStore:
return MemoryStore(tmp_store_dir)
@pytest.fixture
def populated_store(store: MemoryStore) -> MemoryStore:
store.add("ideation", "Topic: RL for robotics\nOutcome: success", {"run_id": "r1"})
store.add("ideation", "Topic: Meta-learning\nOutcome: failure", {"run_id": "r2"})
store.add("experiment", "Task: classification\nHP: lr=0.001", {"run_id": "r1"})
store.add("experiment", "Trick: mixed precision\nImprovement: 5%", {"run_id": "r2"})
store.add("writing", "Feedback: clarity\nResolution: rewrite", {"run_id": "r1"})
return store
@pytest.fixture
def embedding_fn() -> object:
"""Simple deterministic embedding for testing."""
def _embed(text: str) -> list[float]:
vec = [0.0] * 16
for i, ch in enumerate(text[:16]):
vec[i] = ord(ch) / 256.0
norm = math.sqrt(sum(v * v for v in vec)) or 1.0
return [v / norm for v in vec]
return _embed
# ── MemoryStore CRUD ─────────────────────────────────────────────────
class TestMemoryStoreCRUD:
def test_add_entry(self, store: MemoryStore) -> None:
entry_id = store.add("ideation", "test content", {"key": "value"})
assert entry_id
assert store.count("ideation") == 1
def test_add_invalid_category(self, store: MemoryStore) -> None:
with pytest.raises(ValueError, match="Invalid category"):
store.add("invalid_cat", "content")
def test_add_all_categories(self, store: MemoryStore) -> None:
for cat in VALID_CATEGORIES:
store.add(cat, f"content for {cat}")
assert store.count() == 3
def test_get_entry(self, store: MemoryStore) -> None:
entry_id = store.add("ideation", "findme")
entry = store.get(entry_id)
assert entry is not None
assert entry.content == "findme"
assert entry.category == "ideation"
def test_get_nonexistent(self, store: MemoryStore) -> None:
assert store.get("nonexistent_id") is None
def test_get_all_no_filter(self, populated_store: MemoryStore) -> None:
all_entries = populated_store.get_all()
assert len(all_entries) == 5
def test_get_all_with_filter(self, populated_store: MemoryStore) -> None:
ideation = populated_store.get_all("ideation")
assert len(ideation) == 2
def test_update_confidence_success(self, store: MemoryStore) -> None:
entry_id = store.add("ideation", "conf test", confidence=0.5)
assert store.update_confidence(entry_id, 0.1)
entry = store.get(entry_id)
assert entry is not None
assert abs(entry.confidence - 0.6) < 1e-6
def test_update_confidence_clamp_high(self, store: MemoryStore) -> None:
entry_id = store.add("ideation", "test", confidence=0.95)
store.update_confidence(entry_id, 0.2)
entry = store.get(entry_id)
assert entry is not None
assert entry.confidence == 1.0
def test_update_confidence_clamp_low(self, store: MemoryStore) -> None:
entry_id = store.add("ideation", "test", confidence=0.1)
store.update_confidence(entry_id, -0.5)
entry = store.get(entry_id)
assert entry is not None
assert entry.confidence == 0.0
def test_update_confidence_nonexistent(self, store: MemoryStore) -> None:
assert not store.update_confidence("nope", 0.1)
def test_mark_accessed(self, store: MemoryStore) -> None:
entry_id = store.add("ideation", "access test")
entry = store.get(entry_id)
assert entry is not None
assert entry.access_count == 0
store.mark_accessed(entry_id)
entry = store.get(entry_id)
assert entry is not None
assert entry.access_count == 1
def test_capacity_enforcement(self, tmp_store_dir: Path) -> None:
store = MemoryStore(tmp_store_dir, max_entries_per_category=3)
for i in range(5):
store.add("ideation", f"entry {i}", confidence=i * 0.2)
assert store.count("ideation") == 3
# Should keep highest confidence entries
entries = store.get_all("ideation")
confidences = [e.confidence for e in entries]
assert min(confidences) >= 0.4 # lowest 2 (0.0, 0.2) should be pruned
def test_count_empty(self, store: MemoryStore) -> None:
assert store.count() == 0
assert store.count("ideation") == 0
# ── Persistence ──────────────────────────────────────────────────────
class TestMemoryPersistence:
def test_save_and_load(self, tmp_store_dir: Path) -> None:
store = MemoryStore(tmp_store_dir)
store.add("ideation", "persistent content", {"key": "val"})
store.add("experiment", "exp content")
store.save()
store2 = MemoryStore(tmp_store_dir)
loaded = store2.load()
assert loaded == 2
assert store2.count() == 2
def test_save_creates_directory(self, tmp_path: Path) -> None:
new_dir = tmp_path / "new" / "nested" / "dir"
store = MemoryStore(new_dir)
store.add("ideation", "test")
store.save()
assert (new_dir / "ideation.jsonl").exists()
def test_load_empty_dir(self, tmp_store_dir: Path) -> None:
store = MemoryStore(tmp_store_dir)
assert store.load() == 0
def test_load_malformed_jsonl(self, tmp_store_dir: Path) -> None:
(tmp_store_dir / "ideation.jsonl").write_text(
'{"id": "a", "category": "ideation"}\nnot json\n',
encoding="utf-8",
)
store = MemoryStore(tmp_store_dir)
loaded = store.load()
assert loaded == 1 # only valid entry loaded
def test_roundtrip_preserves_data(self, tmp_store_dir: Path) -> None:
store = MemoryStore(tmp_store_dir)
entry_id = store.add(
"experiment", "test content",
metadata={"key": "value"},
embedding=[0.1, 0.2, 0.3],
confidence=0.7,
)
store.save()
store2 = MemoryStore(tmp_store_dir)
store2.load()
entry = store2.get(entry_id)
assert entry is not None
assert entry.content == "test content"
assert entry.metadata == {"key": "value"}
assert entry.embedding == [0.1, 0.2, 0.3]
assert abs(entry.confidence - 0.7) < 1e-6
# ── Prune ────────────────────────────────────────────────────────────
class TestMemoryPrune:
def test_prune_low_confidence(self, store: MemoryStore) -> None:
store.add("ideation", "low conf", confidence=0.1)
store.add("ideation", "high conf", confidence=0.8)
removed = store.prune(confidence_threshold=0.5)
assert removed == 1
assert store.count("ideation") == 1
def test_prune_nothing_to_remove(self, store: MemoryStore) -> None:
store.add("ideation", "good", confidence=0.9)
removed = store.prune()
assert removed == 0
# ── MemoryEntry ──────────────────────────────────────────────────────
class TestMemoryEntry:
def test_to_dict(self) -> None:
entry = MemoryEntry(
id="abc", category="ideation", content="test",
metadata={}, embedding=[], confidence=0.5,
created_at="2024-01-01T00:00:00+00:00",
last_accessed="2024-01-01T00:00:00+00:00",
access_count=0,
)
d = entry.to_dict()
assert d["id"] == "abc"
assert d["category"] == "ideation"
def test_from_dict(self) -> None:
data = {
"id": "xyz", "category": "experiment", "content": "hp test",
"metadata": {"run": "1"}, "embedding": [0.1], "confidence": 0.6,
"created_at": "2024-06-01T00:00:00+00:00",
"last_accessed": "2024-06-01T00:00:00+00:00",
"access_count": 3,
}
entry = MemoryEntry.from_dict(data)
assert entry.id == "xyz"
assert entry.access_count == 3
def test_from_dict_defaults(self) -> None:
entry = MemoryEntry.from_dict({})
assert entry.id == ""
assert entry.confidence == 0.5
assert entry.access_count == 0
# ── Time Decay ───────────────────────────────────────────────────────
class TestTimeDecay:
def test_fresh_entry(self) -> None:
now = datetime.now(timezone.utc)
w = time_decay_weight(now, half_life_days=90.0, now=now)
assert abs(w - 1.0) < 1e-6
def test_half_life(self) -> None:
now = datetime.now(timezone.utc)
half = now - timedelta(days=90)
w = time_decay_weight(half, half_life_days=90.0, now=now)
assert abs(w - 0.5) < 0.01
def test_expired(self) -> None:
now = datetime.now(timezone.utc)
old = now - timedelta(days=400)
w = time_decay_weight(old, half_life_days=90.0, max_age_days=365.0, now=now)
assert w == 0.0
def test_future_timestamp(self) -> None:
now = datetime.now(timezone.utc)
future = now + timedelta(days=10)
w = time_decay_weight(future, now=now)
assert w == 1.0
def test_naive_datetime(self) -> None:
now = datetime.now(timezone.utc)
naive = now.replace(tzinfo=None)
w = time_decay_weight(naive, now=now)
assert w > 0.0
class TestConfidenceUpdate:
def test_increase(self) -> None:
assert confidence_update(0.5, 0.1) == 0.6
def test_decrease(self) -> None:
assert confidence_update(0.5, -0.2) == pytest.approx(0.3)
def test_clamp_ceiling(self) -> None:
assert confidence_update(0.95, 0.2) == 1.0
def test_clamp_floor(self) -> None:
assert confidence_update(0.1, -0.5) == 0.0
# ── Embeddings ───────────────────────────────────────────────────────
class TestEmbeddings:
def test_tfidf_fallback(self) -> None:
provider = EmbeddingProvider()
vec = provider.embed("hello world test")
assert len(vec) > 0
assert isinstance(vec[0], float)
def test_tfidf_normalized(self) -> None:
provider = EmbeddingProvider()
vec = provider.embed("deep learning neural network")
norm = math.sqrt(sum(v * v for v in vec))
assert abs(norm - 1.0) < 0.01
def test_tfidf_empty(self) -> None:
provider = EmbeddingProvider()
# Force TF-IDF backend to test zero-vector behavior
provider._backend = "tfidf"
provider._dim = 256
vec = provider.embed("")
assert all(v == 0.0 for v in vec)
def test_tokenize(self) -> None:
tokens = _tokenize("Hello, World! 123")
assert "hello" in tokens
assert "world" in tokens
assert "123" in tokens
def test_hash_token_deterministic(self) -> None:
a = _hash_token("test", 256)
b = _hash_token("test", 256)
assert a == b
def test_embed_batch(self) -> None:
provider = EmbeddingProvider()
vecs = provider.embed_batch(["hello", "world"])
assert len(vecs) == 2
def test_backend_detection(self) -> None:
provider = EmbeddingProvider()
backend = provider.backend
assert backend in ("api", "sentence_transformers", "tfidf")
# ── Retriever ────────────────────────────────────────────────────────
class TestRetriever:
def test_cosine_similarity_identical(self) -> None:
vec = [1.0, 0.0, 0.0]
assert abs(cosine_similarity(vec, vec) - 1.0) < 1e-6
def test_cosine_similarity_orthogonal(self) -> None:
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(cosine_similarity(a, b)) < 1e-6
def test_cosine_similarity_opposite(self) -> None:
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert abs(cosine_similarity(a, b) + 1.0) < 1e-6
def test_cosine_similarity_empty(self) -> None:
assert cosine_similarity([], []) == 0.0
def test_cosine_similarity_mismatched_length(self) -> None:
assert cosine_similarity([1.0], [1.0, 2.0]) == 0.0
def test_recall_empty_store(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
results = retriever.recall([0.1, 0.2], category="ideation")
assert results == []
def test_recall_returns_results(self, store: MemoryStore) -> None:
store.add("ideation", "RL research", embedding=[1.0, 0.0, 0.0])
store.add("ideation", "NLP research", embedding=[0.0, 1.0, 0.0])
retriever = MemoryRetriever(store)
results = retriever.recall([0.9, 0.1, 0.0], category="ideation", top_k=1)
assert len(results) == 1
assert "RL" in results[0][0].content
def test_recall_respects_top_k(self, store: MemoryStore) -> None:
for i in range(10):
store.add("ideation", f"entry {i}", embedding=[float(i)] * 3)
retriever = MemoryRetriever(store)
results = retriever.recall([5.0, 5.0, 5.0], top_k=3)
assert len(results) == 3
def test_format_for_prompt(self, store: MemoryStore) -> None:
store.add("ideation", "Topic: RL", embedding=[1.0])
retriever = MemoryRetriever(store)
results = retriever.recall([1.0])
text = retriever.format_for_prompt(results)
assert "ideation" in text
def test_format_for_prompt_empty(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
text = retriever.format_for_prompt([])
assert text == ""
# ── Ideation Memory ──────────────────────────────────────────────────
class TestIdeationMemory:
def test_record_topic_success(self, store: MemoryStore, embedding_fn: object) -> None:
retriever = MemoryRetriever(store)
im = IdeationMemory(store, retriever, embed_fn=embedding_fn)
entry_id = im.record_topic_outcome("RL for robotics", "success", 8.0)
assert entry_id
assert store.count("ideation") == 1
def test_record_topic_failure(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
im = IdeationMemory(store, retriever)
im.record_topic_outcome("Bad topic", "failure", 2.0, run_id="r1")
entries = store.get_all("ideation")
assert entries[0].metadata["outcome"] == "failure"
def test_record_hypothesis(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
im = IdeationMemory(store, retriever)
im.record_hypothesis("H1: X is better than Y", True, "Validated")
assert store.count("ideation") == 1
def test_get_anti_patterns(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
im = IdeationMemory(store, retriever)
im.record_topic_outcome("Bad direction", "failure", 1.0)
im.record_topic_outcome("Good direction", "success", 9.0)
patterns = im.get_anti_patterns()
assert len(patterns) == 1
assert "Bad" in patterns[0]
def test_recall_similar_topics_empty(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
im = IdeationMemory(store, retriever)
result = im.recall_similar_topics("test query")
assert result == ""
# ── Experiment Memory ────────────────────────────────────────────────
class TestExperimentMemory:
def test_record_hyperparams(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
em = ExperimentMemory(store, retriever)
em.record_hyperparams("image_cls", {"lr": 0.001, "bs": 32}, 0.95)
assert store.count("experiment") == 1
def test_record_architecture(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
em = ExperimentMemory(store, retriever)
em.record_architecture("image_cls", "ResNet-18", 0.96)
entry = store.get_all("experiment")[0]
assert "ResNet" in entry.content
def test_record_training_trick(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
em = ExperimentMemory(store, retriever)
em.record_training_trick("CosineAnnealing", 0.03, "CIFAR-10 training")
entry = store.get_all("experiment")[0]
assert "CosineAnnealing" in entry.content
def test_recall_best_configs_empty(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
em = ExperimentMemory(store, retriever)
result = em.recall_best_configs("anything")
assert result == ""
# ── Writing Memory ───────────────────────────────────────────────────
class TestWritingMemory:
def test_record_review_feedback(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
wm = WritingMemory(store, retriever)
wm.record_review_feedback("clarity", "Section 3 is unclear", "Rewrote S3")
assert store.count("writing") == 1
def test_record_successful_structure(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
wm = WritingMemory(store, retriever)
wm.record_successful_structure("intro", "Problem-Gap-Contribution", 8.5)
entry = store.get_all("writing")[0]
assert entry.metadata["section"] == "intro"
def test_recall_writing_tips_empty(self, store: MemoryStore) -> None:
retriever = MemoryRetriever(store)
wm = WritingMemory(store, retriever)
result = wm.recall_writing_tips("method", "RL paper")
assert result == ""
================================================
FILE: tests/test_metaclaw_bridge/__init__.py
================================================
================================================
FILE: tests/test_metaclaw_bridge/test_config.py
================================================
"""Tests for MetaClaw bridge configuration parsing."""
from researchclaw.config import RCConfig
def _minimal_config_data(**overrides):
"""Return minimal valid config data with metaclaw_bridge overrides."""
base = {
"project": {"name": "test", "mode": "full-auto"},
"research": {"topic": "test topic", "domains": ["ml"]},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "console"},
"knowledge_base": {"backend": "markdown", "root": "docs/kb"},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:8080",
"api_key_env": "TEST_KEY",
"api_key": "sk-test",
"primary_model": "gpt-4o",
},
}
base.update(overrides)
return base
def test_metaclaw_bridge_defaults():
"""MetaClaw bridge should have sensible defaults when not configured."""
data = _minimal_config_data()
cfg = RCConfig.from_dict(data, check_paths=False)
assert cfg.metaclaw_bridge.enabled is False
assert cfg.metaclaw_bridge.proxy_url == "http://localhost:30000"
assert cfg.metaclaw_bridge.prm.enabled is False
assert cfg.metaclaw_bridge.lesson_to_skill.enabled is True
def test_metaclaw_bridge_enabled():
"""MetaClaw bridge config should be parsed when provided."""
data = _minimal_config_data(
metaclaw_bridge={
"enabled": True,
"proxy_url": "http://localhost:31000",
"skills_dir": "/tmp/skills",
"prm": {
"enabled": True,
"api_base": "http://localhost:8080",
"api_key": "test-key",
"model": "gpt-5.4",
"votes": 5,
"gate_stages": [5, 20],
},
"lesson_to_skill": {
"enabled": True,
"min_severity": "warning",
"max_skills_per_run": 5,
},
}
)
cfg = RCConfig.from_dict(data, check_paths=False)
assert cfg.metaclaw_bridge.enabled is True
assert cfg.metaclaw_bridge.proxy_url == "http://localhost:31000"
assert cfg.metaclaw_bridge.prm.enabled is True
assert cfg.metaclaw_bridge.prm.votes == 5
assert cfg.metaclaw_bridge.prm.gate_stages == (5, 20)
assert cfg.metaclaw_bridge.lesson_to_skill.min_severity == "warning"
assert cfg.metaclaw_bridge.lesson_to_skill.max_skills_per_run == 5
def test_metaclaw_bridge_none_is_default():
"""When metaclaw_bridge is None/missing, defaults should apply."""
data = _minimal_config_data(metaclaw_bridge=None)
cfg = RCConfig.from_dict(data, check_paths=False)
assert cfg.metaclaw_bridge.enabled is False
================================================
FILE: tests/test_metaclaw_bridge/test_lesson_to_skill.py
================================================
"""Tests for lesson-to-skill conversion module."""
import json
import tempfile
from pathlib import Path
from researchclaw.metaclaw_bridge.lesson_to_skill import (
_format_lessons,
_list_existing_skill_names,
_parse_skills_response,
_write_skill,
)
from researchclaw.evolution import LessonEntry
def _make_lesson(stage: str = "experiment_run", severity: str = "error") -> LessonEntry:
return LessonEntry(
stage_name=stage,
stage_num=12,
category="experiment",
severity=severity,
description="Metric NaN detected in loss computation",
timestamp="2026-03-15T00:00:00+00:00",
run_id="test-001",
)
def test_format_lessons():
lessons = [_make_lesson(), _make_lesson("code_generation")]
text = _format_lessons(lessons)
assert "experiment_run" in text
assert "code_generation" in text
assert "NaN" in text
def test_list_existing_skills(tmp_path):
(tmp_path / "skill-a").mkdir()
(tmp_path / "skill-b").mkdir()
(tmp_path / "not-a-skill.txt").write_text("x")
names = _list_existing_skill_names(tmp_path)
assert "skill-a" in names
assert "skill-b" in names
assert "not-a-skill.txt" not in names
def test_list_existing_skills_missing_dir():
names = _list_existing_skill_names(Path("/nonexistent/dir"))
assert names == []
def test_parse_skills_response_valid():
response = json.dumps([
{
"name": "arc-fix-nan",
"description": "Prevent NaN in loss",
"category": "coding",
"content": "# Fix NaN\n1. Check inputs\n2. Use grad clipping",
}
])
parsed = _parse_skills_response(response)
assert len(parsed) == 1
assert parsed[0]["name"] == "arc-fix-nan"
def test_parse_skills_response_with_code_fence():
response = "```json\n" + json.dumps([
{
"name": "arc-test",
"description": "test",
"category": "coding",
"content": "test content",
}
]) + "\n```"
parsed = _parse_skills_response(response)
assert len(parsed) == 1
def test_parse_skills_response_invalid():
assert _parse_skills_response("not json") == []
assert _parse_skills_response("[]") == []
def test_write_skill(tmp_path):
skill = {
"name": "arc-test-skill",
"description": "A test skill",
"category": "coding",
"content": "# Test\n1. Do something",
}
path = _write_skill(tmp_path, skill)
assert path is not None
assert path.exists()
content = path.read_text()
assert "name: arc-test-skill" in content
assert "category: coding" in content
assert "# Test" in content
================================================
FILE: tests/test_metaclaw_bridge/test_prm_gate.py
================================================
"""Tests for PRM quality gate module."""
from unittest.mock import patch, MagicMock
from researchclaw.metaclaw_bridge.prm_gate import (
ResearchPRMGate,
_GATE_INSTRUCTIONS,
)
def test_gate_instructions_cover_expected_stages():
"""PRM gate instructions should cover key gate stages."""
assert 5 in _GATE_INSTRUCTIONS
assert 9 in _GATE_INSTRUCTIONS
assert 15 in _GATE_INSTRUCTIONS
assert 20 in _GATE_INSTRUCTIONS
def test_should_gate():
gate = ResearchPRMGate(
api_base="http://test",
api_key="test",
)
assert gate.should_gate(5) is True
assert gate.should_gate(9) is True
assert gate.should_gate(15) is True
assert gate.should_gate(20) is True
assert gate.should_gate(1) is False
assert gate.should_gate(10) is False
def test_from_bridge_config_disabled():
"""Should return None when PRM is not enabled."""
config = MagicMock()
config.enabled = False
assert ResearchPRMGate.from_bridge_config(config) is None
def test_from_bridge_config_enabled():
"""Should create a gate when properly configured."""
config = MagicMock()
config.enabled = True
config.api_base = "http://test"
config.api_key = "test-key"
config.api_key_env = ""
config.model = "gpt-5.4"
config.votes = 3
config.temperature = 0.6
gate = ResearchPRMGate.from_bridge_config(config)
assert gate is not None
assert gate.api_base == "http://test"
assert gate.votes == 3
@patch("researchclaw.metaclaw_bridge.prm_gate._single_judge_call")
def test_evaluate_stage_majority_pass(mock_call):
"""Should return 1.0 when majority votes pass."""
mock_call.side_effect = [1.0, 1.0, -1.0]
gate = ResearchPRMGate(
api_base="http://test",
api_key="test",
votes=3,
)
score = gate.evaluate_stage(20, "This is a good paper.")
assert score == 1.0
@patch("researchclaw.metaclaw_bridge.prm_gate._single_judge_call")
def test_evaluate_stage_majority_fail(mock_call):
"""Should return -1.0 when majority votes fail."""
mock_call.side_effect = [-1.0, -1.0, 1.0]
gate = ResearchPRMGate(
api_base="http://test",
api_key="test",
votes=3,
)
score = gate.evaluate_stage(20, "This paper has critical issues.")
assert score == -1.0
@patch("researchclaw.metaclaw_bridge.prm_gate._single_judge_call")
def test_evaluate_stage_all_failed(mock_call):
"""Should return 0.0 when all judge calls fail."""
mock_call.side_effect = [None, None, None]
gate = ResearchPRMGate(
api_base="http://test",
api_key="test",
votes=3,
)
score = gate.evaluate_stage(20, "test")
assert score == 0.0
================================================
FILE: tests/test_metaclaw_bridge/test_session.py
================================================
"""Tests for MetaClaw session management module."""
from researchclaw.metaclaw_bridge.session import MetaClawSession
def test_session_creation():
session = MetaClawSession("test-run-001")
assert session.session_id == "arc-test-run-001"
assert session.is_active is True
def test_session_headers():
session = MetaClawSession("run-123")
headers = session.get_headers("hypothesis_gen")
assert headers["X-Session-Id"] == "arc-run-123"
assert headers["X-Turn-Type"] == "main"
assert headers["X-AutoRC-Stage"] == "hypothesis_gen"
def test_session_headers_no_stage():
session = MetaClawSession("run-123")
headers = session.get_headers()
assert "X-AutoRC-Stage" not in headers
def test_session_end():
session = MetaClawSession("run-456")
end_headers = session.end()
assert end_headers["X-Session-Done"] == "true"
assert end_headers["X-Session-Id"] == "arc-run-456"
assert session.is_active is False
================================================
FILE: tests/test_metaclaw_bridge/test_skill_feedback.py
================================================
"""Tests for skill feedback tracking module."""
from pathlib import Path
from researchclaw.metaclaw_bridge.skill_feedback import (
SkillEffectivenessRecord,
SkillFeedbackStore,
record_stage_skills,
)
def test_append_and_load(tmp_path):
store = SkillFeedbackStore(tmp_path / "feedback.jsonl")
rec = SkillEffectivenessRecord(
skill_name="hypothesis-formulation",
stage_name="hypothesis_gen",
run_id="test-001",
stage_success=True,
timestamp="2026-03-15T00:00:00+00:00",
)
store.append(rec)
loaded = store.load_all()
assert len(loaded) == 1
assert loaded[0].skill_name == "hypothesis-formulation"
assert loaded[0].stage_success is True
def test_append_many(tmp_path):
store = SkillFeedbackStore(tmp_path / "feedback.jsonl")
records = [
SkillEffectivenessRecord("skill-a", "stage-1", "run-1", True, "2026-01-01"),
SkillEffectivenessRecord("skill-b", "stage-2", "run-1", False, "2026-01-01"),
]
store.append_many(records)
assert len(store.load_all()) == 2
def test_compute_stats(tmp_path):
store = SkillFeedbackStore(tmp_path / "feedback.jsonl")
records = [
SkillEffectivenessRecord("skill-a", "s1", "r1", True, "t1"),
SkillEffectivenessRecord("skill-a", "s2", "r1", False, "t1"),
SkillEffectivenessRecord("skill-a", "s3", "r2", True, "t2"),
SkillEffectivenessRecord("skill-b", "s1", "r1", False, "t1"),
]
store.append_many(records)
stats = store.compute_skill_stats()
assert stats["skill-a"]["total"] == 3
assert stats["skill-a"]["successes"] == 2
assert abs(stats["skill-a"]["success_rate"] - 2 / 3) < 0.01
assert stats["skill-b"]["total"] == 1
assert stats["skill-b"]["success_rate"] == 0.0
def test_record_stage_skills(tmp_path):
store = SkillFeedbackStore(tmp_path / "feedback.jsonl")
record_stage_skills(
store,
stage_name="hypothesis_gen",
run_id="test-002",
stage_success=True,
active_skills=["hypothesis-formulation", "research-gap-identification"],
)
loaded = store.load_all()
assert len(loaded) == 2
names = {r.skill_name for r in loaded}
assert names == {"hypothesis-formulation", "research-gap-identification"}
def test_empty_store(tmp_path):
store = SkillFeedbackStore(tmp_path / "nonexistent.jsonl")
assert store.load_all() == []
assert store.compute_skill_stats() == {}
================================================
FILE: tests/test_metaclaw_bridge/test_stage_skill_map.py
================================================
"""Tests for stage-skill mapping module."""
from researchclaw.metaclaw_bridge.stage_skill_map import (
STAGE_SKILL_MAP,
LESSON_CATEGORY_TO_SKILL_CATEGORY,
get_stage_config,
)
def test_all_23_stages_mapped():
"""All 23 pipeline stages should have a mapping entry."""
expected_stages = [
"topic_init", "problem_decompose", "search_strategy",
"literature_collect", "literature_screen", "knowledge_extract",
"synthesis", "hypothesis_gen", "experiment_design",
"code_generation", "resource_planning", "experiment_run",
"iterative_refine", "result_analysis", "research_decision",
"paper_outline", "paper_draft", "peer_review",
"paper_revision", "quality_gate", "knowledge_archive",
"export_publish", "citation_verify",
]
for stage in expected_stages:
assert stage in STAGE_SKILL_MAP, f"Missing mapping for {stage}"
def test_stage_config_has_required_keys():
"""Each stage config should have task_type, skills, and top_k."""
for stage_name, config in STAGE_SKILL_MAP.items():
assert "task_type" in config, f"{stage_name} missing task_type"
assert "skills" in config, f"{stage_name} missing skills"
assert "top_k" in config, f"{stage_name} missing top_k"
assert isinstance(config["skills"], list)
assert isinstance(config["top_k"], int)
assert config["top_k"] > 0
def test_get_stage_config_known():
cfg = get_stage_config("hypothesis_gen")
assert cfg["task_type"] == "research"
assert "hypothesis-formulation" in cfg["skills"]
def test_get_stage_config_unknown_returns_default():
cfg = get_stage_config("nonexistent_stage")
assert cfg["task_type"] == "research"
assert cfg["top_k"] == 4
def test_lesson_category_mapping_complete():
"""All lesson categories should map to a skill category."""
expected = ["system", "experiment", "writing", "analysis", "literature", "pipeline"]
for cat in expected:
assert cat in LESSON_CATEGORY_TO_SKILL_CATEGORY
================================================
FILE: tests/test_metric_parser.py
================================================
"""Tests for the universal metric parser."""
from __future__ import annotations
import json
import math
import pytest
from pathlib import Path
from researchclaw.experiment.metrics import (
ExperimentResults,
MetricType,
UniversalMetricParser,
)
@pytest.fixture
def parser():
return UniversalMetricParser()
@pytest.fixture
def tmp_run_dir(tmp_path):
return tmp_path
# ---------------------------------------------------------------------------
# JSON parsing tests
# ---------------------------------------------------------------------------
class TestJSONParsing:
def test_parse_comparison_results(self, parser, tmp_run_dir):
data = {
"experiment_type": "comparison",
"conditions": {
"proposed_method": {
"seed_42": {"accuracy": 0.95, "f1": 0.93},
"seed_123": {"accuracy": 0.94, "f1": 0.92},
},
"baseline": {
"seed_42": {"accuracy": 0.88, "f1": 0.85},
},
},
"metadata": {
"domain": "ml_vision",
"total_runtime_sec": 120.5,
},
}
(tmp_run_dir / "results.json").write_text(json.dumps(data))
result = parser.parse(tmp_run_dir)
assert result.source == "json"
assert result.experiment_type == "comparison"
assert result.domain == "ml_vision"
assert "proposed_method" in result.conditions
flat = result.to_flat_metrics()
assert "proposed_method/accuracy" in flat
def test_parse_convergence_results(self, parser, tmp_run_dir):
data = {
"experiment_type": "convergence",
"convergence": {
"euler": [
{"h": 0.1, "error": 0.05},
{"h": 0.05, "error": 0.012},
{"h": 0.025, "error": 0.003},
],
"rk4": [
{"h": 0.1, "error": 0.001},
{"h": 0.05, "error": 6.25e-5},
{"h": 0.025, "error": 3.9e-6},
],
},
}
(tmp_run_dir / "results.json").write_text(json.dumps(data))
result = parser.parse(tmp_run_dir)
assert result.source == "json"
assert "euler" in result.convergence
assert len(result.convergence["euler"]) == 3
flat = result.to_flat_metrics()
assert "euler/error" in flat # last point
def test_parse_regression_table(self, parser, tmp_run_dir):
data = {
"experiment_type": "progressive_spec",
"regression_table": {
"spec_1_ols": {"coeff": 0.15, "se": 0.03, "p": 0.001, "n": 5000, "r2": 0.12},
"spec_2_fe": {"coeff": 0.11, "se": 0.02, "p": 0.001, "n": 5000, "r2": 0.35},
},
}
(tmp_run_dir / "results.json").write_text(json.dumps(data))
result = parser.parse(tmp_run_dir)
assert result.source == "json"
assert "spec_1_ols" in result.regression_table
flat = result.to_flat_metrics()
assert "spec_1_ols/coeff" in flat
assert flat["spec_1_ols/coeff"] == 0.15
def test_parse_top_level_scalars(self, parser, tmp_run_dir):
data = {"accuracy": 0.95, "loss": 0.32}
(tmp_run_dir / "results.json").write_text(json.dumps(data))
result = parser.parse(tmp_run_dir)
assert result.scalars["accuracy"] == 0.95
assert result.scalars["loss"] == 0.32
def test_skip_nan_inf(self, parser, tmp_run_dir):
data = {
"conditions": {
"method": {
"seed_1": {"accuracy": float("nan"), "f1": 0.9},
},
},
}
(tmp_run_dir / "results.json").write_text(json.dumps(data))
result = parser.parse(tmp_run_dir)
flat = result.to_flat_metrics()
# NaN should be excluded
for k, v in flat.items():
assert math.isfinite(v), f"Non-finite value: {k}={v}"
def test_invalid_json_falls_through(self, parser, tmp_run_dir):
(tmp_run_dir / "results.json").write_text("not valid json{{{")
result = parser.parse(tmp_run_dir, stdout="metric_a: 0.5")
# Should fallback to stdout
assert result.source == "stdout"
# ---------------------------------------------------------------------------
# CSV parsing tests
# ---------------------------------------------------------------------------
class TestCSVParsing:
def test_parse_condition_csv(self, parser, tmp_run_dir):
csv_content = "condition,seed,metric,value\nmethod_a,42,accuracy,0.95\nmethod_b,42,accuracy,0.88\n"
(tmp_run_dir / "results.csv").write_text(csv_content)
result = parser.parse(tmp_run_dir)
assert result.source == "csv"
assert "method_a/accuracy" in result.scalars
assert result.scalars["method_a/accuracy"] == 0.95
def test_parse_convergence_csv(self, parser, tmp_run_dir):
csv_content = "method,h,error\neuler,0.1,0.05\neuler,0.05,0.012\nrk4,0.1,0.001\n"
(tmp_run_dir / "results.csv").write_text(csv_content)
result = parser.parse(tmp_run_dir)
assert result.source == "csv"
assert "euler" in result.convergence
assert len(result.convergence["euler"]) == 2
def test_csv_skip_invalid(self, parser, tmp_run_dir):
csv_content = "condition,metric,value\nmethod,accuracy,not_a_number\n"
(tmp_run_dir / "results.csv").write_text(csv_content)
result = parser.parse(tmp_run_dir)
assert result.source == "csv"
assert len(result.scalars) == 0
# ---------------------------------------------------------------------------
# stdout fallback tests
# ---------------------------------------------------------------------------
class TestStdoutParsing:
def test_parse_plain_metrics(self, parser, tmp_run_dir):
result = parser.parse(tmp_run_dir, stdout="accuracy: 0.95\nloss: 0.32\n")
assert result.source == "stdout"
assert result.scalars["accuracy"] == 0.95
assert result.scalars["loss"] == 0.32
def test_parse_condition_metrics(self, parser, tmp_run_dir):
stdout = "condition=method_a accuracy: 0.95\ncondition=method_b accuracy: 0.88\n"
result = parser.parse(tmp_run_dir, stdout=stdout)
assert result.source == "stdout"
assert "method_a/accuracy" in result.scalars
def test_fallback_to_stdout_log(self, parser, tmp_run_dir):
(tmp_run_dir / "stdout.log").write_text("metric_x: 1.5\n")
result = parser.parse(tmp_run_dir)
assert result.source == "stdout"
assert result.scalars.get("metric_x") == 1.5
# ---------------------------------------------------------------------------
# ExperimentResults tests
# ---------------------------------------------------------------------------
class TestExperimentResults:
def test_to_flat_metrics_empty(self):
result = ExperimentResults()
assert result.to_flat_metrics() == {}
def test_to_flat_metrics_scalars(self):
result = ExperimentResults(scalars={"a": 1.0, "b": 2.0})
flat = result.to_flat_metrics()
assert flat["a"] == 1.0
assert flat["b"] == 2.0
def test_to_flat_metrics_conditions(self):
result = ExperimentResults(
conditions={
"method": {"seed_1": {"acc": 0.9}, "seed_2": {"acc": 0.91}},
}
)
flat = result.to_flat_metrics()
assert "method/acc" in flat
def test_to_flat_metrics_convergence(self):
result = ExperimentResults(
convergence={
"euler": [
{"h": 0.1, "error": 0.05},
{"h": 0.05, "error": 0.01},
],
}
)
flat = result.to_flat_metrics()
assert "euler/error" in flat
assert flat["euler/error"] == 0.01 # last point
def test_to_flat_metrics_regression(self):
result = ExperimentResults(
regression_table={
"ols": {"coeff": 0.5, "se": 0.1},
}
)
flat = result.to_flat_metrics()
assert flat["ols/coeff"] == 0.5
# ---------------------------------------------------------------------------
# Priority tests (JSON > CSV > stdout)
# ---------------------------------------------------------------------------
class TestParsePriority:
def test_json_takes_priority_over_csv(self, parser, tmp_run_dir):
(tmp_run_dir / "results.json").write_text('{"from_json": 1.0}')
(tmp_run_dir / "results.csv").write_text("condition,metric,value\ncsv,m,2.0\n")
result = parser.parse(tmp_run_dir)
assert result.source == "json"
def test_csv_takes_priority_over_stdout(self, parser, tmp_run_dir):
(tmp_run_dir / "results.csv").write_text("condition,metric,value\ncsv,m,2.0\n")
result = parser.parse(tmp_run_dir, stdout="stdout_metric: 3.0")
assert result.source == "csv"
def test_empty_json_falls_to_csv(self, parser, tmp_run_dir):
(tmp_run_dir / "results.json").write_text("{}")
(tmp_run_dir / "results.csv").write_text("condition,metric,value\ncsv,m,2.0\n")
result = parser.parse(tmp_run_dir)
assert result.source == "csv"
# ---------------------------------------------------------------------------
# MetricType enum tests
# ---------------------------------------------------------------------------
class TestMetricType:
def test_values(self):
assert MetricType.SCALAR.value == "scalar"
assert MetricType.TABLE.value == "table"
assert MetricType.CONVERGENCE.value == "convergence"
assert MetricType.STRUCTURED.value == "structured"
================================================
FILE: tests/test_minimax_provider.py
================================================
"""Tests for MiniMax provider integration.
Covers: provider preset, CLI registration, factory wiring,
temperature clamping, and live API integration.
"""
from __future__ import annotations
import json
import os
import urllib.request
from types import SimpleNamespace
from typing import Any, Mapping
import pytest
from researchclaw.llm import PROVIDER_PRESETS, create_llm_client
from researchclaw.llm.client import LLMClient, LLMConfig, LLMResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _DummyHTTPResponse:
"""Minimal stub for ``urllib.request.urlopen`` results."""
def __init__(self, payload: Mapping[str, Any]):
self._payload = payload
def read(self) -> bytes:
return json.dumps(self._payload).encode("utf-8")
def __enter__(self) -> _DummyHTTPResponse:
return self
def __exit__(self, *a: object) -> None:
return None
def _make_minimax_client(
*,
api_key: str = "test-minimax-key",
primary_model: str = "MiniMax-M2.5",
fallback_models: list[str] | None = None,
) -> LLMClient:
config = LLMConfig(
base_url="https://api.minimax.io/v1",
api_key=api_key,
primary_model=primary_model,
fallback_models=fallback_models or ["MiniMax-M2.5-highspeed"],
)
return LLMClient(config)
# ---------------------------------------------------------------------------
# Unit tests — provider preset
# ---------------------------------------------------------------------------
class TestMiniMaxPreset:
"""Verify MiniMax is registered in PROVIDER_PRESETS."""
def test_minimax_in_provider_presets(self):
assert "minimax" in PROVIDER_PRESETS
def test_minimax_base_url(self):
assert PROVIDER_PRESETS["minimax"]["base_url"] == "https://api.minimax.io/v1"
# ---------------------------------------------------------------------------
# Unit tests — from_rc_config wiring
# ---------------------------------------------------------------------------
class TestMiniMaxFromRCConfig:
"""Verify that LLMClient.from_rc_config resolves MiniMax preset."""
def test_from_rc_config_sets_minimax_base_url(self):
rc_config = SimpleNamespace(
llm=SimpleNamespace(
provider="minimax",
base_url="",
api_key="mk-test",
api_key_env="",
primary_model="MiniMax-M2.5",
fallback_models=("MiniMax-M2.5-highspeed",),
),
)
client = LLMClient.from_rc_config(rc_config)
assert client.config.base_url == "https://api.minimax.io/v1"
assert client.config.api_key == "mk-test"
assert client.config.primary_model == "MiniMax-M2.5"
assert client.config.fallback_models == ["MiniMax-M2.5-highspeed"]
def test_from_rc_config_reads_minimax_api_key_from_env(self, monkeypatch):
monkeypatch.setenv("MINIMAX_API_KEY", "env-minimax-key")
rc_config = SimpleNamespace(
llm=SimpleNamespace(
provider="minimax",
base_url="",
api_key="",
api_key_env="MINIMAX_API_KEY",
primary_model="MiniMax-M2.5",
fallback_models=(),
),
)
client = LLMClient.from_rc_config(rc_config)
assert client.config.api_key == "env-minimax-key"
def test_from_rc_config_custom_base_url_overrides_preset(self):
rc_config = SimpleNamespace(
llm=SimpleNamespace(
provider="minimax",
base_url="https://custom-proxy.example/v1",
api_key="mk-test",
api_key_env="",
primary_model="MiniMax-M2.5",
fallback_models=(),
),
)
client = LLMClient.from_rc_config(rc_config)
assert client.config.base_url == "https://custom-proxy.example/v1"
# ---------------------------------------------------------------------------
# Unit tests — temperature clamping
# ---------------------------------------------------------------------------
class TestMiniMaxTemperatureClamping:
"""MiniMax API requires temperature in [0, 1.0]."""
def _capture_body(
self,
monkeypatch: pytest.MonkeyPatch,
client: LLMClient,
temperature: float,
) -> dict[str, Any]:
captured: dict[str, Any] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["body"] = json.loads(req.data.decode("utf-8"))
return _DummyHTTPResponse(
{"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}]}
)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
client._raw_call(
"MiniMax-M2.5",
[{"role": "user", "content": "hi"}],
1024,
temperature,
False,
)
return captured["body"]
def test_temperature_above_one_clamped(self, monkeypatch):
client = _make_minimax_client()
body = self._capture_body(monkeypatch, client, 1.5)
assert body["temperature"] == 1.0
def test_temperature_within_range_unchanged(self, monkeypatch):
client = _make_minimax_client()
body = self._capture_body(monkeypatch, client, 0.7)
assert body["temperature"] == 0.7
def test_temperature_zero_allowed(self, monkeypatch):
client = _make_minimax_client()
body = self._capture_body(monkeypatch, client, 0.0)
assert body["temperature"] == 0.0
def test_temperature_negative_clamped_to_zero(self, monkeypatch):
client = _make_minimax_client()
body = self._capture_body(monkeypatch, client, -0.1)
assert body["temperature"] == 0.0
def test_non_minimax_url_no_clamping(self, monkeypatch):
"""Non-MiniMax URLs should not clamp temperature."""
config = LLMConfig(
base_url="https://api.openai.com/v1",
api_key="test-key",
primary_model="gpt-4o",
)
client = LLMClient(config)
captured: dict[str, Any] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["body"] = json.loads(req.data.decode("utf-8"))
return _DummyHTTPResponse(
{"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}]}
)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
client._raw_call("gpt-4o", [{"role": "user", "content": "hi"}], 1024, 1.5, False)
assert captured["body"]["temperature"] == 1.5 # no clamping
# ---------------------------------------------------------------------------
# Unit tests — model chain
# ---------------------------------------------------------------------------
class TestMiniMaxModelChain:
"""Model fallback chain for MiniMax."""
def test_model_chain_default(self):
client = _make_minimax_client()
assert client._model_chain == ["MiniMax-M2.5", "MiniMax-M2.5-highspeed"]
def test_model_chain_custom_fallbacks(self):
client = _make_minimax_client(
primary_model="MiniMax-M2.7",
fallback_models=["MiniMax-M2.5", "MiniMax-M2.5-highspeed"],
)
assert client._model_chain == [
"MiniMax-M2.7",
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
]
# ---------------------------------------------------------------------------
# Unit tests — raw call body structure
# ---------------------------------------------------------------------------
class TestMiniMaxRawCall:
"""Verify request body sent to MiniMax API."""
def test_request_body_structure(self, monkeypatch):
client = _make_minimax_client()
captured: dict[str, Any] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["url"] = req.full_url
captured["body"] = json.loads(req.data.decode("utf-8"))
captured["headers"] = {k.lower(): v for k, v in req.headers.items()}
return _DummyHTTPResponse(
{
"model": "MiniMax-M2.5",
"choices": [{"message": {"content": "pong"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 1, "total_tokens": 6},
}
)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
resp = client._raw_call(
"MiniMax-M2.5",
[{"role": "user", "content": "ping"}],
1024,
0.5,
False,
)
assert captured["url"] == "https://api.minimax.io/v1/chat/completions"
assert captured["body"]["model"] == "MiniMax-M2.5"
assert captured["body"]["temperature"] == 0.5
assert captured["headers"]["authorization"] == "Bearer test-minimax-key"
assert resp.content == "pong"
assert resp.model == "MiniMax-M2.5"
def test_json_mode_adds_response_format(self, monkeypatch):
client = _make_minimax_client()
captured: dict[str, Any] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["body"] = json.loads(req.data.decode("utf-8"))
return _DummyHTTPResponse(
{"choices": [{"message": {"content": "{}"}, "finish_reason": "stop"}]}
)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
client._raw_call(
"MiniMax-M2.5",
[{"role": "user", "content": "json"}],
1024,
0.5,
True,
)
assert captured["body"]["response_format"] == {"type": "json_object"}
# ---------------------------------------------------------------------------
# Unit tests — CLI provider registration
# ---------------------------------------------------------------------------
class TestMiniMaxCLI:
"""Verify MiniMax is in the CLI interactive provider menu."""
def test_minimax_in_provider_choices(self):
from researchclaw.cli import _PROVIDER_CHOICES
found = any(v[0] == "minimax" for v in _PROVIDER_CHOICES.values())
assert found, "minimax not found in _PROVIDER_CHOICES"
def test_minimax_in_provider_urls(self):
from researchclaw.cli import _PROVIDER_URLS
assert _PROVIDER_URLS["minimax"] == "https://api.minimax.io/v1"
def test_minimax_in_provider_models(self):
from researchclaw.cli import _PROVIDER_MODELS
primary, fallbacks = _PROVIDER_MODELS["minimax"]
assert primary == "MiniMax-M2.5"
assert "MiniMax-M2.5-highspeed" in fallbacks
# ---------------------------------------------------------------------------
# Unit tests — factory function
# ---------------------------------------------------------------------------
class TestMiniMaxFactory:
"""Verify create_llm_client dispatches correctly for MiniMax."""
def test_create_llm_client_returns_llm_client(self):
from researchclaw.config import LlmConfig, RCConfig
rc_config = SimpleNamespace(
llm=SimpleNamespace(
provider="minimax",
base_url="",
api_key="mk-factory-test",
api_key_env="",
primary_model="MiniMax-M2.5",
fallback_models=(),
),
)
client = create_llm_client(rc_config)
assert isinstance(client, LLMClient)
assert client.config.base_url == "https://api.minimax.io/v1"
assert client._anthropic is None # Not anthropic
# ---------------------------------------------------------------------------
# Unit tests — chat fallback with MiniMax models
# ---------------------------------------------------------------------------
class TestMiniMaxChatFallback:
"""Verify fallback works with MiniMax models."""
def test_fallback_to_highspeed_on_primary_failure(self, monkeypatch):
client = _make_minimax_client()
calls: list[str] = []
def fake_call_with_retry(
self,
model: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
json_mode: bool,
) -> LLMResponse:
calls.append(model)
if model == "MiniMax-M2.5":
raise RuntimeError("rate limited")
return LLMResponse(content="ok", model=model)
monkeypatch.setattr(LLMClient, "_call_with_retry", fake_call_with_retry)
resp = client.chat([{"role": "user", "content": "test"}])
assert calls == ["MiniMax-M2.5", "MiniMax-M2.5-highspeed"]
assert resp.model == "MiniMax-M2.5-highspeed"
# ---------------------------------------------------------------------------
# Integration tests — live MiniMax API (skipped without key)
# ---------------------------------------------------------------------------
@pytest.mark.skipif(
not os.environ.get("MINIMAX_API_KEY"),
reason="MINIMAX_API_KEY not set",
)
class TestMiniMaxLiveAPI:
"""Integration tests against the real MiniMax API."""
def _live_client(self) -> LLMClient:
return LLMClient(
LLMConfig(
base_url="https://api.minimax.io/v1",
api_key=os.environ["MINIMAX_API_KEY"],
primary_model="MiniMax-M2.5",
fallback_models=["MiniMax-M2.5-highspeed"],
max_tokens=64,
timeout_sec=60,
)
)
def test_simple_chat_completion(self):
client = self._live_client()
resp = client.chat(
[{"role": "user", "content": "Say 'hello' and nothing else."}],
max_tokens=16,
temperature=0.1,
)
assert resp.content.strip(), "empty response"
assert "hello" in resp.content.lower()
def test_json_mode(self):
client = self._live_client()
resp = client.chat(
[
{"role": "system", "content": "You are a helpful assistant that responds in JSON."},
{"role": "user", "content": 'Return a JSON object with key "status" set to "ok".'},
],
max_tokens=128,
temperature=0.1,
json_mode=True,
strip_thinking=True,
)
# MiniMax M2.5 may wrap JSON in markdown code fences
import re
text = resp.content.strip()
fence_match = re.search(r"```(?:json)?\s*\n(.*?)```", text, re.DOTALL)
if fence_match:
text = fence_match.group(1).strip()
parsed = json.loads(text)
assert "status" in parsed
def test_preflight_check(self):
client = self._live_client()
ok, msg = client.preflight()
assert ok, f"preflight failed: {msg}"
================================================
FILE: tests/test_neuroscience_domain.py
================================================
"""Tests for computational neuroscience domain support.
Covers profile loading, keyword detection, adapter dispatch, and
prompt block generation for neuroscience_computational and
neuroscience_imaging domains.
"""
from __future__ import annotations
import pytest
from researchclaw.domains.detector import (
DomainProfile,
detect_domain,
detect_domain_id,
get_profile,
_keyword_detect,
_profile_cache,
)
from researchclaw.domains.prompt_adapter import (
MLPromptAdapter,
PromptBlocks,
get_adapter,
)
# ---------------------------------------------------------------------------
# Profile loading
# ---------------------------------------------------------------------------
class TestNeuroscienceProfiles:
def setup_method(self):
_profile_cache.clear()
def test_computational_profile_exists(self):
profile = get_profile("neuroscience_computational")
assert profile is not None
assert profile.domain_id == "neuroscience_computational"
assert profile.display_name == "Computational Neuroscience"
def test_computational_profile_fields(self):
profile = get_profile("neuroscience_computational")
assert profile is not None
assert profile.experiment_paradigm == "simulation"
assert "brian2" in profile.core_libraries
assert "numpy" in profile.core_libraries
assert profile.gpu_required is False
def test_computational_profile_baselines(self):
profile = get_profile("neuroscience_computational")
assert profile is not None
assert len(profile.standard_baselines) >= 2
assert any("LIF" in b or "Integrate-and-Fire" in b
for b in profile.standard_baselines)
def test_imaging_profile_exists(self):
profile = get_profile("neuroscience_imaging")
assert profile is not None
assert profile.domain_id == "neuroscience_imaging"
assert profile.display_name == "Brain Imaging Analysis"
def test_imaging_profile_fields(self):
profile = get_profile("neuroscience_imaging")
assert profile is not None
assert profile.experiment_paradigm == "comparison"
assert "nilearn" in profile.core_libraries
assert "mne" in profile.core_libraries
# ---------------------------------------------------------------------------
# Keyword detection
# ---------------------------------------------------------------------------
class TestNeuroscienceKeywordDetection:
def test_spiking_network(self):
assert _keyword_detect("spiking neural model of cortical columns") == "neuroscience_computational"
def test_brian2(self):
assert _keyword_detect("network model implemented in brian2") == "neuroscience_computational"
def test_hodgkin_huxley(self):
assert _keyword_detect("Hodgkin-Huxley neuron model") == "neuroscience_computational"
def test_integrate_and_fire(self):
assert _keyword_detect("leaky integrate-and-fire model") == "neuroscience_computational"
def test_izhikevich(self):
assert _keyword_detect("Izhikevich neuron dynamics") == "neuroscience_computational"
def test_neural_decoding(self):
assert _keyword_detect("neural decoding of population coding in cortex") == "neuroscience_computational"
def test_firing_rate(self):
assert _keyword_detect("firing rate analysis of cortical neurons") == "neuroscience_computational"
def test_fmri(self):
assert _keyword_detect("fmri resting state analysis") == "neuroscience_imaging"
def test_eeg(self):
assert _keyword_detect("EEG classification for BCI") == "neuroscience_imaging"
def test_nilearn(self):
assert _keyword_detect("brain parcellation with nilearn") == "neuroscience_imaging"
def test_mne_python(self):
assert _keyword_detect("ERP analysis using mne-python") == "neuroscience_imaging"
def test_generic_neuroscience(self):
result = _keyword_detect("neuroscience of learning and memory")
assert result == "neuroscience_computational"
def test_detect_domain_integration(self):
profile = detect_domain("brian2 spiking neural model of cortical microcircuits")
assert profile.domain_id == "neuroscience_computational"
def test_detect_domain_id_shortcut(self):
domain_id = detect_domain_id("brian2 leaky integrate-and-fire cortical model")
assert domain_id == "neuroscience_computational"
# ---------------------------------------------------------------------------
# Adapter dispatch
# ---------------------------------------------------------------------------
class TestNeuroscienceAdapter:
def test_computational_gets_neuroscience_adapter(self):
profile = get_profile("neuroscience_computational")
if profile is None:
pytest.skip("neuroscience_computational profile not found")
adapter = get_adapter(profile)
assert not isinstance(adapter, MLPromptAdapter)
from researchclaw.domains.adapters.neuroscience import (
NeurosciencePromptAdapter,
)
assert isinstance(adapter, NeurosciencePromptAdapter)
def test_imaging_gets_neuroscience_adapter(self):
profile = get_profile("neuroscience_imaging")
if profile is None:
pytest.skip("neuroscience_imaging profile not found")
adapter = get_adapter(profile)
assert not isinstance(adapter, MLPromptAdapter)
def test_code_generation_blocks_nonempty(self):
profile = get_profile("neuroscience_computational")
if profile is None:
pytest.skip("neuroscience_computational profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints
assert blocks.dataset_guidance
assert blocks.output_format_guidance
def test_experiment_design_blocks(self):
profile = get_profile("neuroscience_computational")
if profile is None:
pytest.skip("neuroscience_computational profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_experiment_design_blocks({})
assert "neuroscience" in blocks.experiment_design_context.lower() or \
"Computational Neuroscience" in blocks.experiment_design_context
assert blocks.statistical_test_guidance
def test_result_analysis_blocks(self):
profile = get_profile("neuroscience_computational")
if profile is None:
pytest.skip("neuroscience_computational profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_result_analysis_blocks({})
assert "firing rate" in blocks.result_analysis_hints.lower()
def test_blueprint_context(self):
profile = get_profile("neuroscience_computational")
if profile is None:
pytest.skip("neuroscience_computational profile not found")
adapter = get_adapter(profile)
ctx = adapter.get_blueprint_context()
# Should include file structure and libraries from the profile
if profile.typical_file_structure:
assert "network.py" in ctx or "neuron.py" in ctx
if profile.core_libraries:
assert "brian2" in ctx or "numpy" in ctx
================================================
FILE: tests/test_opencode_bridge.py
================================================
"""Tests for OpenCode Beast Mode bridge."""
from __future__ import annotations
import json
import subprocess
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.config import OpenCodeConfig, _parse_opencode_config
from researchclaw.pipeline.opencode_bridge import (
ComplexityScore,
OpenCodeBridge,
OpenCodeResult,
count_historical_failures,
score_complexity,
)
# ============================================================
# TestComplexityScorer
# ============================================================
class TestComplexityScorer:
"""Tests for complexity scoring logic."""
def test_low_complexity_simple_classification(self):
plan = (
"Train a ResNet-18 on CIFAR-10 with SGD optimizer.\n"
"Report test accuracy as the primary metric.\n"
"condition_0: baseline (lr=0.1)\n"
"condition_1: ablation (lr=0.01)\n"
)
result = score_complexity(plan, topic="Image classification on CIFAR-10")
assert result.score < 0.4
assert result.recommendation == "code_agent"
def test_high_complexity_multimodal_gan(self):
plan = (
"Implement a vision-language GAN with the following components:\n"
"- Encoder: ViT-based image encoder\n"
"- Decoder: Transformer text decoder\n"
"- Generator: produces synthetic image-text pairs\n"
"- Discriminator: classifies real vs fake\n"
"- Critic: provides auxiliary reward signal\n"
"Multiple files needed: model.py, trainer.py, dataset.py\n"
"condition_0: baseline\n"
"condition_1: ablation without critic\n"
"condition_2: ablation without encoder pretraining\n"
"condition_3: ablation with reduced generator\n"
"Custom loss function and custom layer for cross-modal attention.\n"
)
result = score_complexity(
plan, topic="Multi-modal GAN for vision-language synthesis"
)
assert result.score > 0.6
assert result.recommendation == "beast_mode"
def test_historical_failures_boost_score(self):
plan = (
"Train a simple model with encoder and decoder.\n"
"condition_0: baseline\n"
)
score_without = score_complexity(plan, topic="test", historical_failures=0)
score_with = score_complexity(plan, topic="test", historical_failures=3)
assert score_with.score > score_without.score
assert score_with.signals["historical_failure"] > 0
def test_empty_plan_returns_zero(self):
result = score_complexity("", topic="")
assert result.score == 0.0
assert result.recommendation == "legacy"
assert result.reason == "Empty plan"
def test_threshold_boundary(self):
"""A plan scoring exactly at threshold should recommend beast_mode."""
plan = (
"Multi-modal diffusion model with encoder, decoder, discriminator.\n"
"Custom loss, custom layer, wrapper pattern.\n"
"model.py, trainer.py needed.\n"
)
# Use a low threshold to ensure it triggers
result = score_complexity(plan, topic="Diffusion model", threshold=0.2)
assert result.recommendation == "beast_mode"
# Use a very high threshold to ensure it doesn't trigger
result2 = score_complexity(plan, topic="Diffusion model", threshold=0.99)
assert result2.recommendation == "code_agent"
def test_signals_all_present(self):
result = score_complexity("some plan", topic="some topic")
expected_keys = {
"component_count",
"file_count_hint",
"domain_complexity",
"condition_count",
"historical_failure",
"dependency_depth",
}
assert set(result.signals.keys()) == expected_keys
def test_score_clamped_to_unit_interval(self):
"""Score should never exceed 1.0 even with extreme inputs."""
plan = " ".join(
["encoder decoder discriminator generator critic actor teacher student"] * 10
+ ["model.py trainer.py dataset.py multiple files modular"] * 10
+ ["multi-modal distributed GAN diffusion NeRF MoE meta-learning"] * 10
+ ["condition_1 condition_2 condition_3 ablation_4 variant_5 baseline"] * 10
+ ["custom layer custom loss wrapper registry hook callback"] * 10
)
result = score_complexity(plan, topic="everything", historical_failures=100)
assert 0.0 <= result.score <= 1.0
def test_domain_complexity_keywords(self):
plan = "Implement a physics-informed neural network (PINN) with neural ODE solver."
result = score_complexity(plan, topic="PINN for fluid dynamics")
assert result.signals["domain_complexity"] > 0
# ============================================================
# TestOpenCodeBridge
# ============================================================
class TestOpenCodeBridge:
"""Tests for the OpenCode bridge class."""
def test_check_available_returns_false_when_not_installed(self):
with patch(
"researchclaw.pipeline.opencode_bridge.shutil.which",
return_value=None,
):
assert OpenCodeBridge.check_available() is False
def test_check_available_returns_false_on_timeout(self):
with patch(
"researchclaw.pipeline.opencode_bridge.shutil.which",
return_value=r"C:\Users\tester\AppData\Roaming\npm\opencode.cmd",
), patch(
"researchclaw.pipeline.opencode_bridge.subprocess.run",
side_effect=subprocess.TimeoutExpired(cmd="opencode", timeout=15),
):
assert OpenCodeBridge.check_available() is False
def test_check_available_returns_true(self):
mock_result = MagicMock()
mock_result.returncode = 0
with patch(
"researchclaw.pipeline.opencode_bridge.shutil.which",
return_value=r"C:\Users\tester\AppData\Roaming\npm\opencode.cmd",
), patch(
"researchclaw.pipeline.opencode_bridge.subprocess.run",
return_value=mock_result,
) as run_mock:
assert OpenCodeBridge.check_available() is True
assert run_mock.call_args.args[0][0].endswith("opencode.cmd")
def test_workspace_creates_correct_files(self, tmp_path):
bridge = OpenCodeBridge(
model="gpt-5.2",
llm_base_url="https://example.com",
api_key_env="TEST_KEY",
)
ws = bridge._prepare_workspace(
stage_dir=tmp_path,
topic="Test topic",
exp_plan="plan: test",
metric="accuracy",
pkg_hint="torch available",
extra_guidance="Be careful",
time_budget_sec=300,
)
assert (ws / "EXPERIMENT_PLAN.yaml").exists()
assert (ws / "GUIDANCE.md").exists()
assert (ws / "opencode.json").exists()
guidance = (ws / "GUIDANCE.md").read_text()
assert "Test topic" in guidance
assert "accuracy" in guidance
def test_opencode_config_azure_format(self, tmp_path):
bridge = OpenCodeBridge(
model="gpt-5.2",
llm_base_url="https://huaxi.openai.azure.com/openai/v1",
api_key_env="AZURE_OPENAI_API_KEY",
llm_provider="azure",
)
ws = bridge._prepare_workspace(
stage_dir=tmp_path,
topic="t",
exp_plan="p",
metric="m",
pkg_hint="",
extra_guidance="",
time_budget_sec=300,
)
cfg = json.loads((ws / "opencode.json").read_text())
# Azure now uses the unified "openai" provider (Bearer token auth
# works on Azure endpoints and Responses API is supported)
assert cfg["model"] == "openai/gpt-5.2"
assert "provider" in cfg
assert "openai" in cfg["provider"]
assert cfg["provider"]["openai"]["options"]["baseURL"] == "https://huaxi.openai.azure.com/openai/v1"
assert "{env:AZURE_OPENAI_API_KEY}" in cfg["provider"]["openai"]["options"]["apiKey"]
def test_opencode_config_openai_format(self, tmp_path):
bridge = OpenCodeBridge(
model="gpt-4o",
llm_base_url="https://api.openai.com/v1",
api_key_env="OPENAI_API_KEY",
)
ws = bridge._prepare_workspace(
stage_dir=tmp_path,
topic="t",
exp_plan="p",
metric="m",
pkg_hint="",
extra_guidance="",
time_budget_sec=300,
)
cfg = json.loads((ws / "opencode.json").read_text())
assert cfg["model"] == "openai/gpt-4o"
assert "openai" in cfg["provider"]
def test_opencode_config_preserves_prefixed_model(self, tmp_path):
"""Model with '/' prefix (e.g. anthropic/...) should NOT get double-prefixed (BUG-C fix)."""
bridge = OpenCodeBridge(
model="anthropic/claude-sonnet-4-6",
llm_base_url="https://huaxi.openai.azure.com/openai/v1",
api_key_env="AZURE_API_KEY",
llm_provider="azure",
)
ws = bridge._prepare_workspace(
stage_dir=tmp_path,
topic="t",
exp_plan="p",
metric="m",
pkg_hint="",
extra_guidance="",
time_budget_sec=300,
)
cfg = json.loads((ws / "opencode.json").read_text())
# Should be "anthropic/claude-sonnet-4-6", NOT "azure/anthropic/claude-sonnet-4-6"
assert cfg["model"] == "anthropic/claude-sonnet-4-6"
def test_resolve_model_azure_uses_openai_prefix(self):
"""Azure endpoint → uses openai/ prefix (Azure supports Responses API now)."""
bridge = OpenCodeBridge(
model="gpt-5.2",
llm_base_url="https://huaxi.openai.azure.com/openai/v1",
llm_provider="azure",
)
resolved = bridge._resolve_opencode_model()
assert resolved == "openai/gpt-5.2"
def test_resolve_model_preserves_explicit_prefix(self):
"""Model with '/' prefix should be used as-is regardless of provider."""
bridge = OpenCodeBridge(
model="anthropic/claude-sonnet-4-6",
llm_base_url="https://huaxi.openai.azure.com/openai/v1",
llm_provider="azure",
)
resolved = bridge._resolve_opencode_model()
assert resolved == "anthropic/claude-sonnet-4-6"
def test_resolve_model_no_model_default(self):
"""Empty model string → default Anthropic model."""
bridge = OpenCodeBridge()
assert bridge._resolve_opencode_model() == "anthropic/claude-sonnet-4-6"
def test_collect_files_ignores_pycache(self, tmp_path):
(tmp_path / "main.py").write_text("print('hello')")
pycache = tmp_path / "__pycache__"
pycache.mkdir()
(pycache / "main.cpython-311.pyc").write_text("bytecode")
# Also write a .py in pycache to test filtering
(pycache / "cached.py").write_text("cached")
files = OpenCodeBridge._collect_files(tmp_path)
assert "main.py" in files
assert not any("__pycache__" in k for k in files)
def test_collect_files_includes_requirements(self, tmp_path):
(tmp_path / "main.py").write_text("import torch")
(tmp_path / "requirements.txt").write_text("torch>=2.0")
files = OpenCodeBridge._collect_files(tmp_path)
assert "requirements.txt" in files
assert "main.py" in files
def test_collect_files_flattens_subdirectories(self, tmp_path):
"""Files in subdirs should be flattened to basenames (BUG-D fix)."""
src = tmp_path / "src"
src.mkdir()
(src / "model.py").write_text("class Model: pass")
(src / "utils.py").write_text("def helper(): pass")
(tmp_path / "main.py").write_text("from model import Model")
files = OpenCodeBridge._collect_files(tmp_path)
# Keys should be flat basenames, not paths like "src/model.py"
assert "model.py" in files
assert "utils.py" in files
assert "main.py" in files
assert not any("/" in k for k in files)
def test_collect_files_root_takes_priority_over_subdir(self, tmp_path):
"""Root-level file wins when basename collides with subdir file."""
(tmp_path / "main.py").write_text("root version")
sub = tmp_path / "src"
sub.mkdir()
(sub / "main.py").write_text("subdir version")
files = OpenCodeBridge._collect_files(tmp_path)
assert files["main.py"] == "root version"
def test_generate_returns_error_on_not_installed(self, tmp_path):
bridge = OpenCodeBridge()
with patch.object(OpenCodeBridge, "check_available", return_value=False):
result = bridge.generate(
stage_dir=tmp_path,
topic="test",
exp_plan="plan",
metric="acc",
)
assert not result.success
assert "not installed" in result.error
def test_generate_returns_error_on_cli_failure(self, tmp_path):
bridge = OpenCodeBridge(max_retries=0, workspace_cleanup=True)
with patch.object(OpenCodeBridge, "check_available", return_value=True), \
patch.object(
bridge,
"_invoke_opencode",
return_value=(False, "CLI error", 1.5),
):
result = bridge.generate(
stage_dir=tmp_path,
topic="test",
exp_plan="plan",
metric="acc",
)
assert not result.success
assert "failed" in result.error.lower()
def test_generate_success(self, tmp_path):
bridge = OpenCodeBridge(max_retries=0, workspace_cleanup=False)
def fake_invoke(workspace, prompt):
# Write main.py into the workspace to simulate OpenCode output
(workspace / "main.py").write_text("print('acc: 0.95')")
(workspace / "requirements.txt").write_text("torch")
return True, "success", 5.0
with patch.object(OpenCodeBridge, "check_available", return_value=True), \
patch.object(bridge, "_invoke_opencode", side_effect=fake_invoke):
result = bridge.generate(
stage_dir=tmp_path,
topic="test",
exp_plan="plan",
metric="acc",
)
assert result.success
assert "main.py" in result.files
assert result.elapsed_sec == 5.0
def test_invoke_opencode_uses_resolved_path(self, tmp_path):
bridge = OpenCodeBridge(model="gpt-5.2", timeout_sec=10)
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "{}"
mock_result.stderr = ""
with patch(
"researchclaw.pipeline.opencode_bridge.shutil.which",
return_value=r"C:\Users\tester\AppData\Roaming\npm\opencode.cmd",
), patch(
"researchclaw.pipeline.opencode_bridge.subprocess.run",
return_value=mock_result,
) as run_mock:
success, _log, _elapsed = bridge._invoke_opencode(tmp_path, "test prompt")
assert success is True
assert run_mock.call_args.args[0][0].endswith("opencode.cmd")
# ============================================================
# TestEnsureMainEntryPoint (BUG-R52-01)
# ============================================================
class TestHasMainGuard:
"""Tests for _has_main_guard static method."""
def test_with_guard(self):
code = 'def main():\n pass\n\nif __name__ == "__main__":\n main()\n'
assert OpenCodeBridge._has_main_guard(code) is True
def test_without_guard(self):
code = "def main():\n pass\n"
assert OpenCodeBridge._has_main_guard(code) is False
def test_syntax_error(self):
assert OpenCodeBridge._has_main_guard("def broken(") is False
def test_empty(self):
assert OpenCodeBridge._has_main_guard("") is False
def test_single_quote_guard(self):
code = "if __name__ == '__main__':\n print('hi')\n"
assert OpenCodeBridge._has_main_guard(code) is True
class TestEnsureMainEntryPoint:
"""Tests for _ensure_main_entry_point — BUG-R52-01 fix."""
def test_already_has_guard_unchanged(self):
files = {
"main.py": 'def run():\n pass\n\nif __name__ == "__main__":\n run()\n',
"utils.py": "def helper(): pass\n",
}
result = OpenCodeBridge._ensure_main_entry_point(files)
assert result is files # Same object, unchanged
def test_no_main_py_unchanged(self):
files = {"utils.py": "def helper(): pass\n"}
result = OpenCodeBridge._ensure_main_entry_point(files)
assert result is files
def test_swap_entry_point_from_other_file(self):
"""When main.py is library-only and another file has __main__, swap."""
lib_code = "class Model:\n pass\n\ndef train(model):\n pass\n"
entry_code = (
'from main import Model, train\n\n'
'if __name__ == "__main__":\n'
' m = Model()\n'
' train(m)\n'
)
files = {
"main.py": lib_code,
"run_experiment.py": entry_code,
}
result = OpenCodeBridge._ensure_main_entry_point(files)
# main.py should now contain the entry point code
assert '__main__' in result["main.py"]
# The old main.py content should be in run_experiment.py
assert result["run_experiment.py"] == lib_code
def test_inject_entry_for_main_function(self):
"""When main.py defines main() but no guard, inject one."""
code = "import torch\n\ndef main():\n print('training')\n"
files = {"main.py": code}
result = OpenCodeBridge._ensure_main_entry_point(files)
assert '__main__' in result["main.py"]
assert "main()" in result["main.py"]
def test_inject_entry_for_run_function(self):
"""Should also detect run(), train(), etc."""
code = "def run_experiment():\n print('running')\n"
files = {"main.py": code}
result = OpenCodeBridge._ensure_main_entry_point(files)
assert '__main__' in result["main.py"]
assert "run_experiment()" in result["main.py"]
def test_no_known_entry_function_warns(self):
"""When no known entry function exists, return unchanged with warning."""
code = "class Config:\n x = 1\n\nclass Trainer:\n pass\n"
files = {"main.py": code}
result = OpenCodeBridge._ensure_main_entry_point(files)
# Should return unchanged since no entry function found
assert result["main.py"] == code
def test_non_py_files_not_checked(self):
"""requirements.txt and setup.py should not be checked for __main__."""
lib_code = "class Model:\n pass\n"
files = {
"main.py": lib_code,
"requirements.txt": "torch>=2.0\n",
"setup.py": "# setup\n",
}
result = OpenCodeBridge._ensure_main_entry_point(files)
# No swap should occur — only .py files are checked
assert result["main.py"] == lib_code
def test_swap_preserves_other_files(self):
"""Swapping should not lose any files from the dict."""
files = {
"main.py": "class Lib: pass\n",
"run.py": 'if __name__ == "__main__":\n print("go")\n',
"utils.py": "def helper(): pass\n",
"requirements.txt": "numpy\n",
}
result = OpenCodeBridge._ensure_main_entry_point(files)
assert len(result) == len(files)
assert "utils.py" in result
assert "requirements.txt" in result
# ============================================================
# TestOpenCodeConfig
# ============================================================
class TestOpenCodeConfig:
"""Tests for OpenCodeConfig dataclass and parser."""
def test_default_values(self):
cfg = OpenCodeConfig()
assert cfg.enabled is True
assert cfg.auto is True
assert cfg.complexity_threshold == 0.2
assert cfg.model == ""
assert cfg.timeout_sec == 600
assert cfg.max_retries == 1
assert cfg.workspace_cleanup is True
def test_parse_from_dict(self):
data = {
"enabled": True,
"auto": True,
"complexity_threshold": 0.5,
"model": "gpt-5.2",
"timeout_sec": 900,
"max_retries": 2,
"workspace_cleanup": False,
}
cfg = _parse_opencode_config(data)
assert cfg.enabled is True
assert cfg.auto is True
assert cfg.complexity_threshold == 0.5
assert cfg.model == "gpt-5.2"
assert cfg.timeout_sec == 900
assert cfg.max_retries == 2
assert cfg.workspace_cleanup is False
def test_empty_dict_returns_default(self):
cfg = _parse_opencode_config({})
assert cfg == OpenCodeConfig()
# ============================================================
# TestCountHistoricalFailures
# ============================================================
class TestCountHistoricalFailures:
def test_no_failures(self, tmp_path):
assert count_historical_failures(tmp_path) == 0
def test_counts_beast_mode_failures(self, tmp_path):
d = tmp_path / "stage-10_001"
d.mkdir()
(d / "beast_mode_log.json").write_text(json.dumps({"success": False}))
assert count_historical_failures(tmp_path) >= 1
def test_counts_validation_failures(self, tmp_path):
d = tmp_path / "stage-10_002"
d.mkdir()
(d / "validation_report.md").write_text("**Status**: FAILED after 5 repairs")
assert count_historical_failures(tmp_path) >= 1
def test_deduplicates_multiple_failure_indicators(self, tmp_path):
"""Same dir with beast_mode_log + stage_health + validation_report = 1 failure (BUG-E fix)."""
d = tmp_path / "stage-10_003"
d.mkdir()
(d / "beast_mode_log.json").write_text(json.dumps({"success": False}))
(d / "stage_health.json").write_text(json.dumps({"status": "FAILED"}))
(d / "validation_report.md").write_text("FAILED after 3 repairs")
assert count_historical_failures(tmp_path) == 1
================================================
FILE: tests/test_overleaf.py
================================================
"""Tests for Overleaf sync (C4): Sync engine, Conflict resolver, Watcher, Formatter."""
from __future__ import annotations
import textwrap
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.overleaf.sync import OverleafSync
from researchclaw.overleaf.conflict import ConflictResolver, _extract_conflicts, _resolve_content
from researchclaw.overleaf.watcher import FileWatcher
from researchclaw.overleaf.formatter import LatexFormatter
# ══════════════════════════════════════════════════════════════════
# ConflictResolver tests
# ══════════════════════════════════════════════════════════════════
class TestConflictResolver:
def test_no_conflicts(self, tmp_path: Path) -> None:
(tmp_path / "paper.tex").write_text("\\section{Intro}\nHello world\n")
resolver = ConflictResolver()
assert not resolver.has_conflicts(tmp_path)
def test_has_conflicts(self, tmp_path: Path) -> None:
content = textwrap.dedent("""\
\\section{Intro}
<<<<<<< HEAD
Our method is great.
=======
Our method is good.
>>>>>>> remote
""")
(tmp_path / "paper.tex").write_text(content)
resolver = ConflictResolver()
assert resolver.has_conflicts(tmp_path)
def test_detect_conflicts(self, tmp_path: Path) -> None:
content = textwrap.dedent("""\
<<<<<<< HEAD
line A
=======
line B
>>>>>>> remote
""")
(tmp_path / "main.tex").write_text(content)
resolver = ConflictResolver()
conflicts = resolver.detect(tmp_path)
assert len(conflicts) == 1
assert conflicts[0]["ours"] == "line A"
assert conflicts[0]["theirs"] == "line B"
def test_resolve_ours(self, tmp_path: Path) -> None:
content = textwrap.dedent("""\
\\section{Intro}
<<<<<<< HEAD
AI version
=======
Human version
>>>>>>> remote
\\section{End}
""")
(tmp_path / "paper.tex").write_text(content)
resolver = ConflictResolver()
resolved = resolver.resolve(tmp_path, strategy="ours")
assert len(resolved) == 1
text = (tmp_path / "paper.tex").read_text()
assert "AI version" in text
assert "Human version" not in text
assert "<<<<<<" not in text
def test_resolve_theirs(self, tmp_path: Path) -> None:
content = textwrap.dedent("""\
<<<<<<< HEAD
AI text
=======
Human text
>>>>>>> remote
""")
(tmp_path / "paper.tex").write_text(content)
resolver = ConflictResolver()
resolver.resolve(tmp_path, strategy="theirs")
text = (tmp_path / "paper.tex").read_text()
assert "Human text" in text
assert "AI text" not in text
def test_multiple_conflicts(self, tmp_path: Path) -> None:
content = textwrap.dedent("""\
<<<<<<< HEAD
A1
=======
B1
>>>>>>> remote
middle
<<<<<<< HEAD
A2
=======
B2
>>>>>>> remote
""")
(tmp_path / "paper.tex").write_text(content)
resolver = ConflictResolver()
conflicts = resolver.detect(tmp_path)
assert len(conflicts) == 2
class TestConflictHelpers:
def test_extract_conflicts_empty(self) -> None:
assert _extract_conflicts("no conflicts here") == []
def test_resolve_content_ours(self) -> None:
content = "<<<<<<< HEAD\nours\n=======\ntheirs\n>>>>>>> remote\n"
resolved = _resolve_content(content, "ours")
assert "ours" in resolved
assert "theirs" not in resolved
def test_resolve_content_theirs(self) -> None:
content = "<<<<<<< HEAD\nours\n=======\ntheirs\n>>>>>>> remote\n"
resolved = _resolve_content(content, "theirs")
assert "theirs" in resolved
assert "ours" not in resolved
# ══════════════════════════════════════════════════════════════════
# FileWatcher tests
# ══════════════════════════════════════════════════════════════════
class TestFileWatcher:
def test_no_changes_initially(self, tmp_path: Path) -> None:
(tmp_path / "paper.tex").write_text("content")
watcher = FileWatcher(tmp_path)
assert watcher.check_changes() == []
def test_detect_new_file(self, tmp_path: Path) -> None:
watcher = FileWatcher(tmp_path)
(tmp_path / "new.tex").write_text("new content")
changes = watcher.check_changes()
assert "new.tex" in changes
def test_detect_modified_file(self, tmp_path: Path) -> None:
f = tmp_path / "paper.tex"
f.write_text("v1")
watcher = FileWatcher(tmp_path)
# Modify
import time
time.sleep(0.05)
f.write_text("v2")
changes = watcher.check_changes()
assert "paper.tex" in changes
def test_detect_deleted_file(self, tmp_path: Path) -> None:
f = tmp_path / "paper.tex"
f.write_text("content")
watcher = FileWatcher(tmp_path)
f.unlink()
changes = watcher.check_changes()
assert "paper.tex" in changes
def test_only_watches_extensions(self, tmp_path: Path) -> None:
watcher = FileWatcher(tmp_path, extensions=(".tex",))
(tmp_path / "readme.md").write_text("markdown")
changes = watcher.check_changes()
assert changes == []
def test_nonexistent_dir(self, tmp_path: Path) -> None:
watcher = FileWatcher(tmp_path / "nonexistent")
assert watcher.check_changes() == []
# ══════════════════════════════════════════════════════════════════
# LatexFormatter tests
# ══════════════════════════════════════════════════════════════════
class TestLatexFormatter:
def test_normalize_paths(self) -> None:
content = r"\includegraphics[width=0.5\textwidth]{/home/user/artifacts/rc-123/figures/plot.png}"
result = LatexFormatter.normalize_paths(content)
assert "figures/plot.png" in result
assert "/home/user" not in result
def test_ensure_document_class_adds(self) -> None:
content = "\\begin{document}\nHello\n\\end{document}"
result = LatexFormatter.ensure_document_class(content)
assert "\\documentclass" in result
def test_ensure_document_class_noop(self) -> None:
content = "\\documentclass{article}\n\\begin{document}\nHello\n\\end{document}"
result = LatexFormatter.ensure_document_class(content)
assert result.count("\\documentclass") == 1
def test_strip_local_comments(self) -> None:
content = "Normal line\n% RESEARCHCLAW: internal note\nAnother line\n"
result = LatexFormatter.strip_local_comments(content)
assert "RESEARCHCLAW" not in result
assert "Normal line" in result
assert "Another line" in result
def test_fix_encoding(self) -> None:
content = "\\documentclass{article}\n\\begin{document}\n"
result = LatexFormatter.fix_encoding(content)
assert "\\usepackage[utf8]{inputenc}" in result
def test_fix_encoding_noop(self) -> None:
content = "\\documentclass{article}\n\\usepackage[utf8]{inputenc}\n\\begin{document}\n"
result = LatexFormatter.fix_encoding(content)
assert result.count("inputenc") == 1
def test_format_for_overleaf(self, tmp_path: Path) -> None:
tex = tmp_path / "paper.tex"
tex.write_text("\\documentclass{article}\n% RESEARCHCLAW: test\n\\begin{document}\nHello\n\\end{document}\n")
formatter = LatexFormatter()
result = formatter.format_for_overleaf(tex)
assert "RESEARCHCLAW" not in result
assert "inputenc" in result
# ══════════════════════════════════════════════════════════════════
# OverleafSync tests (mock git)
# ══════════════════════════════════════════════════════════════════
class TestOverleafSync:
def test_init(self) -> None:
sync = OverleafSync(git_url="https://git.overleaf.com/abc123")
assert sync.git_url == "https://git.overleaf.com/abc123"
assert sync.branch == "main"
assert sync.local_dir is None
def test_get_status_before_setup(self) -> None:
sync = OverleafSync(git_url="https://git.overleaf.com/abc123")
status = sync.get_status()
assert status["local_dir"] is None
assert status["last_sync"] is None
def test_push_before_setup_raises(self, tmp_path: Path) -> None:
sync = OverleafSync(git_url="https://git.overleaf.com/abc123")
with pytest.raises(RuntimeError, match="setup"):
sync.push_paper(tmp_path / "paper.tex")
def test_pull_before_setup_raises(self) -> None:
sync = OverleafSync(git_url="https://git.overleaf.com/abc123")
with pytest.raises(RuntimeError, match="setup"):
sync.pull_changes()
def test_resolve_before_setup_raises(self) -> None:
sync = OverleafSync(git_url="https://git.overleaf.com/abc123")
with pytest.raises(RuntimeError, match="setup"):
sync.resolve_conflicts()
@patch("researchclaw.overleaf.sync.subprocess.run")
def test_setup_clones(self, mock_run: MagicMock, tmp_path: Path) -> None:
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
sync = OverleafSync(git_url="https://git.overleaf.com/abc123")
local = sync.setup(tmp_path)
assert local == tmp_path / "overleaf_repo"
# git clone was called
mock_run.assert_called_once()
args = mock_run.call_args[0][0]
assert "clone" in args
================================================
FILE: tests/test_paper_verifier.py
================================================
"""Tests for paper_verifier — post-generation fabrication detection."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.pipeline.paper_verifier import (
VerificationResult,
verify_paper,
)
from researchclaw.pipeline.verified_registry import VerifiedRegistry
ARTIFACTS = Path(__file__).resolve().parent.parent / "artifacts"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_registry(**kwargs) -> VerifiedRegistry:
summary = {"best_run": {"metrics": {}}, "condition_summaries": {}, "metrics_summary": {}}
conditions = kwargs.get("conditions", {})
for cond_name, seeds in conditions.items():
for seed_idx, value in seeds.items():
summary["best_run"]["metrics"][f"{cond_name}/{seed_idx}/metric"] = value
mean_val = sum(seeds.values()) / len(seeds)
summary["condition_summaries"][cond_name] = {"metrics": {"metric": mean_val}}
pm = kwargs.get("primary_metric")
if pm is not None:
summary["best_run"]["metrics"]["primary_metric"] = pm
return VerifiedRegistry.from_experiment(summary)
# ---------------------------------------------------------------------------
# Unit tests — clean paper
# ---------------------------------------------------------------------------
class TestCleanPaper:
def test_all_numbers_verified_passes(self):
reg = _make_registry(
conditions={"Baseline": {0: 80.0, 1: 82.0}, "Proposed": {0: 90.0, 1: 92.0}},
primary_metric=91.0,
)
tex = r"""
\section{Results}
Our proposed method achieves 91.0000 on the primary metric,
compared to 81.0000 for the baseline.
\begin{table}[htbp]
\centering
\begin{tabular}{lcc}
\toprule
Method & Metric & $n$ \\
\midrule
Baseline & 81.0000 $\pm$ 1.4142 & 2 \\
Proposed & 91.0000 $\pm$ 1.4142 & 2 \\
\bottomrule
\end{tabular}
\end{table}
"""
result = verify_paper(tex, reg)
assert result.severity == "PASS"
assert result.strict_violations == 0
def test_common_constants_allowed(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Experimental Setup}
We use a batch size of 64 and train for 100 epochs
with a learning rate of 0.001.
"""
result = verify_paper(tex, reg)
assert result.severity == "PASS"
def test_year_numbers_allowed(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Introduction}
Following the work of Smith et al. (2025), we propose...
"""
result = verify_paper(tex, reg)
assert result.severity == "PASS"
# ---------------------------------------------------------------------------
# Unit tests — fabricated numbers
# ---------------------------------------------------------------------------
class TestFabricatedNumbers:
def test_fabricated_in_results_rejects(self):
reg = _make_registry(
conditions={"Baseline": {0: 80.0}, "Proposed": {0: 90.0}},
)
tex = r"""
\section{Results}
Our method achieves 95.5 accuracy.
"""
result = verify_paper(tex, reg)
assert result.severity == "REJECT"
assert result.strict_violations >= 1
assert any(abs(u.value - 95.5) < 0.01 for u in result.unverified_numbers)
def test_fabricated_in_table_rejects(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Results}
\begin{table}[h]
\begin{tabular}{lc}
A & 85.3 \\
\end{tabular}
\end{table}
"""
result = verify_paper(tex, reg)
assert result.severity == "REJECT"
def test_fabricated_in_discussion_warns(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Discussion}
Compared to prior work reporting 95.5 accuracy, our result is lower.
"""
result = verify_paper(tex, reg)
# In Discussion → warning, not reject
assert result.severity == "WARN"
assert result.lenient_violations >= 1
def test_numbers_in_cite_skipped(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Results}
As shown by \cite{smith2025deep}, our method works.
"""
result = verify_paper(tex, reg)
assert result.severity == "PASS"
def test_numbers_in_comments_skipped(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Results}
% This is a comment with fake number 99.99
Our method achieves 80.0.
"""
result = verify_paper(tex, reg)
assert result.severity == "PASS"
# ---------------------------------------------------------------------------
# Unit tests — fabricated conditions
# ---------------------------------------------------------------------------
class TestFabricatedConditions:
def test_unknown_condition_in_table(self):
reg = _make_registry(conditions={"DQN": {0: 80.0}, "DQN+Abstraction": {0: 90.0}})
tex = r"""
\section{Results}
\begin{table}[h]
\begin{tabular}{lc}
DQN & 80.0 \\
DQN+Abstraction & 90.0 \\
PPO & 75.0 \\
\end{tabular}
\end{table}
"""
result = verify_paper(tex, reg)
assert len(result.fabricated_conditions) >= 1
assert any(fc.name == "PPO" for fc in result.fabricated_conditions)
assert result.severity == "REJECT"
# ---------------------------------------------------------------------------
# Unit tests — fabrication rate
# ---------------------------------------------------------------------------
class TestFabricationRate:
def test_rate_zero_for_clean_paper(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Results}
Accuracy is 80.0.
"""
result = verify_paper(tex, reg)
assert result.fabrication_rate == 0.0
def test_rate_nonzero_for_fabricated(self):
reg = _make_registry(conditions={"A": {0: 80.0}})
tex = r"""
\section{Results}
Accuracy is 99.99 and loss is 45.67.
"""
result = verify_paper(tex, reg)
assert result.fabrication_rate > 0.0
# ---------------------------------------------------------------------------
# Integration — real fabricated papers
# ---------------------------------------------------------------------------
class TestRealPapers:
def _load(self, run_id: str) -> tuple[str, VerifiedRegistry]:
pattern = f"rc-*-{run_id}"
matches = sorted(ARTIFACTS.glob(pattern))
if not matches:
pytest.skip(f"Artifact {run_id} not found")
base = matches[0]
tex_path = base / "stage-22" / "paper.tex"
summary_path = base / "stage-14" / "experiment_summary.json"
ref_path = base / "stage-13" / "refinement_log.json"
if not tex_path.exists() or not summary_path.exists():
pytest.skip(f"Missing files for {run_id}")
tex = tex_path.read_text(encoding="utf-8")
summary = json.loads(summary_path.read_text())
ref_log = None
if ref_path.exists():
ref_log = json.loads(ref_path.read_text())
reg = VerifiedRegistry.from_experiment(summary, ref_log)
return tex, reg
def test_run_e57360_severe_fabrication_detected(self):
"""Run 38 (LACE) — audit found SEVERE fabrication.
The verifier should REJECT this paper."""
tex, reg = self._load("e57360")
result = verify_paper(tex, reg)
assert result.severity == "REJECT", (
f"Expected REJECT for severely fabricated paper, got {result.severity}. "
f"Unverified: {len(result.unverified_numbers)}, "
f"Fabricated conditions: {[fc.name for fc in result.fabricated_conditions]}"
)
def test_run_6a1ec9_severe_fabrication_detected(self):
"""Run 6a1ec9 (FAME) — audit found SEVERE fabrication."""
tex, reg = self._load("6a1ec9")
result = verify_paper(tex, reg)
assert result.severity == "REJECT"
def test_run_85fefc_fabrication_detected(self):
"""Run 85fefc (CRAFT) — audit found SEVERE fabrication."""
tex, reg = self._load("85fefc")
result = verify_paper(tex, reg)
# Should detect at least some issues
assert len(result.unverified_numbers) > 0 or len(result.fabricated_conditions) > 0
def test_run_acbdfa_moderate_fabrication(self):
"""Run acbdfa (CTS) — audit found MODERATE fabrication."""
tex, reg = self._load("acbdfa")
result = verify_paper(tex, reg)
# May or may not reject (moderate case), but should find issues
assert len(result.unverified_numbers) > 0 or result.lenient_violations > 0
================================================
FILE: tests/test_project_manager.py
================================================
"""Tests for multi-project management (C1): ProjectManager, ProjectScheduler, IdeaPool."""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
import pytest
from researchclaw.project.models import Idea, Project
from researchclaw.project.manager import ProjectManager
from researchclaw.project.scheduler import ProjectScheduler
from researchclaw.project.idea_pool import IdeaPool
# ── fixtures ──────────────────────────────────────────────────────
@pytest.fixture
def tmp_projects(tmp_path: Path) -> Path:
return tmp_path / "projects"
@pytest.fixture
def manager(tmp_projects: Path) -> ProjectManager:
return ProjectManager(tmp_projects)
@pytest.fixture
def config_yaml(tmp_path: Path) -> Path:
cfg = tmp_path / "config.yaml"
cfg.write_text("project:\n name: test\nresearch:\n topic: test\n")
return cfg
@pytest.fixture
def pool_path(tmp_path: Path) -> Path:
return tmp_path / "ideas.json"
# ══════════════════════════════════════════════════════════════════
# Project model tests
# ══════════════════════════════════════════════════════════════════
class TestProjectModel:
def test_to_dict_roundtrip(self) -> None:
p = Project(name="test", config_path="/a/b", run_dir="/c/d", topic="ml")
d = p.to_dict()
p2 = Project.from_dict(d)
assert p2.name == p.name
assert p2.topic == p.topic
assert p2.status == "idle"
def test_from_dict_defaults(self) -> None:
d = {"name": "x", "config_path": "/a", "run_dir": "/b"}
p = Project.from_dict(d)
assert p.status == "idle"
assert p.last_run_id is None
def test_from_dict_with_iso_date(self) -> None:
d = {
"name": "x",
"config_path": "/a",
"run_dir": "/b",
"created_at": "2024-01-01T00:00:00+00:00",
}
p = Project.from_dict(d)
assert p.created_at.year == 2024
# ══════════════════════════════════════════════════════════════════
# Idea model tests
# ══════════════════════════════════════════════════════════════════
class TestIdeaModel:
def test_score_calculation(self) -> None:
idea = Idea(id="1", title="t", description="d", feasibility=1.0, novelty=1.0)
assert idea.score == pytest.approx(1.0)
def test_score_weighted(self) -> None:
idea = Idea(id="1", title="t", description="d", feasibility=0.5, novelty=0.5)
assert idea.score == pytest.approx(0.5)
def test_to_dict_roundtrip(self) -> None:
idea = Idea(id="abc", title="GNN", description="graph stuff", domains=["ml"])
d = idea.to_dict()
i2 = Idea.from_dict(d)
assert i2.id == "abc"
assert i2.domains == ["ml"]
# ══════════════════════════════════════════════════════════════════
# ProjectManager tests
# ══════════════════════════════════════════════════════════════════
class TestProjectManager:
def test_create_project(self, manager: ProjectManager, config_yaml: Path) -> None:
proj = manager.create("my_project", str(config_yaml), topic="RL")
assert proj.name == "my_project"
assert proj.topic == "RL"
assert proj.status == "idle"
def test_create_sets_active(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("first", str(config_yaml))
assert manager.active is not None
assert manager.active.name == "first"
def test_create_duplicate_raises(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("dup", str(config_yaml))
with pytest.raises(ValueError, match="already exists"):
manager.create("dup", str(config_yaml))
def test_delete_project(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("del_me", str(config_yaml))
manager.delete("del_me")
assert "del_me" not in manager.projects
def test_delete_unknown_raises(self, manager: ProjectManager) -> None:
with pytest.raises(KeyError):
manager.delete("nonexistent")
def test_get_project(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("proj1", str(config_yaml))
p = manager.get("proj1")
assert p.name == "proj1"
def test_get_unknown_raises(self, manager: ProjectManager) -> None:
with pytest.raises(KeyError):
manager.get("nope")
def test_list_all_sorted(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("b_proj", str(config_yaml))
manager.create("a_proj", str(config_yaml))
projects = manager.list_all()
assert len(projects) == 2
# Sorted by creation time (b_proj first)
assert projects[0].name == "b_proj"
def test_get_status(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("s1", str(config_yaml))
manager.create("s2", str(config_yaml))
status = manager.get_status()
assert status["total"] == 2
assert status["active"] == "s1"
def test_switch_project(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("a", str(config_yaml))
manager.create("b", str(config_yaml))
manager.switch("b")
assert manager.active is not None
assert manager.active.name == "b"
def test_switch_unknown_raises(self, manager: ProjectManager) -> None:
with pytest.raises(KeyError):
manager.switch("ghost")
def test_compare_projects(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("pa", str(config_yaml))
manager.create("pb", str(config_yaml))
manager.projects["pa"].metrics = {"acc": 0.9}
manager.projects["pb"].metrics = {"acc": 0.95}
result = manager.compare("pa", "pb")
assert "metric_diff" in result
assert result["metric_diff"]["acc"]["delta"] == pytest.approx(0.05)
def test_start_run(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("run_proj", str(config_yaml))
rid = manager.start_run("run_proj", run_id="rc-123")
assert rid == "rc-123"
assert manager.get("run_proj").status == "running"
def test_finish_run(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("fin_proj", str(config_yaml))
manager.start_run("fin_proj", run_id="rc-456")
manager.finish_run("fin_proj", "completed", {"acc": 0.88})
p = manager.get("fin_proj")
assert p.status == "completed"
assert p.metrics["acc"] == 0.88
def test_registry_persistence(self, tmp_projects: Path, config_yaml: Path) -> None:
m1 = ProjectManager(tmp_projects)
m1.create("persist", str(config_yaml), topic="persistence")
# Load from disk
m2 = ProjectManager(tmp_projects)
assert "persist" in m2.projects
assert m2.projects["persist"].topic == "persistence"
def test_delete_switches_active(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("first", str(config_yaml))
manager.create("second", str(config_yaml))
manager.switch("first")
manager.delete("first")
# Should switch active to remaining project
assert manager.active is not None
assert manager.active.name == "second"
def test_config_copied_to_project_dir(self, manager: ProjectManager, config_yaml: Path) -> None:
proj = manager.create("copy_test", str(config_yaml))
copied = Path(proj.config_path)
assert copied.exists()
assert "test" in copied.read_text()
# ══════════════════════════════════════════════════════════════════
# ProjectScheduler tests
# ══════════════════════════════════════════════════════════════════
class TestProjectScheduler:
def test_enqueue_and_next(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("proj", str(config_yaml))
sched = ProjectScheduler(manager, max_concurrent=1)
sched.enqueue("proj")
name = sched.next()
assert name == "proj"
def test_concurrency_limit(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("a", str(config_yaml))
manager.create("b", str(config_yaml))
sched = ProjectScheduler(manager, max_concurrent=1)
sched.enqueue("a")
sched.enqueue("b")
sched.next() # starts "a"
assert sched.next() is None # can't start "b"
def test_mark_done_frees_slot(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("a", str(config_yaml))
manager.create("b", str(config_yaml))
sched = ProjectScheduler(manager, max_concurrent=1)
sched.enqueue("a")
sched.enqueue("b")
sched.next() # starts "a"
sched.mark_done("a")
name = sched.next()
assert name == "b"
def test_priority_order(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("low", str(config_yaml))
manager.create("high", str(config_yaml))
sched = ProjectScheduler(manager, max_concurrent=2)
sched.enqueue("low", priority=10)
sched.enqueue("high", priority=1)
# Higher priority (lower number) first
assert sched.next() == "high"
assert sched.next() == "low"
def test_enqueue_unknown_raises(self, manager: ProjectManager) -> None:
sched = ProjectScheduler(manager)
with pytest.raises(KeyError):
sched.enqueue("ghost")
def test_duplicate_enqueue_ignored(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("dup", str(config_yaml))
sched = ProjectScheduler(manager)
sched.enqueue("dup")
sched.enqueue("dup")
assert sched.queue_size == 1
def test_get_status(self, manager: ProjectManager, config_yaml: Path) -> None:
manager.create("s", str(config_yaml))
sched = ProjectScheduler(manager, max_concurrent=3)
sched.enqueue("s")
status = sched.get_status()
assert status["max_concurrent"] == 3
assert status["queue_size"] == 1
def test_can_start_empty_queue(self, manager: ProjectManager) -> None:
sched = ProjectScheduler(manager)
assert not sched.can_start()
# ══════════════════════════════════════════════════════════════════
# IdeaPool tests
# ══════════════════════════════════════════════════════════════════
class TestIdeaPool:
def test_add_idea(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
idea = pool.add("GNN for proteins", "Apply GNN to protein folding", ["bio", "ml"])
assert idea.title == "GNN for proteins"
assert len(idea.id) == 8
def test_remove_idea(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
idea = pool.add("remove me", "desc")
pool.remove(idea.id)
assert idea.id not in pool.ideas
def test_remove_unknown_raises(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
with pytest.raises(KeyError):
pool.remove("nonexistent")
def test_get_idea(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
idea = pool.add("get me", "desc")
retrieved = pool.get(idea.id)
assert retrieved.title == "get me"
def test_evaluate(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
idea = pool.add("eval", "desc")
result = pool.evaluate(idea.id, feasibility=0.8, novelty=0.9)
assert result["feasibility"] == 0.8
assert result["novelty"] == 0.9
assert pool.get(idea.id).status == "evaluated"
def test_evaluate_clamps_values(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
idea = pool.add("clamp", "desc")
pool.evaluate(idea.id, feasibility=1.5, novelty=-0.5)
assert pool.get(idea.id).feasibility == 1.0
assert pool.get(idea.id).novelty == 0.0
def test_rank(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
pool.add("low", "desc")
pool.add("high", "desc")
pool.evaluate(pool.list_all()[0].id, 0.1, 0.1)
pool.evaluate(pool.list_all()[1].id, 0.9, 0.9)
ranked = pool.rank()
assert ranked[0].score > ranked[1].score
def test_list_all(self, pool_path: Path) -> None:
pool = IdeaPool(pool_path)
pool.add("a", "desc")
pool.add("b", "desc")
assert len(pool.list_all()) == 2
def test_persistence(self, pool_path: Path) -> None:
pool1 = IdeaPool(pool_path)
pool1.add("persist", "desc", ["ml"])
pool2 = IdeaPool(pool_path)
assert len(pool2.ideas) == 1
assert list(pool2.ideas.values())[0].title == "persist"
def test_to_project(self, pool_path: Path, tmp_path: Path, config_yaml: Path) -> None:
pool = IdeaPool(pool_path)
idea = pool.add("my idea", "a nice description")
projects_dir = tmp_path / "projects"
proj = pool.to_project(idea.id, str(config_yaml), projects_dir)
assert proj.topic == "a nice description"
assert pool.get(idea.id).status == "planned"
================================================
FILE: tests/test_prompt_adapter.py
================================================
"""Tests for domain-aware prompt adapters."""
from __future__ import annotations
import pytest
from researchclaw.domains.detector import DomainProfile, get_profile, get_generic_profile
from researchclaw.domains.prompt_adapter import (
GenericPromptAdapter,
MLPromptAdapter,
PromptAdapter,
PromptBlocks,
get_adapter,
register_adapter,
)
# ---------------------------------------------------------------------------
# PromptBlocks tests
# ---------------------------------------------------------------------------
class TestPromptBlocks:
def test_default_empty(self):
blocks = PromptBlocks()
assert blocks.compute_budget == ""
assert blocks.dataset_guidance == ""
assert blocks.code_generation_hints == ""
def test_all_fields(self):
blocks = PromptBlocks(
compute_budget="budget info",
dataset_guidance="data info",
hp_reporting="hp info",
code_generation_hints="code hints",
result_analysis_hints="analysis hints",
experiment_design_context="design context",
statistical_test_guidance="stat guidance",
output_format_guidance="output format",
)
assert blocks.compute_budget == "budget info"
assert blocks.output_format_guidance == "output format"
# ---------------------------------------------------------------------------
# ML Adapter tests
# ---------------------------------------------------------------------------
class TestMLPromptAdapter:
def test_returns_empty_blocks(self):
"""ML adapter must return empty blocks (delegates to prompts.py)."""
profile = get_profile("ml_vision") or DomainProfile(
domain_id="ml_vision", display_name="CV"
)
adapter = MLPromptAdapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.compute_budget == ""
assert blocks.dataset_guidance == ""
assert blocks.code_generation_hints == ""
def test_all_methods_return_empty(self):
profile = DomainProfile(domain_id="ml_generic", display_name="ML")
adapter = MLPromptAdapter(profile)
for method in [
adapter.get_code_generation_blocks,
adapter.get_experiment_design_blocks,
adapter.get_result_analysis_blocks,
]:
blocks = method({})
assert all(
getattr(blocks, f) == ""
for f in [
"compute_budget", "dataset_guidance", "hp_reporting",
"code_generation_hints", "result_analysis_hints",
]
)
# ---------------------------------------------------------------------------
# Generic Adapter tests
# ---------------------------------------------------------------------------
class TestGenericPromptAdapter:
def test_provides_code_hints(self):
profile = DomainProfile(
domain_id="generic",
display_name="Generic",
core_libraries=["numpy", "scipy"],
)
adapter = GenericPromptAdapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints # should not be empty
def test_convergence_hints(self):
profile = DomainProfile(
domain_id="test_conv",
display_name="Conv Test",
experiment_paradigm="convergence",
)
adapter = GenericPromptAdapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert "convergence" in blocks.code_generation_hints.lower()
def test_progressive_spec_hints(self):
profile = DomainProfile(
domain_id="test_econ",
display_name="Econ Test",
experiment_paradigm="progressive_spec",
)
adapter = GenericPromptAdapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert "progressive" in blocks.code_generation_hints.lower()
def test_experiment_design_has_terminology(self):
profile = DomainProfile(
domain_id="test",
display_name="Test Domain",
condition_terminology={"baseline": "reference", "proposed": "our method"},
standard_baselines=["Method A", "Method B"],
)
adapter = GenericPromptAdapter(profile)
blocks = adapter.get_experiment_design_blocks({})
assert "reference" in blocks.experiment_design_context
assert "Method A" in blocks.experiment_design_context
# ---------------------------------------------------------------------------
# Physics Adapter tests
# ---------------------------------------------------------------------------
class TestPhysicsAdapter:
def test_physics_adapter_loaded(self):
profile = get_profile("physics_simulation")
if profile is None:
pytest.skip("physics_simulation profile not found")
adapter = get_adapter(profile)
assert not isinstance(adapter, MLPromptAdapter)
def test_physics_code_blocks_nonempty(self):
profile = get_profile("physics_pde")
if profile is None:
pytest.skip("physics_pde profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints # should have physics-specific hints
# ---------------------------------------------------------------------------
# Economics Adapter tests
# ---------------------------------------------------------------------------
class TestEconomicsAdapter:
def test_economics_adapter_loaded(self):
profile = get_profile("economics_empirical")
if profile is None:
pytest.skip("economics_empirical profile not found")
adapter = get_adapter(profile)
assert not isinstance(adapter, MLPromptAdapter)
def test_economics_design_blocks(self):
profile = get_profile("economics_empirical")
if profile is None:
pytest.skip("economics_empirical profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_experiment_design_blocks({})
assert "progressive" in blocks.experiment_design_context.lower()
# ---------------------------------------------------------------------------
# get_adapter dispatch tests
# ---------------------------------------------------------------------------
class TestGetAdapter:
def test_ml_domains_get_ml_adapter(self):
for domain_id in ["ml_vision", "ml_nlp", "ml_rl", "ml_generic"]:
profile = get_profile(domain_id)
if profile is None:
continue
adapter = get_adapter(profile)
assert isinstance(adapter, MLPromptAdapter), (
f"{domain_id} should use MLPromptAdapter"
)
def test_generic_domain_gets_generic_adapter(self):
profile = get_generic_profile()
adapter = get_adapter(profile)
assert isinstance(adapter, GenericPromptAdapter)
def test_physics_uses_physics_adapter(self):
profile = get_profile("physics_simulation")
if profile is None:
pytest.skip("physics_simulation profile not found")
adapter = get_adapter(profile)
from researchclaw.domains.adapters.physics import PhysicsPromptAdapter
assert isinstance(adapter, PhysicsPromptAdapter)
def test_unknown_domain_gets_generic(self):
profile = DomainProfile(domain_id="unknown_domain", display_name="Unknown")
adapter = get_adapter(profile)
assert isinstance(adapter, GenericPromptAdapter)
# ---------------------------------------------------------------------------
# Blueprint context tests
# ---------------------------------------------------------------------------
class TestBlueprintContext:
def test_blueprint_includes_file_structure(self):
profile = DomainProfile(
domain_id="test",
display_name="Test",
typical_file_structure={"config.py": "Config", "main.py": "Entry"},
core_libraries=["numpy"],
)
adapter = GenericPromptAdapter(profile)
ctx = adapter.get_blueprint_context()
assert "config.py" in ctx
assert "numpy" in ctx
def test_blueprint_includes_hints(self):
profile = DomainProfile(
domain_id="test",
display_name="Test",
code_generation_hints="Use scipy.integrate for ODE solving",
)
adapter = GenericPromptAdapter(profile)
ctx = adapter.get_blueprint_context()
assert "scipy.integrate" in ctx
def test_ml_adapter_blueprint_context(self):
"""ML adapter should also provide basic blueprint context."""
profile = get_profile("ml_vision") or DomainProfile(
domain_id="ml_vision",
display_name="CV",
typical_file_structure={"model.py": "Model", "train.py": "Training"},
)
adapter = MLPromptAdapter(profile)
ctx = adapter.get_blueprint_context()
# ML adapter inherits from base, should have file structure if profile has it
if profile.typical_file_structure:
assert "model.py" in ctx or ctx == "" # acceptable either way
# ---------------------------------------------------------------------------
# Adapter registration tests
# ---------------------------------------------------------------------------
class TestAdapterRegistration:
def test_register_custom_adapter(self):
class CustomAdapter(PromptAdapter):
def get_code_generation_blocks(self, ctx):
return PromptBlocks(code_generation_hints="custom")
def get_experiment_design_blocks(self, ctx):
return PromptBlocks()
def get_result_analysis_blocks(self, ctx):
return PromptBlocks()
register_adapter("custom_domain", CustomAdapter)
profile = DomainProfile(domain_id="custom_domain", display_name="Custom")
adapter = get_adapter(profile)
assert isinstance(adapter, CustomAdapter)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints == "custom"
================================================
FILE: tests/test_rc_adapters.py
================================================
from __future__ import annotations
from researchclaw.adapters import (
AdapterBundle,
BrowserPage,
FetchResponse,
RecordingBrowserAdapter,
RecordingCronAdapter,
RecordingMemoryAdapter,
RecordingMessageAdapter,
RecordingSessionsAdapter,
RecordingWebFetchAdapter,
)
def test_adapter_bundle_defaults_are_recording_types():
bundle = AdapterBundle()
assert isinstance(bundle.cron, RecordingCronAdapter)
assert isinstance(bundle.message, RecordingMessageAdapter)
assert isinstance(bundle.memory, RecordingMemoryAdapter)
assert isinstance(bundle.sessions, RecordingSessionsAdapter)
assert isinstance(bundle.web_fetch, RecordingWebFetchAdapter)
assert isinstance(bundle.browser, RecordingBrowserAdapter)
def test_recording_cron_adapter_records_call_and_returns_id():
adapter = RecordingCronAdapter()
result = adapter.schedule_resume("run-1", 7, "gate opened")
assert result == "cron-1"
assert adapter.calls == [("run-1", 7, "gate opened")]
def test_recording_message_adapter_notify_records_call():
adapter = RecordingMessageAdapter()
result = adapter.notify("ops", "stage update", "stage 3 done")
assert result == "message-1"
assert adapter.calls == [("ops", "stage update", "stage 3 done")]
def test_recording_memory_adapter_append_records_entries():
adapter = RecordingMemoryAdapter()
result = adapter.append("runs", "run-1 started")
assert result == "memory-1"
assert adapter.entries == [("runs", "run-1 started")]
def test_recording_sessions_adapter_spawn_records_calls():
adapter = RecordingSessionsAdapter()
result = adapter.spawn("worker", ("python", "train.py"))
assert result == "session-1"
assert adapter.calls == [("worker", ("python", "train.py"))]
def test_recording_webfetch_fetch_returns_success_response():
adapter = RecordingWebFetchAdapter()
response = adapter.fetch("https://example.com")
assert isinstance(response, FetchResponse)
assert response.url == "https://example.com"
assert response.status_code == 200
assert "stub fetch" in response.text
def test_recording_browser_open_returns_browser_page():
adapter = RecordingBrowserAdapter()
page = adapter.open("https://example.com")
assert isinstance(page, BrowserPage)
assert page.url == "https://example.com"
assert "Stub browser page" in page.title
def test_fetch_response_dataclass_fields():
response = FetchResponse(url="u", status_code=201, text="ok")
assert response.url == "u"
assert response.status_code == 201
assert response.text == "ok"
def test_browser_page_dataclass_fields():
page = BrowserPage(url="https://a", title="A")
assert page.url == "https://a"
assert page.title == "A"
def test_all_adapters_start_with_empty_call_lists():
cron = RecordingCronAdapter()
message = RecordingMessageAdapter()
memory = RecordingMemoryAdapter()
sessions = RecordingSessionsAdapter()
web_fetch = RecordingWebFetchAdapter()
browser = RecordingBrowserAdapter()
assert cron.calls == []
assert message.calls == []
assert memory.entries == []
assert sessions.calls == []
assert web_fetch.calls == []
assert browser.calls == []
================================================
FILE: tests/test_rc_cache.py
================================================
"""Tests for literature query cache and degradation fallback."""
from __future__ import annotations
import importlib
from unittest.mock import patch
from researchclaw.literature.models import Author, Paper
from researchclaw.literature.search import search_papers
cache_mod = importlib.import_module("researchclaw.literature.cache")
cache_key = cache_mod.cache_key
cache_stats = cache_mod.cache_stats
clear_cache = cache_mod.clear_cache
get_cached = cache_mod.get_cached
put_cache = cache_mod.put_cache
class TestCacheKey:
def test_deterministic(self, tmp_path):
_ = tmp_path
k1 = cache_key("transformer", "s2", 20)
k2 = cache_key("transformer", "s2", 20)
assert k1 == k2
def test_different_query(self):
k1 = cache_key("transformer", "s2", 20)
k2 = cache_key("attention", "s2", 20)
assert k1 != k2
def test_case_insensitive(self):
k1 = cache_key("Transformer", "S2", 20)
k2 = cache_key("transformer", "s2", 20)
assert k1 == k2
def test_length_16(self):
k = cache_key("test", "s2", 10)
assert len(k) == 16
class TestGetPut:
def test_put_and_get(self, tmp_path):
papers = [{"paper_id": "1", "title": "Test Paper"}]
put_cache("q1", "s2", 20, papers, cache_base=tmp_path)
result = get_cached("q1", "s2", 20, cache_base=tmp_path)
assert result is not None
assert len(result) == 1
assert result[0]["title"] == "Test Paper"
def test_cache_miss(self, tmp_path):
result = get_cached("nonexistent", "s2", 20, cache_base=tmp_path)
assert result is None
def test_cache_expired(self, tmp_path):
papers = [{"paper_id": "1", "title": "Old"}]
put_cache("q1", "s2", 20, papers, cache_base=tmp_path)
result = get_cached("q1", "s2", 20, cache_base=tmp_path, ttl=0)
assert result is None
def test_cache_not_expired(self, tmp_path):
papers = [{"paper_id": "1", "title": "Fresh"}]
put_cache("q1", "s2", 20, papers, cache_base=tmp_path)
result = get_cached("q1", "s2", 20, cache_base=tmp_path, ttl=9999)
assert result is not None
def test_corrupted_cache_returns_none(self, tmp_path):
key = cache_key("q1", "s2", 20)
(tmp_path / f"{key}.json").write_text("not json", encoding="utf-8")
result = get_cached("q1", "s2", 20, cache_base=tmp_path)
assert result is None
class TestClear:
def test_clear_removes_all(self, tmp_path):
put_cache("q1", "s2", 20, [{"id": "1"}], cache_base=tmp_path)
put_cache("q2", "arxiv", 10, [{"id": "2"}], cache_base=tmp_path)
count = clear_cache(cache_base=tmp_path)
assert count == 2
assert get_cached("q1", "s2", 20, cache_base=tmp_path) is None
def test_clear_empty(self, tmp_path):
count = clear_cache(cache_base=tmp_path)
assert count == 0
class TestStats:
def test_stats_empty(self, tmp_path):
stats = cache_stats(cache_base=tmp_path)
assert stats["entries"] == 0
assert stats["total_bytes"] == 0
def test_stats_with_entries(self, tmp_path):
put_cache("q1", "s2", 20, [{"id": "1"}], cache_base=tmp_path)
stats = cache_stats(cache_base=tmp_path)
assert stats["entries"] == 1
assert stats["total_bytes"] > 0
class TestSearchDegradation:
def test_search_uses_cache_on_failure(self, tmp_path):
cached_papers = [
{
"paper_id": "s2-123",
"title": "Cached Paper",
"authors": [],
"year": 2024,
"abstract": "",
"venue": "",
"citation_count": 10,
"doi": "",
"arxiv_id": "",
"url": "",
"source": "semantic_scholar",
}
]
put_cache(
"test query",
"semantic_scholar",
20,
cached_papers,
cache_base=tmp_path,
)
with patch(
"researchclaw.literature.search.search_openalex",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.search.search_semantic_scholar",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.search.search_arxiv",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.cache._DEFAULT_CACHE_DIR", tmp_path
):
with patch(
"researchclaw.literature.search.time.sleep", lambda _: None
):
results = search_papers("test query", limit=20)
assert len(results) >= 1
assert results[0].title == "Cached Paper"
def test_search_caches_successful_results(self, tmp_path):
mock_paper = Paper(
paper_id="s2-test",
title="Test",
authors=(Author(name="Smith"),),
year=2024,
abstract="abs",
source="semantic_scholar",
)
with patch(
"researchclaw.literature.search.search_semantic_scholar",
return_value=[mock_paper],
):
with patch("researchclaw.literature.search.search_arxiv", return_value=[]):
with patch(
"researchclaw.literature.cache._DEFAULT_CACHE_DIR", tmp_path
):
with patch(
"researchclaw.literature.search.time.sleep", lambda _: None
):
_ = search_papers("test", limit=20)
cached = get_cached("test", "semantic_scholar", 20, cache_base=tmp_path)
assert cached is not None
assert cached[0]["paper_id"] == "s2-test"
================================================
FILE: tests/test_rc_checkpoint.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnusedCallResult=false
"""Tests for checkpoint/resume and content metrics."""
from __future__ import annotations
import json
from pathlib import Path
from typing import cast
from researchclaw.pipeline.executor import StageResult
from researchclaw.pipeline.runner import (
_build_pipeline_summary,
_collect_content_metrics,
_write_checkpoint,
read_checkpoint,
resume_from_checkpoint,
)
from researchclaw.pipeline.stages import (
NONCRITICAL_STAGES,
STAGE_SEQUENCE,
Stage,
StageStatus,
)
class TestCheckpoint:
def test_write_checkpoint(self, tmp_path: Path):
_write_checkpoint(tmp_path, Stage.LITERATURE_COLLECT, "test-run")
cp = json.loads((tmp_path / "checkpoint.json").read_text())
assert cp["last_completed_stage"] == 4
assert cp["last_completed_name"] == "LITERATURE_COLLECT"
assert cp["run_id"] == "test-run"
assert "timestamp" in cp
def test_read_checkpoint_returns_next_stage(self, tmp_path: Path):
_write_checkpoint(tmp_path, Stage.LITERATURE_COLLECT, "test-run")
next_stage = read_checkpoint(tmp_path)
assert next_stage == Stage.LITERATURE_SCREEN
def test_read_checkpoint_no_file(self, tmp_path: Path):
assert read_checkpoint(tmp_path) is None
def test_read_checkpoint_last_stage(self, tmp_path: Path):
_write_checkpoint(tmp_path, Stage.CITATION_VERIFY, "test-run")
assert read_checkpoint(tmp_path) is None
def test_read_checkpoint_corrupted(self, tmp_path: Path):
(tmp_path / "checkpoint.json").write_text("not json", encoding="utf-8")
assert read_checkpoint(tmp_path) is None
def test_read_checkpoint_invalid_stage(self, tmp_path: Path):
(tmp_path / "checkpoint.json").write_text(
json.dumps({"last_completed_stage": 999}), encoding="utf-8"
)
assert read_checkpoint(tmp_path) is None
def test_resume_from_checkpoint_uses_default(self, tmp_path: Path):
assert resume_from_checkpoint(tmp_path) == Stage.TOPIC_INIT
def test_resume_from_checkpoint_uses_next_stage(self, tmp_path: Path):
_write_checkpoint(tmp_path, Stage.SEARCH_STRATEGY, "run-x")
assert resume_from_checkpoint(tmp_path) == Stage.LITERATURE_COLLECT
class TestNoncriticalStages:
def test_knowledge_archive_is_noncritical(self):
assert Stage.KNOWLEDGE_ARCHIVE in NONCRITICAL_STAGES
def test_citation_verify_is_critical(self):
# T3.4: CITATION_VERIFY is now critical — hallucinated refs must block export
assert Stage.CITATION_VERIFY not in NONCRITICAL_STAGES
def test_topic_init_is_critical(self):
assert Stage.TOPIC_INIT not in NONCRITICAL_STAGES
def test_paper_draft_is_critical(self):
assert Stage.PAPER_DRAFT not in NONCRITICAL_STAGES
def test_stage_sequence_still_ends_with_citation_verify(self):
assert STAGE_SEQUENCE[-1] == Stage.CITATION_VERIFY
class TestContentMetrics:
def test_metrics_empty_run_dir(self, tmp_path: Path):
metrics = _collect_content_metrics(tmp_path)
assert metrics["template_ratio"] is None
assert metrics["citation_verify_score"] is None
assert metrics["total_citations"] is None
assert metrics["degraded_sources"] == []
def test_metrics_with_draft(self, tmp_path: Path):
draft_dir = tmp_path / "stage-17"
draft_dir.mkdir()
(draft_dir / "paper_draft.md").write_text(
"This is a real academic paper about transformers and attention mechanisms. We propose a novel method for improving efficiency.",
encoding="utf-8",
)
metrics = _collect_content_metrics(tmp_path)
assert metrics["template_ratio"] is not None
assert cast(float, metrics["template_ratio"]) < 0.5
def test_metrics_with_verification(self, tmp_path: Path):
verify_dir = tmp_path / "stage-23"
verify_dir.mkdir()
(verify_dir / "verification_report.json").write_text(
json.dumps(
{
"summary": {
"total": 10,
"verified": 8,
"suspicious": 1,
"hallucinated": 1,
"skipped": 0,
"integrity_score": 0.8
},
"results": []
}
),
encoding="utf-8",
)
metrics = _collect_content_metrics(tmp_path)
assert metrics["total_citations"] == 10
assert metrics["verified_citations"] == 8
assert metrics["citation_verify_score"] == 0.8
def test_metrics_no_stage23(self, tmp_path: Path):
metrics = _collect_content_metrics(tmp_path)
assert metrics["citation_verify_score"] is None
def test_metrics_with_non_dict_summary(self, tmp_path: Path):
"""Must not raise NameError when 'summary' is not a dict."""
verify_dir = tmp_path / "stage-23"
verify_dir.mkdir()
(verify_dir / "verification_report.json").write_text(
json.dumps({"summary": "unexpected string"}),
encoding="utf-8",
)
metrics = _collect_content_metrics(tmp_path)
assert metrics["total_citations"] is None
assert metrics["verified_citations"] is None
assert metrics["citation_verify_score"] is None
def test_metrics_with_summary_missing_fields(self, tmp_path: Path):
"""summary dict without total/verified should not crash."""
verify_dir = tmp_path / "stage-23"
verify_dir.mkdir()
(verify_dir / "verification_report.json").write_text(
json.dumps({"summary": {"notes": "incomplete"}}),
encoding="utf-8",
)
metrics = _collect_content_metrics(tmp_path)
assert metrics["total_citations"] == 0
assert metrics["verified_citations"] == 0
assert metrics["citation_verify_score"] is None
def test_summary_includes_content_metrics(self, tmp_path: Path):
results = [
StageResult(
stage=Stage.TOPIC_INIT,
status=StageStatus.DONE,
artifacts=("topic.json",),
),
]
summary = _build_pipeline_summary(
run_id="test",
results=results,
from_stage=Stage.TOPIC_INIT,
run_dir=tmp_path,
)
assert "content_metrics" in summary
assert isinstance(summary["content_metrics"], dict)
================================================
FILE: tests/test_rc_citation_resolve.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false
"""Tests for BUG-194: Citation resolver must not replace correct bib entries
with garbage papers from search results.
Tests cover:
- _resolve_missing_citations: seminal lookup, API validation, rejection of
unrelated results, year mismatch rejection
- _load_seminal_papers_by_key: index construction
- _seminal_to_bibtex: BibTeX generation from YAML entries
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from researchclaw.literature.models import Author, Paper
# ---------------------------------------------------------------------------
# Helpers to build mock Paper objects
# ---------------------------------------------------------------------------
def _make_paper(
title: str,
year: int = 2020,
authors: list[str] | None = None,
bibtex_override: str = "",
) -> Paper:
"""Create a Paper with minimal metadata."""
return Paper(
paper_id=f"test_{title[:10].replace(' ', '_').lower()}",
title=title,
authors=tuple(Author(name=n) for n in (authors or ["Unknown"])),
year=year,
source="test",
_bibtex_override=bibtex_override,
)
# Patch target for search_papers — the import inside _resolve_missing_citations
# does `from researchclaw.literature.search import search_papers`, so we patch
# the source module.
_SEARCH_PAPERS_PATH = "researchclaw.literature.search.search_papers"
# ---------------------------------------------------------------------------
# Tests for _load_seminal_papers_by_key
# ---------------------------------------------------------------------------
class TestLoadSeminalPapersByKey:
"""Test the seminal papers index builder."""
def test_loads_well_known_keys(self):
from researchclaw.pipeline.stage_impls._review_publish import (
_load_seminal_papers_by_key,
)
index = _load_seminal_papers_by_key()
# The seminal_papers.yaml must contain these foundational papers
assert "he2016deep" in index
assert "vaswani2017attention" in index
assert "srivastava2014dropout" in index
def test_entries_have_required_fields(self):
from researchclaw.pipeline.stage_impls._review_publish import (
_load_seminal_papers_by_key,
)
index = _load_seminal_papers_by_key()
for key, entry in index.items():
assert "title" in entry, f"Missing title for {key}"
assert "year" in entry, f"Missing year for {key}"
assert "authors" in entry, f"Missing authors for {key}"
def test_graceful_on_load_failure(self):
"""If _load_all raises, _load_seminal_papers_by_key returns {}."""
from researchclaw.pipeline.stage_impls._review_publish import (
_load_seminal_papers_by_key,
)
with patch(
"researchclaw.data._load_all",
side_effect=RuntimeError("disk error"),
):
result = _load_seminal_papers_by_key()
assert result == {}
# ---------------------------------------------------------------------------
# Tests for _seminal_to_bibtex
# ---------------------------------------------------------------------------
class TestSeminalToBibtex:
"""Test BibTeX generation from seminal_papers.yaml entries."""
def test_conference_paper(self):
from researchclaw.pipeline.stage_impls._review_publish import _seminal_to_bibtex
entry = {
"title": "Deep Residual Learning for Image Recognition",
"authors": "He et al.",
"year": 2016,
"venue": "CVPR",
}
bib = _seminal_to_bibtex(entry, "he2016deep")
assert "@inproceedings{he2016deep," in bib
assert "Deep Residual Learning" in bib
assert "He et al." in bib
assert "2016" in bib
assert "booktitle = {CVPR}" in bib
def test_journal_paper(self):
from researchclaw.pipeline.stage_impls._review_publish import _seminal_to_bibtex
entry = {
"title": "Dropout: A Simple Way to Prevent Neural Networks from Overfitting",
"authors": "Srivastava et al.",
"year": 2014,
"venue": "JMLR",
}
bib = _seminal_to_bibtex(entry, "srivastava2014dropout")
assert "@article{srivastava2014dropout," in bib
assert "Dropout" in bib
assert "journal = {JMLR}" in bib
def test_neurips_is_conference(self):
from researchclaw.pipeline.stage_impls._review_publish import _seminal_to_bibtex
entry = {
"title": "Attention Is All You Need",
"authors": "Vaswani et al.",
"year": 2017,
"venue": "NeurIPS",
}
bib = _seminal_to_bibtex(entry, "vaswani2017attention")
assert "@inproceedings{vaswani2017attention," in bib
# ---------------------------------------------------------------------------
# Tests for _resolve_missing_citations
# ---------------------------------------------------------------------------
class TestResolveMissingCitations:
"""Test the full resolution pipeline with BUG-194 fixes."""
def test_seminal_papers_resolved_without_api(self):
"""Foundational papers should be resolved from seminal_papers.yaml
without any API calls."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
missing = {"he2016deep", "vaswani2017attention", "srivastava2014dropout"}
existing_bib = ""
# Patch search_papers so it FAILS if called — seminal papers shouldn't
# need it.
with patch(
_SEARCH_PAPERS_PATH,
side_effect=AssertionError("Should not be called for seminal papers"),
):
resolved, entries = _resolve_missing_citations(missing, existing_bib)
assert "he2016deep" in resolved
assert "vaswani2017attention" in resolved
assert "srivastava2014dropout" in resolved
assert len(entries) == 3
# Verify the BibTeX entries contain correct titles
combined = "\n".join(entries)
assert "Deep Residual Learning" in combined
assert "Attention Is All You Need" in combined
assert "Dropout" in combined
def test_seminal_papers_not_duplicated_in_existing_bib(self):
"""If the key is already in existing_bib, don't add it again."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
existing_bib = "@article{he2016deep, title={Deep Residual Learning}}"
missing = {"he2016deep"}
# Mock search_papers to ensure no real API calls (key should be skipped
# entirely since it's already in existing_bib).
with patch(
_SEARCH_PAPERS_PATH,
side_effect=AssertionError("Should not call API for key in existing_bib"),
):
resolved, entries = _resolve_missing_citations(missing, existing_bib)
assert "he2016deep" not in resolved
assert len(entries) == 0
def test_garbage_results_rejected_by_similarity(self):
"""BUG-194 regression: unrelated search results must be rejected."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
# Mock a garbage result that has the right year but wrong title
garbage_paper = _make_paper(
title="Jokowi and the New Developmentalism",
year=2016,
authors=["He, Some Politician"],
bibtex_override=(
"@article{jokowi2016,\n"
" title = {Jokowi and the New Developmentalism},\n"
" author = {He, Some Politician},\n"
" year = {2016},\n"
"}"
),
)
# This key is NOT in seminal_papers.yaml
missing = {"smith2016novel"}
with patch(_SEARCH_PAPERS_PATH, return_value=[garbage_paper]):
resolved, entries = _resolve_missing_citations(missing, "")
# The garbage result should be rejected (no overlap with "smith novel")
assert "smith2016novel" not in resolved
assert len(entries) == 0
def test_year_mismatch_rejected(self):
"""Results with year > 1 year off from cite key are rejected."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
wrong_year_paper = _make_paper(
title="Novel Deep Learning Approach by Smith",
year=2020, # cite key says 2016
authors=["Smith, John"],
bibtex_override=(
"@article{smith2020,\n"
" title = {Novel Deep Learning Approach by Smith},\n"
" author = {Smith, John},\n"
" year = {2020},\n"
"}"
),
)
missing = {"smith2016novel"}
with patch(_SEARCH_PAPERS_PATH, return_value=[wrong_year_paper]):
resolved, entries = _resolve_missing_citations(missing, "")
assert "smith2016novel" not in resolved
def test_good_api_result_accepted(self):
"""A search result with matching author + title words should be accepted."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
good_paper = _make_paper(
title="Novel Approach to Feature Extraction in Deep Networks",
year=2018,
authors=["Chen, Wei"],
bibtex_override=(
"@article{chen2018something,\n"
" title = {Novel Approach to Feature Extraction in Deep Networks},\n"
" author = {Chen, Wei},\n"
" year = {2018},\n"
"}"
),
)
# cite key: chen2018novel — "chen" matches author, "novel" matches title
missing = {"chen2018novel"}
with patch(_SEARCH_PAPERS_PATH, return_value=[good_paper]):
resolved, entries = _resolve_missing_citations(missing, "")
assert "chen2018novel" in resolved
assert len(entries) == 1
# The bib entry should use the original cite_key
assert "chen2018novel" in entries[0]
def test_empty_missing_keys_returns_empty(self):
"""No keys to resolve -> empty results."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
resolved, entries = _resolve_missing_citations(set(), "")
assert len(resolved) == 0
assert len(entries) == 0
def test_unparseable_keys_skipped(self):
"""Keys that don't match author-year pattern are skipped."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
missing = {"notyearkey", "abc"}
resolved, entries = _resolve_missing_citations(missing, "")
assert len(resolved) == 0
assert len(entries) == 0
def test_import_failure_returns_seminal_only(self):
"""If search_papers can't be imported, seminal results still returned."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
# Mix of seminal and non-seminal keys
missing = {"he2016deep", "unknownauthor2020something"}
with patch(
_SEARCH_PAPERS_PATH,
side_effect=ImportError("mocked"),
):
resolved, entries = _resolve_missing_citations(missing, "")
# he2016deep should be resolved from seminal
assert "he2016deep" in resolved
# unknownauthor2020something would need API which fails
assert "unknownauthor2020something" not in resolved
def test_search_exception_handled_gracefully(self):
"""If search_papers raises, the key is skipped (no crash)."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
missing = {"unknownauthor2020something"}
with patch(
_SEARCH_PAPERS_PATH,
side_effect=RuntimeError("API down"),
):
resolved, entries = _resolve_missing_citations(missing, "")
assert len(resolved) == 0
def test_bug194_he2016deep_not_replaced_with_jokowi(self):
"""BUG-194 exact regression: he2016deep must NEVER resolve to
'Jokowi and the New Developmentalism'."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
# he2016deep IS in seminal_papers.yaml, so it should resolve from there
missing = {"he2016deep"}
resolved, entries = _resolve_missing_citations(missing, "")
assert "he2016deep" in resolved
assert len(entries) == 1
assert "Jokowi" not in entries[0]
assert "Deep Residual Learning" in entries[0]
def test_bug194_vaswani2017attention_not_replaced_with_health_supplement(self):
"""BUG-194 exact regression: vaswani2017attention must resolve to
'Attention Is All You Need', not health supplement garbage."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
missing = {"vaswani2017attention"}
resolved, entries = _resolve_missing_citations(missing, "")
assert "vaswani2017attention" in resolved
assert len(entries) == 1
assert "Health Supplement" not in entries[0]
assert "Attention Is All You Need" in entries[0]
def test_bug194_srivastava2014dropout_not_replaced_with_cnn_sentence(self):
"""BUG-194 exact regression: srivastava2014dropout must resolve to
Dropout paper, not CNN for Sentence Classification."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
missing = {"srivastava2014dropout"}
resolved, entries = _resolve_missing_citations(missing, "")
assert "srivastava2014dropout" in resolved
assert len(entries) == 1
assert "Sentence Classification" not in entries[0]
assert "Dropout" in entries[0]
def test_multiple_seminal_and_api_mixed(self):
"""Mix of seminal keys (resolved locally) and API keys."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
api_paper = _make_paper(
title="Adaptive Learning Rate Methods for Deep Networks",
year=2019,
authors=["Zhang, Adaptive"],
bibtex_override=(
"@article{zhang2019something,\n"
" title = {Adaptive Learning Rate Methods for Deep Networks},\n"
" author = {Zhang, Adaptive},\n"
" year = {2019},\n"
"}"
),
)
missing = {"he2016deep", "zhang2019adaptive"}
with patch(_SEARCH_PAPERS_PATH, return_value=[api_paper]):
resolved, entries = _resolve_missing_citations(missing, "")
# he2016deep from seminal, zhang2019adaptive from API
assert "he2016deep" in resolved
assert "zhang2019adaptive" in resolved
assert len(entries) == 2
def test_no_results_from_api_skips(self):
"""If API returns empty list, key is skipped (not crashed)."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
missing = {"unknownauthor2020something"}
with patch(_SEARCH_PAPERS_PATH, return_value=[]):
resolved, entries = _resolve_missing_citations(missing, "")
assert len(resolved) == 0
assert len(entries) == 0
def test_close_year_accepted(self):
"""A result with year within 1 of the cite key year should be accepted
(arXiv vs conference year difference)."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
paper = _make_paper(
title="Novel Deep Feature Extraction by Li",
year=2019, # cite key says 2018, but 1 year off is OK
authors=["Li, Novel"],
bibtex_override=(
"@article{li2019,\n"
" title = {Novel Deep Feature Extraction by Li},\n"
" author = {Li, Novel},\n"
" year = {2019},\n"
"}"
),
)
missing = {"li2018novel"}
with patch(_SEARCH_PAPERS_PATH, return_value=[paper]):
resolved, entries = _resolve_missing_citations(missing, "")
# Year 2019 vs 2018 — diff=1, should be accepted since title matches
assert "li2018novel" in resolved
def test_completely_unrelated_title_rejected(self):
"""Even if year and author name match, completely unrelated title
must be rejected."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
paper = _make_paper(
title="AI-Assisted Pipeline for Dynamic Generation of Trustworthy Health Supplement Content at Scale",
year=2017,
authors=["Vaswani, Raj"],
bibtex_override=(
"@article{vaswani2017health,\n"
" title = {AI-Assisted Pipeline for Dynamic Generation of Trustworthy Health Supplement Content at Scale},\n"
" author = {Vaswani, Raj},\n"
" year = {2017},\n"
"}"
),
)
# Not in seminal_papers.yaml (different key)
missing = {"vaswani2017health"}
with patch(_SEARCH_PAPERS_PATH, return_value=[paper]):
resolved, entries = _resolve_missing_citations(missing, "")
# "health" matches but the overall overlap with query words
# ["vaswani", "health"] should be evaluated. "vaswani" is in author
# and "health" is in title, so it may pass. But this tests the
# validation path at least works.
# The key point: the search is called only for non-seminal keys.
def test_picks_best_result_from_multiple(self):
"""When API returns multiple results, the one with best overlap wins."""
from researchclaw.pipeline.stage_impls._review_publish import (
_resolve_missing_citations,
)
bad_paper = _make_paper(
title="Convolutional Neural Networks for Sentence Classification",
year=2018,
authors=["Kim, Yoon"],
)
good_paper = _make_paper(
title="Feature Extraction via Progressive Learning",
year=2018,
authors=["Wang, Feature"],
bibtex_override=(
"@article{wang2018,\n"
" title = {Feature Extraction via Progressive Learning},\n"
" author = {Wang, Feature},\n"
" year = {2018},\n"
"}"
),
)
missing = {"wang2018feature"}
with patch(_SEARCH_PAPERS_PATH, return_value=[bad_paper, good_paper]):
resolved, entries = _resolve_missing_citations(missing, "")
if resolved:
# If resolved, it should be the good paper, not the bad one
assert "Sentence Classification" not in entries[0]
================================================
FILE: tests/test_rc_citation_verify.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false
from __future__ import annotations
import json
import textwrap
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.literature.verify import (
CitationResult,
VerificationReport,
VerifyStatus,
annotate_paper_hallucinations,
filter_verified_bibtex,
parse_bibtex_entries,
title_similarity,
verify_by_arxiv_id,
verify_by_doi,
verify_by_title_search,
verify_citations,
)
from researchclaw.literature.models import Author, Paper
SAMPLE_BIB = textwrap.dedent("""\
@article{vaswani2017attention,
title = {Attention Is All You Need},
author = {Ashish Vaswani and Noam Shazeer},
year = {2017},
eprint = {1706.03762},
archiveprefix = {arXiv},
}
@inproceedings{devlin2019bert,
title = {BERT: Pre-training of Deep Bidirectional Transformers},
author = {Jacob Devlin},
year = {2019},
doi = {10.18653/v1/N19-1423},
booktitle = {NAACL},
}
@article{fakepaper2025hallucinated,
title = {A Completely Made Up Paper That Does Not Exist},
author = {Imaginary Author},
year = {2025},
}
""")
SAMPLE_ARXIV_VERIFY_RESPONSE = textwrap.dedent("""\
http://arxiv.org/abs/1706.03762v5
Attention Is All You Need
The dominant sequence transduction models...
Ashish Vaswani
""")
SAMPLE_ARXIV_EMPTY_RESPONSE = textwrap.dedent("""\
http://arxiv.org/api/errors#incorrect_id_format_for_9999.99999
Error
incorrect id format for 9999.99999
""")
SAMPLE_CROSSREF_RESPONSE = {
"status": "ok",
"message": {
"DOI": "10.18653/v1/N19-1423",
"title": [
"BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
],
"author": [{"given": "Jacob", "family": "Devlin"}],
},
}
class TestParseBibtexEntries:
def test_parses_three_entries(self) -> None:
entries = parse_bibtex_entries(SAMPLE_BIB)
assert len(entries) == 3
def test_entry_keys(self) -> None:
entries = parse_bibtex_entries(SAMPLE_BIB)
keys = [e["key"] for e in entries]
assert "vaswani2017attention" in keys
assert "devlin2019bert" in keys
assert "fakepaper2025hallucinated" in keys
def test_entry_fields(self) -> None:
entries = parse_bibtex_entries(SAMPLE_BIB)
vaswani = next(e for e in entries if e["key"] == "vaswani2017attention")
assert vaswani["title"] == "Attention Is All You Need"
assert vaswani["eprint"] == "1706.03762"
assert vaswani["type"] == "article"
def test_entry_type(self) -> None:
entries = parse_bibtex_entries(SAMPLE_BIB)
devlin = next(e for e in entries if e["key"] == "devlin2019bert")
assert devlin["type"] == "inproceedings"
assert devlin["doi"] == "10.18653/v1/N19-1423"
def test_empty_bib(self) -> None:
assert parse_bibtex_entries("") == []
def test_malformed_bib(self) -> None:
assert parse_bibtex_entries("not bibtex at all") == []
class TestTitleSimilarity:
def test_identical(self) -> None:
assert (
title_similarity("Attention Is All You Need", "Attention Is All You Need")
== 1.0
)
def test_case_insensitive(self) -> None:
assert (
title_similarity("attention is all you need", "ATTENTION IS ALL YOU NEED")
== 1.0
)
def test_high_similarity(self) -> None:
sim = title_similarity(
"Attention Is All You Need",
"Attention Is All You Need: A Transformer Architecture",
)
assert sim >= 0.5
def test_low_similarity(self) -> None:
sim = title_similarity(
"Attention Is All You Need",
"Protein Folding with AlphaFold",
)
assert sim < 0.3
def test_empty_strings(self) -> None:
assert title_similarity("", "") == 0.0
assert title_similarity("something", "") == 0.0
class TestVerifyByArxivId:
def test_verified_match(self) -> None:
mock_resp = MagicMock()
mock_resp.read.return_value = SAMPLE_ARXIV_VERIFY_RESPONSE.encode("utf-8")
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp):
result = verify_by_arxiv_id("1706.03762", "Attention Is All You Need")
assert result is not None
assert result.status == VerifyStatus.VERIFIED
assert result.method == "arxiv_id"
assert result.confidence >= 0.80
def test_hallucinated_error_response(self) -> None:
mock_resp = MagicMock()
mock_resp.read.return_value = SAMPLE_ARXIV_EMPTY_RESPONSE.encode("utf-8")
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp):
result = verify_by_arxiv_id("9999.99999", "Fake Paper")
assert result is not None
assert result.status == VerifyStatus.HALLUCINATED
def test_network_failure_returns_none(self) -> None:
with patch("urllib.request.urlopen", side_effect=OSError("connection refused")):
result = verify_by_arxiv_id("1706.03762", "Attention Is All You Need")
assert result is None
def test_title_mismatch_suspicious(self) -> None:
different_title_response = textwrap.dedent("""\
http://arxiv.org/abs/1706.03762v5
A Completely Different Paper Title About Quantum Computing
Summary
""")
mock_resp = MagicMock()
mock_resp.read.return_value = different_title_response.encode("utf-8")
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp):
result = verify_by_arxiv_id("1706.03762", "Attention Is All You Need")
assert result is not None
assert result.status == VerifyStatus.SUSPICIOUS
class TestVerifyByDoi:
def test_verified_crossref(self) -> None:
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps(SAMPLE_CROSSREF_RESPONSE).encode(
"utf-8"
)
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp):
result = verify_by_doi(
"10.18653/v1/N19-1423",
"BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding",
)
assert result is not None
assert result.status == VerifyStatus.VERIFIED
assert result.method == "doi"
def test_doi_404_hallucinated(self) -> None:
import urllib.error
with patch(
"urllib.request.urlopen",
side_effect=urllib.error.HTTPError(
"https://api.crossref.org/works/10.fake/doi",
404,
"Not Found",
{},
None, # type: ignore[arg-type]
),
):
result = verify_by_doi("10.fake/doi", "Nonexistent Paper")
assert result is not None
assert result.status == VerifyStatus.HALLUCINATED
def test_network_error_returns_none(self) -> None:
with patch("urllib.request.urlopen", side_effect=OSError("timeout")):
result = verify_by_doi("10.1234/test", "Test Paper")
assert result is None
def test_doi_exists_no_title(self) -> None:
no_title_resp = {"status": "ok", "message": {"DOI": "10.1234/test"}}
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps(no_title_resp).encode("utf-8")
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp):
result = verify_by_doi("10.1234/test", "Some Paper")
assert result is not None
assert result.status == VerifyStatus.VERIFIED
assert "no title comparison" in result.details.lower()
class TestVerifyByTitleSearch:
def test_verified_via_search(self) -> None:
mock_paper = Paper(
paper_id="s2-abc",
title="Attention Is All You Need",
authors=(Author(name="Vaswani"),),
year=2017,
source="semantic_scholar",
)
with patch(
"researchclaw.literature.search.search_papers",
return_value=[mock_paper],
):
result = verify_by_title_search("Attention Is All You Need")
assert result is not None
assert result.status == VerifyStatus.VERIFIED
assert result.matched_paper is not None
def test_no_results_hallucinated(self) -> None:
with patch("researchclaw.literature.search.search_papers", return_value=[]):
result = verify_by_title_search("A Completely Made Up Paper")
assert result is not None
assert result.status == VerifyStatus.HALLUCINATED
def test_weak_match_hallucinated(self) -> None:
mock_paper = Paper(
paper_id="s2-xyz",
title="Quantum Computing for Protein Folding",
year=2023,
source="arxiv",
)
with patch(
"researchclaw.literature.search.search_papers",
return_value=[mock_paper],
):
result = verify_by_title_search("A Completely Made Up Paper About Nothing")
assert result is not None
assert result.status == VerifyStatus.HALLUCINATED
def test_partial_match_suspicious(self) -> None:
mock_paper = Paper(
paper_id="s2-partial",
title="Attention Mechanisms in Neural Networks",
year=2019,
source="semantic_scholar",
)
with patch(
"researchclaw.literature.search.search_papers",
return_value=[mock_paper],
):
result = verify_by_title_search("Attention Neural Networks Survey Overview")
assert result is not None
assert result.status in (VerifyStatus.SUSPICIOUS, VerifyStatus.HALLUCINATED)
def test_network_failure_returns_none(self) -> None:
with patch(
"researchclaw.literature.search.search_papers",
side_effect=OSError("network down"),
):
result = verify_by_title_search("Any Paper")
assert result is None
class TestVerifyCitations:
def test_full_pipeline_mocked(self) -> None:
arxiv_resp = MagicMock()
arxiv_resp.read.return_value = SAMPLE_ARXIV_VERIFY_RESPONSE.encode("utf-8")
arxiv_resp.__enter__ = lambda s: s
arxiv_resp.__exit__ = MagicMock(return_value=False)
crossref_resp = MagicMock()
crossref_resp.read.return_value = json.dumps(SAMPLE_CROSSREF_RESPONSE).encode(
"utf-8"
)
crossref_resp.__enter__ = lambda s: s
crossref_resp.__exit__ = MagicMock(return_value=False)
call_count = {"n": 0}
def mock_urlopen(req: Any, **kwargs: Any) -> MagicMock:
call_count["n"] += 1
url = req.full_url if hasattr(req, "full_url") else str(req)
if "arxiv.org" in url:
return arxiv_resp
if "crossref.org" in url:
return crossref_resp
raise OSError("unexpected URL")
with (
patch("researchclaw.literature.verify.time.sleep"),
patch("urllib.request.urlopen", side_effect=mock_urlopen),
patch("researchclaw.literature.search.search_papers", return_value=[]),
):
report = verify_citations(SAMPLE_BIB, inter_verify_delay=0)
assert report.total == 3
assert report.verified >= 1
assert report.hallucinated >= 1
report_dict = report.to_dict()
assert "summary" in report_dict
assert "results" in report_dict
assert report_dict["summary"]["total"] == 3
def test_empty_bib(self) -> None:
report = verify_citations("")
assert report.total == 0
assert report.integrity_score == 1.0
def test_no_title_entry_skipped(self) -> None:
bib = textwrap.dedent("""\
@article{noauthor2025,
author = {Some Author},
year = {2025},
}
""")
report = verify_citations(bib)
assert report.total == 1
assert report.skipped == 1
class TestVerificationReport:
def test_integrity_score(self) -> None:
report = VerificationReport(
total=10, verified=7, suspicious=1, hallucinated=2, skipped=0
)
assert report.integrity_score == 0.7
def test_integrity_score_with_skips(self) -> None:
report = VerificationReport(
total=10, verified=6, suspicious=0, hallucinated=2, skipped=2
)
assert report.integrity_score == 0.75
def test_integrity_score_all_skipped(self) -> None:
report = VerificationReport(
total=3, verified=0, suspicious=0, hallucinated=0, skipped=3
)
assert report.integrity_score == 1.0
def test_to_dict(self) -> None:
report = VerificationReport(total=2, verified=1, hallucinated=1)
d = report.to_dict()
assert d["summary"]["total"] == 2
assert d["summary"]["integrity_score"] == 0.5
class TestFilterVerifiedBibtex:
def _make_report(self) -> VerificationReport:
return VerificationReport(
total=3,
verified=1,
suspicious=1,
hallucinated=1,
results=[
CitationResult(
cite_key="vaswani2017attention",
title="Attention Is All You Need",
status=VerifyStatus.VERIFIED,
confidence=1.0,
method="arxiv_id",
),
CitationResult(
cite_key="devlin2019bert",
title="BERT",
status=VerifyStatus.SUSPICIOUS,
confidence=0.6,
method="doi",
),
CitationResult(
cite_key="fakepaper2025hallucinated",
title="Fake Paper",
status=VerifyStatus.HALLUCINATED,
confidence=0.9,
method="title_search",
),
],
)
def test_includes_verified_and_suspicious(self) -> None:
report = self._make_report()
filtered = filter_verified_bibtex(SAMPLE_BIB, report, include_suspicious=True)
assert "vaswani2017attention" in filtered
assert "devlin2019bert" in filtered
assert "fakepaper2025hallucinated" not in filtered
def test_excludes_suspicious(self) -> None:
report = self._make_report()
filtered = filter_verified_bibtex(SAMPLE_BIB, report, include_suspicious=False)
assert "vaswani2017attention" in filtered
assert "devlin2019bert" not in filtered
assert "fakepaper2025hallucinated" not in filtered
def test_empty_bib(self) -> None:
report = VerificationReport()
assert filter_verified_bibtex("", report) == ""
class TestAnnotatePaperHallucinations:
def test_latex_citations(self) -> None:
paper = r"As shown in \cite{vaswani2017attention} and \cite{fakepaper2025hallucinated}."
report = VerificationReport(
results=[
CitationResult(
cite_key="vaswani2017attention",
title="",
status=VerifyStatus.VERIFIED,
confidence=1.0,
method="arxiv_id",
),
CitationResult(
cite_key="fakepaper2025hallucinated",
title="",
status=VerifyStatus.HALLUCINATED,
confidence=0.9,
method="title_search",
),
],
)
result = annotate_paper_hallucinations(paper, report)
assert r"\cite{vaswani2017attention}" in result
# Hallucinated citations are removed, not annotated
assert "fakepaper2025hallucinated" not in result
def test_markdown_citations(self) -> None:
paper = "As shown in [vaswani2017attention] and [fakepaper2025hallucinated]."
report = VerificationReport(
results=[
CitationResult(
cite_key="vaswani2017attention",
title="",
status=VerifyStatus.VERIFIED,
confidence=1.0,
method="arxiv_id",
),
CitationResult(
cite_key="fakepaper2025hallucinated",
title="",
status=VerifyStatus.HALLUCINATED,
confidence=0.9,
method="title_search",
),
],
)
result = annotate_paper_hallucinations(paper, report)
assert "[vaswani2017attention]" in result
# Hallucinated citations are removed, not annotated
assert "fakepaper2025hallucinated" not in result
def test_suspicious_annotation(self) -> None:
"""Suspicious citations are left unchanged (not removed)."""
paper = r"\cite{devlin2019bert}"
report = VerificationReport(
results=[
CitationResult(
cite_key="devlin2019bert",
title="",
status=VerifyStatus.SUSPICIOUS,
confidence=0.6,
method="doi",
),
],
)
result = annotate_paper_hallucinations(paper, report)
assert r"\cite{devlin2019bert}" in result
def test_no_modifications_all_verified(self) -> None:
paper = r"See \cite{vaswani2017attention}."
report = VerificationReport(
results=[
CitationResult(
cite_key="vaswani2017attention",
title="",
status=VerifyStatus.VERIFIED,
confidence=1.0,
method="arxiv_id",
),
],
)
result = annotate_paper_hallucinations(paper, report)
assert result == paper
class TestCitationResultSerialization:
def test_to_dict_basic(self) -> None:
result = CitationResult(
cite_key="smith2024test",
title="Test Paper",
status=VerifyStatus.VERIFIED,
confidence=0.95,
method="arxiv_id",
details="Confirmed",
)
d = result.to_dict()
assert d["cite_key"] == "smith2024test"
assert d["status"] == "verified"
assert d["confidence"] == 0.95
def test_to_dict_with_matched_paper(self) -> None:
paper = Paper(
paper_id="s2-abc",
title="Found Paper",
authors=(Author(name="Smith"),),
year=2024,
source="semantic_scholar",
)
result = CitationResult(
cite_key="smith2024test",
title="Test",
status=VerifyStatus.VERIFIED,
confidence=0.9,
method="title_search",
matched_paper=paper,
)
d = result.to_dict()
assert "matched_paper" in d
assert d["matched_paper"]["title"] == "Found Paper"
class TestStage23Integration:
def test_stage_exists_in_enum(self) -> None:
from researchclaw.pipeline.stages import Stage
assert hasattr(Stage, "CITATION_VERIFY")
assert Stage.CITATION_VERIFY == 23
def test_stage_in_sequence(self) -> None:
from researchclaw.pipeline.stages import Stage, STAGE_SEQUENCE, NEXT_STAGE
assert Stage.CITATION_VERIFY in STAGE_SEQUENCE
assert NEXT_STAGE[Stage.EXPORT_PUBLISH] == Stage.CITATION_VERIFY
assert NEXT_STAGE[Stage.CITATION_VERIFY] is None
def test_contract_exists(self) -> None:
from researchclaw.pipeline.contracts import CONTRACTS
from researchclaw.pipeline.stages import Stage
assert Stage.CITATION_VERIFY in CONTRACTS
contract = CONTRACTS[Stage.CITATION_VERIFY]
assert "verification_report.json" in contract.output_files
assert "references_verified.bib" in contract.output_files
def test_executor_registered(self) -> None:
from researchclaw.pipeline.executor import _STAGE_EXECUTORS
from researchclaw.pipeline.stages import Stage
assert Stage.CITATION_VERIFY in _STAGE_EXECUTORS
def test_phase_map(self) -> None:
from researchclaw.pipeline.stages import PHASE_MAP, Stage
finalization_stages = PHASE_MAP["H: Finalization"]
assert Stage.CITATION_VERIFY in finalization_stages
def test_total_stages_is_23(self) -> None:
from researchclaw.pipeline.stages import STAGE_SEQUENCE
assert len(STAGE_SEQUENCE) == 23
================================================
FILE: tests/test_rc_cli.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnusedCallResult=false, reportAttributeAccessIssue=false, reportUnknownLambdaType=false
from __future__ import annotations
import argparse
import re
from pathlib import Path
import pytest
from researchclaw import cli as rc_cli
from researchclaw.config import resolve_config_path
def _write_valid_config(path: Path) -> None:
path.write_text(
"""
project:
name: demo
mode: docs-first
research:
topic: Synthetic benchmark research
runtime:
timezone: UTC
notifications:
channel: test
knowledge_base:
backend: markdown
root: kb
openclaw_bridge: {}
llm:
provider: openai-compatible
base_url: http://localhost:1234/v1
api_key_env: TEST_KEY
""".strip()
+ "\n",
encoding="utf-8",
)
def test_main_with_no_args_returns_zero_and_prints_help(
capsys: pytest.CaptureFixture[str],
) -> None:
code = rc_cli.main([])
assert code == 0
captured = capsys.readouterr()
assert "ResearchClaw" in captured.out
assert "usage:" in captured.out
@pytest.mark.parametrize("argv", [["run", "--help"], ["validate", "--help"]])
def test_help_subcommands_exit_zero(argv: list[str]) -> None:
with pytest.raises(SystemExit) as exc_info:
rc_cli.main(argv)
assert exc_info.value.code == 0
def test_generate_run_id_format() -> None:
run_id = rc_cli._generate_run_id("my topic")
assert run_id.startswith("rc-")
assert re.fullmatch(r"rc-\d{8}-\d{6}-[0-9a-f]{6}", run_id)
def test_cmd_run_missing_config_returns_one(
tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
args = argparse.Namespace(
config=str(tmp_path / "missing.yaml"),
topic=None,
output=None,
from_stage=None,
auto_approve=False,
skip_preflight=True,
resume=False,
skip_noncritical_stage=False,
)
code = rc_cli.cmd_run(args)
assert code == 1
assert "config file not found" in capsys.readouterr().err
def test_cmd_validate_missing_config_returns_one(
tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
args = argparse.Namespace(
config=str(tmp_path / "missing.yaml"), no_check_paths=False
)
code = rc_cli.cmd_validate(args)
assert code == 1
assert "config file not found" in capsys.readouterr().err
def test_cmd_validate_valid_config_returns_zero(
tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
config_path = tmp_path / "config.yaml"
_write_valid_config(config_path)
args = argparse.Namespace(config=str(config_path), no_check_paths=True)
code = rc_cli.cmd_validate(args)
assert code == 0
assert "Config validation passed" in capsys.readouterr().out
def test_main_dispatches_run_command(monkeypatch: pytest.MonkeyPatch) -> None:
captured = {}
def fake_cmd_run(args):
captured["args"] = args
return 0
monkeypatch.setattr(rc_cli, "cmd_run", fake_cmd_run)
code = rc_cli.main(
[
"run",
"--topic",
"new topic",
"--config",
"cfg.yaml",
"--output",
"out-dir",
"--from-stage",
"PAPER_OUTLINE",
"--auto-approve",
]
)
assert code == 0
parsed = captured["args"]
assert parsed.topic == "new topic"
assert parsed.config == "cfg.yaml"
assert parsed.output == "out-dir"
assert parsed.from_stage == "PAPER_OUTLINE"
assert parsed.auto_approve is True
def test_main_dispatches_validate_command(monkeypatch: pytest.MonkeyPatch) -> None:
captured = {}
def fake_cmd_validate(args):
captured["args"] = args
return 0
monkeypatch.setattr(rc_cli, "cmd_validate", fake_cmd_validate)
code = rc_cli.main(["validate", "--config", "cfg.yaml", "--no-check-paths"])
assert code == 0
parsed = captured["args"]
assert parsed.config == "cfg.yaml"
assert parsed.no_check_paths is True
@pytest.mark.parametrize(
"argv",
[
["run", "--topic", "x", "--config", "c.yaml"],
["run", "--output", "out", "--config", "c.yaml"],
["run", "--from-stage", "TOPIC_INIT", "--config", "c.yaml"],
["run", "--auto-approve", "--config", "c.yaml"],
],
)
def test_run_parser_accepts_required_flags(
argv: list[str], monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(rc_cli, "cmd_run", lambda args: 0)
assert rc_cli.main(argv) == 0
def test_validate_parser_accepts_config_flag(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(rc_cli, "cmd_validate", lambda args: 0)
assert rc_cli.main(["validate", "--config", "cfg.yaml"]) == 0
# --- resolve_config_path tests ---
def test_resolve_config_finds_arc_yaml_first(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.chdir(tmp_path)
(tmp_path / "config.arc.yaml").write_text("x: 1\n")
(tmp_path / "config.yaml").write_text("x: 2\n")
result = resolve_config_path(None)
assert result is not None
assert result.name == "config.arc.yaml"
def test_resolve_config_falls_back_to_config_yaml(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.chdir(tmp_path)
(tmp_path / "config.yaml").write_text("x: 1\n")
result = resolve_config_path(None)
assert result is not None
assert result.name == "config.yaml"
def test_resolve_config_returns_none_when_missing(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.chdir(tmp_path)
result = resolve_config_path(None)
assert result is None
def test_resolve_config_explicit_path_no_search() -> None:
result = resolve_config_path("/some/explicit/path.yaml")
assert result is not None
assert str(result) == "/some/explicit/path.yaml"
# --- cmd_init tests ---
def _write_example_config(path: Path) -> None:
path.write_text(
"""\
project:
name: "my-research"
llm:
provider: "openai"
base_url: "https://api.openai.com/v1"
api_key_env: "OPENAI_API_KEY"
primary_model: "gpt-4o"
fallback_models:
- "gpt-4.1"
- "gpt-4o-mini"
""",
encoding="utf-8",
)
def test_cmd_init_creates_config(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
monkeypatch.chdir(tmp_path)
_write_example_config(tmp_path / "config.researchclaw.example.yaml")
# Simulate non-TTY (stdin not a tty) → defaults to openai
monkeypatch.setattr("sys.stdin", type("FakeStdin", (), {"isatty": lambda self: False})())
args = argparse.Namespace(force=False)
code = rc_cli.cmd_init(args)
assert code == 0
created = tmp_path / "config.arc.yaml"
assert created.exists()
content = created.read_text()
assert 'provider: "openai"' in content
assert "Created config.arc.yaml" in capsys.readouterr().out
def test_cmd_init_refuses_overwrite(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
monkeypatch.chdir(tmp_path)
_write_example_config(tmp_path / "config.researchclaw.example.yaml")
(tmp_path / "config.arc.yaml").write_text("existing\n")
args = argparse.Namespace(force=False)
code = rc_cli.cmd_init(args)
assert code == 1
assert "already exists" in capsys.readouterr().err
assert (tmp_path / "config.arc.yaml").read_text() == "existing\n"
def test_cmd_init_force_overwrites(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.chdir(tmp_path)
_write_example_config(tmp_path / "config.researchclaw.example.yaml")
(tmp_path / "config.arc.yaml").write_text("old\n")
monkeypatch.setattr("sys.stdin", type("FakeStdin", (), {"isatty": lambda self: False})())
args = argparse.Namespace(force=True)
code = rc_cli.cmd_init(args)
assert code == 0
assert (tmp_path / "config.arc.yaml").read_text() != "old\n"
def test_cmd_run_missing_config_shows_init_hint(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
monkeypatch.chdir(tmp_path)
args = argparse.Namespace(
config=None,
topic=None,
output=None,
from_stage=None,
auto_approve=False,
skip_preflight=True,
resume=False,
skip_noncritical_stage=False,
)
code = rc_cli.cmd_run(args)
assert code == 1
assert "researchclaw init" in capsys.readouterr().err
def test_resume_finds_existing_checkpoint_dir(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
"""BUG-119: --resume without --output should find the latest checkpoint dir."""
import hashlib
import json
monkeypatch.chdir(tmp_path)
# Write a valid config
config_path = tmp_path / "config.arc.yaml"
_write_valid_config(config_path)
# Create a fake previous run directory with a checkpoint
topic = "Synthetic benchmark research" # matches _write_valid_config
topic_hash = hashlib.sha256(topic.encode()).hexdigest()[:6]
old_run_dir = tmp_path / "artifacts" / f"rc-20260319-100000-{topic_hash}"
old_run_dir.mkdir(parents=True)
(old_run_dir / "checkpoint.json").write_text(
json.dumps({"last_completed_stage": 5, "last_completed_name": "HYPOTHESIS_GEN",
"run_id": old_run_dir.name, "timestamp": "2026-03-19T10:00:00Z"})
)
# Mock execute_pipeline so we don't actually run
import researchclaw.pipeline.runner as runner_mod
monkeypatch.setattr(runner_mod, "execute_pipeline", lambda **kw: [])
# Also mock preflight
from unittest.mock import MagicMock
mock_client = MagicMock()
mock_client.preflight.return_value = (True, "OK")
import researchclaw.llm as llm_mod
monkeypatch.setattr(llm_mod, "create_llm_client", lambda cfg: mock_client)
args = argparse.Namespace(
config=str(config_path),
topic=None,
output=None,
from_stage=None,
auto_approve=False,
skip_preflight=True,
resume=True,
skip_noncritical_stage=False,
no_graceful_degradation=False,
)
rc_cli.cmd_run(args)
captured = capsys.readouterr()
assert "Found existing run to resume" in captured.out
assert old_run_dir.name in captured.out
def test_resume_no_checkpoint_warns(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
"""BUG-119: --resume with no matching checkpoint should warn and start new."""
monkeypatch.chdir(tmp_path)
config_path = tmp_path / "config.arc.yaml"
_write_valid_config(config_path)
# Create empty artifacts dir (no checkpoints)
(tmp_path / "artifacts").mkdir()
import researchclaw.pipeline.runner as runner_mod
monkeypatch.setattr(runner_mod, "execute_pipeline", lambda **kw: [])
from unittest.mock import MagicMock
mock_client = MagicMock()
mock_client.preflight.return_value = (True, "OK")
import researchclaw.llm as llm_mod
monkeypatch.setattr(llm_mod, "create_llm_client", lambda cfg: mock_client)
args = argparse.Namespace(
config=str(config_path),
topic=None,
output=None,
from_stage=None,
auto_approve=False,
skip_preflight=True,
resume=True,
skip_noncritical_stage=False,
no_graceful_degradation=False,
)
rc_cli.cmd_run(args)
captured = capsys.readouterr()
assert "no checkpoint found" in captured.err
def test_main_dispatches_init(monkeypatch: pytest.MonkeyPatch) -> None:
captured = {}
def fake_cmd_init(args):
captured["args"] = args
return 0
monkeypatch.setattr(rc_cli, "cmd_init", fake_cmd_init)
code = rc_cli.main(["init", "--force"])
assert code == 0
assert captured["args"].force is True
================================================
FILE: tests/test_rc_config.py
================================================
import json
from pathlib import Path
from typing import cast
import pytest
from researchclaw.config import (
ExperimentConfig,
RCConfig,
SandboxConfig,
SecurityConfig,
ValidationResult,
load_config,
validate_config,
)
def _write_valid_config(tmp_path: Path) -> Path:
kb_root = tmp_path / "docs" / "kb"
for name in (
"questions",
"literature",
"experiments",
"findings",
"decisions",
"reviews",
):
(kb_root / name).mkdir(parents=True, exist_ok=True)
config_path = tmp_path / "config.rc.yaml"
_ = config_path.write_text(
"""
project:
name: demo
mode: docs-first
research:
topic: Test topic
domains: [ml, agents]
runtime:
timezone: America/New_York
notifications:
channel: discord
knowledge_base:
backend: markdown
root: docs/kb
openclaw_bridge:
use_cron: true
use_message: true
use_memory: true
use_sessions_spawn: true
use_web_fetch: true
use_browser: false
llm:
provider: openai-compatible
base_url: https://example.invalid/v1
api_key_env: OPENAI_API_KEY
security:
hitl_required_stages: [5, 9, 20]
experiment:
mode: simulated
""".strip()
+ "\n",
encoding="utf-8",
)
return config_path
def _valid_config_data() -> dict[str, dict[str, object]]:
return {
"project": {"name": "demo", "mode": "docs-first"},
"research": {"topic": "Test topic", "domains": ["ml", "agents"]},
"runtime": {"timezone": "America/New_York"},
"notifications": {"channel": "discord"},
"knowledge_base": {"backend": "markdown", "root": "docs/kb"},
"openclaw_bridge": {
"use_cron": True,
"use_message": True,
"use_memory": True,
"use_sessions_spawn": True,
"use_web_fetch": True,
"use_browser": False,
},
"llm": {
"provider": "openai-compatible",
"base_url": "https://example.invalid/v1",
"api_key_env": "OPENAI_API_KEY",
"primary_model": "gpt-4.1",
"fallback_models": ["gpt-4o-mini", "gpt-4o"],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "simulated",
"metric_direction": "minimize",
},
}
def test_valid_config_data_helper_returns_expected_baseline_shape():
data = _valid_config_data()
assert data["project"]["name"] == "demo"
assert data["knowledge_base"]["root"] == "docs/kb"
assert data["security"]["hitl_required_stages"] == [5, 9, 20]
def test_validate_config_with_valid_data_returns_ok_true(tmp_path: Path):
result = validate_config(
_valid_config_data(), project_root=tmp_path, check_paths=False
)
assert isinstance(result, ValidationResult)
assert result.ok is True
assert result.errors == ()
def test_validate_config_missing_required_fields_returns_errors(tmp_path: Path):
data = _valid_config_data()
data["research"] = {}
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert "Missing required field: research.topic" in result.errors
def test_validate_config_rejects_invalid_project_mode(tmp_path: Path):
data = _valid_config_data()
data["project"]["mode"] = "invalid-mode"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert "Invalid project.mode: invalid-mode" in result.errors
def test_validate_config_rejects_invalid_knowledge_base_backend(tmp_path: Path):
data = _valid_config_data()
data["knowledge_base"]["backend"] = "sqlite"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert "Invalid knowledge_base.backend: sqlite" in result.errors
@pytest.mark.parametrize("entry", [0, 24, "5", 9.1])
def test_validate_config_rejects_invalid_hitl_required_stages_entries(
tmp_path: Path, entry: object
):
data = _valid_config_data()
data["security"]["hitl_required_stages"] = [5, entry, 20]
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert f"Invalid security.hitl_required_stages entry: {entry}" in result.errors
def test_validate_config_rejects_non_list_hitl_required_stages(tmp_path: Path):
data = _valid_config_data()
data["security"]["hitl_required_stages"] = "5,9,20"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert "security.hitl_required_stages must be a list" in result.errors
def test_validate_config_rejects_invalid_experiment_mode(tmp_path: Path):
data = _valid_config_data()
data["experiment"]["mode"] = "kubernetes"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert "Invalid experiment.mode: kubernetes" in result.errors
def test_validate_config_accepts_docker_mode(tmp_path: Path):
data = _valid_config_data()
data["experiment"]["mode"] = "docker"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is True
def test_validate_config_rejects_invalid_metric_direction(tmp_path: Path):
data = _valid_config_data()
data["experiment"]["metric_direction"] = "upward"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is False
assert "Invalid experiment.metric_direction: upward" in result.errors
def test_rcconfig_from_dict_happy_path(tmp_path: Path):
config = RCConfig.from_dict(
_valid_config_data(),
project_root=tmp_path,
check_paths=False,
)
assert isinstance(config, RCConfig)
assert config.project.name == "demo"
assert config.research.domains == ("ml", "agents")
assert config.llm.fallback_models == ("gpt-4o-mini", "gpt-4o")
def test_rcconfig_from_dict_missing_fields_raises_value_error(tmp_path: Path):
data = _valid_config_data()
del data["runtime"]
with pytest.raises(ValueError, match="Missing required field: runtime.timezone"):
_ = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
def test_rcconfig_load_from_yaml_file(tmp_path: Path):
config_path = _write_valid_config(tmp_path)
config = RCConfig.load(config_path, project_root=tmp_path)
assert isinstance(config, RCConfig)
assert config.project.name == "demo"
assert config.knowledge_base.root == "docs/kb"
def test_load_config_wrapper_returns_rcconfig(tmp_path: Path):
config_path = _write_valid_config(tmp_path)
config = load_config(config_path, project_root=tmp_path)
assert isinstance(config, RCConfig)
assert config.security.hitl_required_stages == (5, 9, 20)
def test_security_config_defaults_match_expected_values():
defaults = SecurityConfig()
assert defaults.hitl_required_stages == (5, 9, 20)
assert defaults.allow_publish_without_approval is False
assert defaults.redact_sensitive_logs is True
def test_experiment_config_defaults_mode_is_simulated():
defaults = ExperimentConfig()
assert defaults.mode == "simulated"
assert defaults.metric_direction == "minimize"
def test_sandbox_config_defaults_match_expected_values():
from researchclaw.config import DEFAULT_PYTHON_PATH
defaults = SandboxConfig()
assert defaults.python_path == DEFAULT_PYTHON_PATH
assert defaults.gpu_required is False
assert defaults.max_memory_mb == 4096
assert "numpy" in defaults.allowed_imports
def test_to_dict_roundtrip_rehydrates_equivalent_rcconfig(tmp_path: Path):
original = RCConfig.from_dict(
_valid_config_data(),
project_root=tmp_path,
check_paths=False,
)
normalized = cast(dict[str, object], json.loads(json.dumps(original.to_dict())))
rehydrated = RCConfig.from_dict(
normalized,
project_root=tmp_path,
check_paths=False,
)
assert rehydrated == original
assert isinstance(original.to_dict()["security"]["hitl_required_stages"], tuple)
def test_check_paths_false_skips_missing_kb_root_validation(tmp_path: Path):
data = _valid_config_data()
data["knowledge_base"]["root"] = "docs/missing-kb"
result = validate_config(data, project_root=tmp_path, check_paths=False)
assert result.ok is True
assert not any(error.startswith("Missing path:") for error in result.errors)
def test_path_validation_missing_kb_root_is_error(tmp_path: Path):
result = validate_config(
_valid_config_data(), project_root=tmp_path, check_paths=True
)
assert result.ok is False
assert any(error.startswith("Missing path:") for error in result.errors)
def test_validate_config_missing_kb_subdirs_emits_warnings(tmp_path: Path):
data = _valid_config_data()
_ = (tmp_path / "docs" / "kb").mkdir(parents=True)
result = validate_config(data, project_root=tmp_path, check_paths=True)
assert result.ok is True
assert len(result.warnings) == 6
assert all(
warning.startswith("Missing recommended kb subdir:")
for warning in result.warnings
)
def test_rcconfig_from_dict_uses_default_security_when_missing(tmp_path: Path):
data = _valid_config_data()
del data["security"]
config = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
assert config.security.hitl_required_stages == (5, 9, 20)
def test_load_uses_file_parent_as_default_project_root(tmp_path: Path):
config_path = _write_valid_config(tmp_path)
config = RCConfig.load(config_path)
assert config.project.name == "demo"
assert config.knowledge_base.root == "docs/kb"
================================================
FILE: tests/test_rc_contracts.py
================================================
import re
import pytest
from researchclaw.pipeline.contracts import CONTRACTS, StageContract
from researchclaw.pipeline.stages import GATE_STAGES, STAGE_SEQUENCE, Stage
def test_contracts_dict_has_exactly_23_entries():
assert len(CONTRACTS) == 23
def test_every_stage_has_matching_contract_entry():
assert set(CONTRACTS.keys()) == set(Stage)
@pytest.mark.parametrize("stage", STAGE_SEQUENCE)
def test_each_stage_member_resolves_to_stage_contract(stage: Stage):
assert isinstance(CONTRACTS[stage], StageContract)
@pytest.mark.parametrize("stage,contract", tuple(CONTRACTS.items()))
def test_contract_stage_field_matches_dict_key(stage: Stage, contract: StageContract):
assert contract.stage is stage
@pytest.mark.parametrize("contract", tuple(CONTRACTS.values()))
def test_output_files_is_non_empty_for_all_contracts(contract: StageContract):
assert contract.output_files
@pytest.mark.parametrize("stage,contract", tuple(CONTRACTS.items()))
def test_error_code_starts_with_e_and_contains_stage_number(
stage: Stage, contract: StageContract
):
assert contract.error_code.startswith("E")
assert f"{int(stage):02d}" in contract.error_code
assert re.match(r"^E\d{2}_[A-Z0-9_]+$", contract.error_code)
@pytest.mark.parametrize("contract", tuple(CONTRACTS.values()))
def test_max_retries_is_non_negative_for_all_contracts(contract: StageContract):
assert contract.max_retries >= 0
def test_gate_stages_have_expected_max_retries():
assert CONTRACTS[Stage.LITERATURE_SCREEN].max_retries == 0
assert CONTRACTS[Stage.EXPERIMENT_DESIGN].max_retries == 0
assert CONTRACTS[Stage.QUALITY_GATE].max_retries == 0
@pytest.mark.parametrize("stage", tuple(GATE_STAGES))
def test_gate_stage_contracts_are_never_retried(stage: Stage):
assert CONTRACTS[stage].max_retries == 0
def test_topic_init_contract_has_expected_input_output_files():
contract = CONTRACTS[Stage.TOPIC_INIT]
assert contract.input_files == ()
assert contract.output_files == ("goal.md", "hardware_profile.json")
def test_export_publish_contract_has_expected_outputs():
contract = CONTRACTS[Stage.EXPORT_PUBLISH]
assert contract.output_files == ("paper_final.md", "code/")
@pytest.mark.parametrize("contract", tuple(CONTRACTS.values()))
def test_dod_is_non_empty_string_for_all_contracts(contract: StageContract):
assert isinstance(contract.dod, str)
assert contract.dod.strip()
@pytest.mark.parametrize("contract", tuple(CONTRACTS.values()))
def test_input_files_is_tuple_of_strings(contract: StageContract):
assert isinstance(contract.input_files, tuple)
assert all(isinstance(path, str) and path for path in contract.input_files)
@pytest.mark.parametrize("contract", tuple(CONTRACTS.values()))
def test_output_files_is_tuple_of_strings(contract: StageContract):
assert isinstance(contract.output_files, tuple)
assert all(isinstance(path, str) and path for path in contract.output_files)
def test_error_codes_are_unique_across_contracts():
all_codes = [contract.error_code for contract in CONTRACTS.values()]
assert len(all_codes) == len(set(all_codes))
def test_contracts_follow_stage_sequence_order():
assert tuple(CONTRACTS.keys()) == STAGE_SEQUENCE
@pytest.mark.parametrize("stage", STAGE_SEQUENCE)
def test_contract_stage_int_matches_stage_enum_value(stage: Stage):
assert int(CONTRACTS[stage].stage) == int(stage)
================================================
FILE: tests/test_rc_docker_sandbox.py
================================================
"""Tests for DockerSandbox — all mocked, no real Docker needed."""
from __future__ import annotations
import subprocess
import threading
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.config import DockerSandboxConfig, ExperimentConfig
from researchclaw.experiment.docker_sandbox import DockerSandbox, _next_container_name
from researchclaw.experiment.factory import create_sandbox
from researchclaw.experiment.sandbox import SandboxResult
# ── SandboxResult contract ─────────────────────────────────────────────
def test_sandbox_result_fields():
r = SandboxResult(
returncode=0,
stdout="primary_metric: 0.95\n",
stderr="",
elapsed_sec=1.2,
metrics={"primary_metric": 0.95},
timed_out=False,
)
assert r.returncode == 0
assert r.metrics["primary_metric"] == 0.95
assert r.timed_out is False
# ── DockerSandbox command building ─────────────────────────────────────
def test_build_run_command_network_none(tmp_path: Path):
"""network_policy='none' → --network none, --user UID:GID."""
cfg = DockerSandboxConfig(network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
cmd = sandbox._build_run_command(
tmp_path / "staging",
entry_point="main.py",
container_name="rc-test-1",
)
assert "docker" in cmd
assert "--gpus" in cmd
assert "--network" in cmd
assert "none" in cmd
assert "--memory=8192m" in cmd
assert "--shm-size=2048m" in cmd
assert cmd[-1] == "main.py"
# Should contain --user (non-root)
assert "--user" in cmd
def test_build_run_command_setup_only(tmp_path: Path):
"""Default network_policy='setup_only' → RC_SETUP_ONLY_NETWORK=1, --cap-add."""
cfg = DockerSandboxConfig() # default is setup_only
sandbox = DockerSandbox(cfg, tmp_path / "work")
cmd = sandbox._build_run_command(
tmp_path / "staging",
entry_point="main.py",
container_name="rc-test-setup",
)
# Should set env var for setup-only network
assert "-e" in cmd
env_idx = [i for i, x in enumerate(cmd) if x == "-e"]
env_values = [cmd[i + 1] for i in env_idx]
assert "RC_SETUP_ONLY_NETWORK=1" in env_values
# Should add NET_ADMIN capability
assert "--cap-add=NET_ADMIN" in cmd
# Should NOT have --network none (needs network for setup)
network_indices = [i for i, x in enumerate(cmd) if x == "--network"]
assert len(network_indices) == 0
# Should have --user (runs as host user so experiment can write results.json)
assert "--user" in cmd
def test_build_run_command_full_network(tmp_path: Path):
"""network_policy='full' → no --network none, has --user."""
cfg = DockerSandboxConfig(network_policy="full")
sandbox = DockerSandbox(cfg, tmp_path / "work")
cmd = sandbox._build_run_command(
tmp_path / "staging",
entry_point="main.py",
container_name="rc-test-full",
)
# No --network none
network_indices = [i for i, x in enumerate(cmd) if x == "--network"]
assert len(network_indices) == 0
# Should have --user (non-root)
assert "--user" in cmd
def test_build_run_command_no_gpu(tmp_path: Path):
cfg = DockerSandboxConfig(gpu_enabled=False, network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
cmd = sandbox._build_run_command(
tmp_path / "staging",
entry_point="main.py",
container_name="rc-test-2",
)
assert "--gpus" not in cmd
def test_build_run_command_specific_gpus(tmp_path: Path):
cfg = DockerSandboxConfig(gpu_device_ids=(0, 2), network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
cmd = sandbox._build_run_command(
tmp_path / "staging",
entry_point="main.py",
container_name="rc-test-3",
)
assert "--gpus" in cmd
gpu_idx = cmd.index("--gpus")
assert "0,2" in cmd[gpu_idx + 1]
# ── Harness injection ─────────────────────────────────────────────────
def test_harness_injection(tmp_path: Path):
harness_src = Path(__file__).parent.parent / "researchclaw" / "experiment" / "harness_template.py"
if not harness_src.exists():
pytest.skip("harness_template.py not found")
target = tmp_path / "project"
target.mkdir()
DockerSandbox._inject_harness(target)
assert (target / "experiment_harness.py").exists()
# ── Factory ────────────────────────────────────────────────────────────
def test_factory_returns_experiment_sandbox(tmp_path: Path):
from researchclaw.experiment.sandbox import ExperimentSandbox
config = ExperimentConfig(mode="sandbox")
sandbox = create_sandbox(config, tmp_path / "work")
assert isinstance(sandbox, ExperimentSandbox)
@patch("researchclaw.experiment.docker_sandbox.DockerSandbox.ensure_image", return_value=True)
@patch("researchclaw.experiment.docker_sandbox.DockerSandbox.check_docker_available", return_value=True)
def test_factory_returns_docker_sandbox(mock_avail, mock_image, tmp_path: Path):
config = ExperimentConfig(mode="docker")
sandbox = create_sandbox(config, tmp_path / "work")
assert isinstance(sandbox, DockerSandbox)
@patch("researchclaw.experiment.docker_sandbox.DockerSandbox.check_docker_available", return_value=False)
def test_factory_falls_back_when_docker_unavailable(mock_avail, tmp_path: Path):
config = ExperimentConfig(mode="docker")
sandbox = create_sandbox(config, tmp_path / "work")
# BUG-002: Should fall back to subprocess sandbox instead of raising
from researchclaw.experiment.sandbox import ExperimentSandbox
assert isinstance(sandbox, ExperimentSandbox)
@patch("researchclaw.experiment.docker_sandbox.DockerSandbox.ensure_image", return_value=False)
@patch("researchclaw.experiment.docker_sandbox.DockerSandbox.check_docker_available", return_value=True)
def test_factory_raises_when_image_missing(mock_avail, mock_image, tmp_path: Path):
config = ExperimentConfig(mode="docker")
with pytest.raises(RuntimeError, match="not found locally"):
create_sandbox(config, tmp_path / "work")
# ── run() with mocked subprocess ──────────────────────────────────────
@patch("subprocess.run")
def test_docker_run_success(mock_run, tmp_path: Path):
mock_run.return_value = subprocess.CompletedProcess(
args=["docker", "run"],
returncode=0,
stdout="primary_metric: 0.85\n",
stderr="",
)
cfg = DockerSandboxConfig(network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
result = sandbox.run("print('hello')", timeout_sec=60)
assert result.returncode == 0
assert result.metrics.get("primary_metric") == 0.85
assert result.timed_out is False
@patch("subprocess.run")
def test_docker_run_timeout(mock_run, tmp_path: Path):
mock_run.side_effect = subprocess.TimeoutExpired(cmd="docker run", timeout=10)
cfg = DockerSandboxConfig(network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
result = sandbox.run("import time; time.sleep(999)", timeout_sec=10)
assert result.timed_out is True
assert result.returncode == -1
# ── Dep detection ─────────────────────────────────────────────────────
def test_detect_pip_packages(tmp_path: Path):
(tmp_path / "main.py").write_text(
"import torchdiffeq\nimport numpy\nfrom PIL import Image\n"
)
detected = DockerSandbox._detect_pip_packages(tmp_path)
# torchdiffeq and PIL/Pillow are now in builtin → skipped
# numpy should be skipped (builtin)
assert "numpy" not in detected
assert "torchdiffeq" not in detected
def test_detect_pip_packages_finds_unknown(tmp_path: Path):
"""Unknown packages should be detected."""
(tmp_path / "main.py").write_text(
"import some_new_package\nimport numpy\n"
)
detected = DockerSandbox._detect_pip_packages(tmp_path)
assert "some_new_package" in detected
assert "numpy" not in detected
def test_detect_pip_packages_skips_setup_py(tmp_path: Path):
"""setup.py should not be scanned for experiment deps."""
(tmp_path / "setup.py").write_text("import some_setup_dep\n")
(tmp_path / "main.py").write_text("import numpy\n")
detected = DockerSandbox._detect_pip_packages(tmp_path)
assert "some_setup_dep" not in detected
def test_detect_pip_packages_maps_imports(tmp_path: Path):
"""Known import-to-pip mappings should be applied."""
(tmp_path / "main.py").write_text(
"import cv2\nimport wandb\n"
)
detected = DockerSandbox._detect_pip_packages(tmp_path)
assert "opencv-python" in detected
assert "wandb" in detected
def test_next_container_name_is_thread_safe():
names: list[str] = []
lock = threading.Lock()
def worker() -> None:
for _ in range(20):
name = _next_container_name()
with lock:
names.append(name)
threads = [threading.Thread(target=worker) for _ in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert len(names) == 100
assert len(names) == len(set(names))
# ── requirements.txt generation ──────────────────────────────────────
def test_write_requirements_txt_from_auto_detect(tmp_path: Path):
"""Auto-detected packages should be written to requirements.txt."""
staging = tmp_path / "staging"
staging.mkdir()
(staging / "main.py").write_text("import wandb\nimport optuna\n")
cfg = DockerSandboxConfig(auto_install_deps=True)
sandbox = DockerSandbox(cfg, tmp_path / "work")
sandbox._write_requirements_txt(staging)
req_path = staging / "requirements.txt"
assert req_path.exists()
content = req_path.read_text()
assert "wandb" in content
assert "optuna" in content
def test_write_requirements_txt_with_pip_pre_install(tmp_path: Path):
"""pip_pre_install packages should be added to requirements.txt."""
staging = tmp_path / "staging"
staging.mkdir()
(staging / "main.py").write_text("import numpy\n")
cfg = DockerSandboxConfig(pip_pre_install=("einops==0.8.0", "kornia"))
sandbox = DockerSandbox(cfg, tmp_path / "work")
sandbox._write_requirements_txt(staging)
req_path = staging / "requirements.txt"
assert req_path.exists()
content = req_path.read_text()
assert "einops==0.8.0" in content
assert "kornia" in content
def test_write_requirements_txt_respects_existing(tmp_path: Path):
"""If LLM already generated requirements.txt, append only new packages."""
staging = tmp_path / "staging"
staging.mkdir()
(staging / "main.py").write_text("import numpy\n")
(staging / "requirements.txt").write_text("wandb\n")
cfg = DockerSandboxConfig(pip_pre_install=("wandb", "einops"))
sandbox = DockerSandbox(cfg, tmp_path / "work")
sandbox._write_requirements_txt(staging)
content = (staging / "requirements.txt").read_text()
# wandb already in existing file, should not be duplicated
assert content.count("wandb") == 1
# einops should be appended
assert "einops" in content
def test_write_requirements_txt_no_packages(tmp_path: Path):
"""No requirements.txt if no packages needed."""
staging = tmp_path / "staging"
staging.mkdir()
(staging / "main.py").write_text("import numpy\n")
cfg = DockerSandboxConfig()
sandbox = DockerSandbox(cfg, tmp_path / "work")
sandbox._write_requirements_txt(staging)
assert not (staging / "requirements.txt").exists()
# ── Static checks (mocked) ────────────────────────────────────────────
@patch("subprocess.run")
def test_check_docker_available_true(mock_run):
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0)
assert DockerSandbox.check_docker_available() is True
@patch("subprocess.run")
def test_check_docker_available_false(mock_run):
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=1)
assert DockerSandbox.check_docker_available() is False
@patch("subprocess.run", side_effect=FileNotFoundError)
def test_check_docker_available_no_binary(mock_run):
assert DockerSandbox.check_docker_available() is False
@patch("subprocess.run")
def test_ensure_image_true(mock_run):
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0)
assert DockerSandbox.ensure_image("researchclaw/experiment:latest") is True
@patch("subprocess.run")
def test_ensure_image_false(mock_run):
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=1)
assert DockerSandbox.ensure_image("nonexistent:latest") is False
# ── Default config values ────────────────────────────────────────────
def test_default_network_policy_is_setup_only():
"""Default network_policy should be 'setup_only', not 'none'."""
cfg = DockerSandboxConfig()
assert cfg.network_policy == "setup_only"
def test_default_auto_install_deps_enabled():
cfg = DockerSandboxConfig()
assert cfg.auto_install_deps is True
# ── Entry point path traversal validation ─────────────────────────────
@patch("researchclaw.experiment.docker_sandbox.subprocess.run")
def test_run_project_rejects_path_traversal(mock_run: MagicMock, tmp_path: Path):
"""run_project() must reject entry_point with '..' components."""
project = tmp_path / "proj"
project.mkdir()
(project / "main.py").write_text("print('hi')")
cfg = DockerSandboxConfig()
work = tmp_path / "work"
sandbox = DockerSandbox(cfg, work)
# Create escape target so .exists() alone wouldn't catch it
work.mkdir(parents=True, exist_ok=True)
(work / "escape.py").write_text("print('escaped!')")
result = sandbox.run_project(project, entry_point="../escape.py")
assert result.returncode == -1
assert ".." in result.stderr
mock_run.assert_not_called()
@patch("researchclaw.experiment.docker_sandbox.subprocess.run")
def test_run_project_rejects_absolute_path(mock_run: MagicMock, tmp_path: Path):
"""run_project() must reject absolute entry_point paths."""
project = tmp_path / "proj"
project.mkdir()
(project / "main.py").write_text("print('hi')")
cfg = DockerSandboxConfig()
sandbox = DockerSandbox(cfg, tmp_path / "work")
result = sandbox.run_project(project, entry_point="/etc/passwd")
assert result.returncode == -1
assert "relative" in result.stderr.lower() or "absolute" in result.stderr.lower()
mock_run.assert_not_called()
# ── Container cleanup behavior ────────────────────────────────────────
@patch.object(DockerSandbox, "_remove_container")
@patch("subprocess.run")
def test_cleanup_on_normal_exit(mock_run: MagicMock, mock_remove: MagicMock, tmp_path: Path):
"""_remove_container is called on normal successful exit."""
mock_run.return_value = subprocess.CompletedProcess(
args=["docker", "run"], returncode=0, stdout="metric: 1.0\n", stderr="",
)
cfg = DockerSandboxConfig(network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
result = sandbox.run("print('ok')", timeout_sec=60)
assert result.returncode == 0
mock_remove.assert_called_once()
@patch.object(DockerSandbox, "_remove_container")
@patch.object(DockerSandbox, "_kill_container")
@patch("subprocess.run")
def test_cleanup_on_timeout(
mock_run: MagicMock, mock_kill: MagicMock, mock_remove: MagicMock, tmp_path: Path,
):
"""Both _kill_container and _remove_container are called on timeout."""
mock_run.side_effect = subprocess.TimeoutExpired(cmd="docker run", timeout=10)
cfg = DockerSandboxConfig(network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
result = sandbox.run("import time; time.sleep(999)", timeout_sec=10)
assert result.timed_out is True
mock_kill.assert_called_once()
mock_remove.assert_called_once()
@patch.object(DockerSandbox, "_remove_container")
@patch("subprocess.run")
def test_cleanup_on_exception(mock_run: MagicMock, mock_remove: MagicMock, tmp_path: Path):
"""_remove_container is called even when subprocess.run raises an unexpected exception."""
mock_run.side_effect = OSError("Docker daemon not responding")
cfg = DockerSandboxConfig(network_policy="none")
sandbox = DockerSandbox(cfg, tmp_path / "work")
result = sandbox.run("print('hi')", timeout_sec=60)
assert result.returncode == -1
assert "Docker execution error" in result.stderr
mock_remove.assert_called_once()
@patch.object(DockerSandbox, "_remove_container")
@patch.object(DockerSandbox, "_kill_container")
@patch("subprocess.run")
def test_keep_containers_skips_removal(
mock_run: MagicMock, mock_kill: MagicMock, mock_remove: MagicMock, tmp_path: Path,
):
"""When keep_containers=True, _remove_container is never called."""
mock_run.return_value = subprocess.CompletedProcess(
args=["docker", "run"], returncode=0, stdout="", stderr="",
)
cfg = DockerSandboxConfig(network_policy="none", keep_containers=True)
sandbox = DockerSandbox(cfg, tmp_path / "work")
sandbox.run("print('ok')", timeout_sec=60)
mock_remove.assert_not_called()
================================================
FILE: tests/test_rc_e2e_regression.py
================================================
# pyright: reportMissingImports=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportPrivateUsage=false, reportUnknownLambdaType=false
from __future__ import annotations
import json
import urllib.error
from email.message import Message
from pathlib import Path
from unittest.mock import patch
import pytest
class _DummyResponse:
def __init__(self, payload: bytes) -> None:
self._payload: bytes = payload
def read(self) -> bytes:
return self._payload
def __enter__(self) -> _DummyResponse:
return self
def __exit__(self, exc_type, exc, tb) -> None:
_ = exc_type, exc, tb
return None
class TestRateLimitRetry:
def test_s2_429_retries_and_succeeds(self) -> None:
from researchclaw.literature.semantic_scholar import (
_reset_circuit_breaker,
search_semantic_scholar,
)
_reset_circuit_breaker() # ensure clean CB state from prior tests
call_count = 0
def mock_urlopen(req, **kwargs):
_ = kwargs
nonlocal call_count
call_count += 1
if call_count == 1:
raise urllib.error.HTTPError(
req.full_url if hasattr(req, "full_url") else str(req),
429,
"Too Many Requests",
Message(),
None,
)
payload = json.dumps(
{
"data": [
{
"paperId": "abc123",
"title": "Test Paper",
"authors": [{"name": "Smith"}],
"year": 2024,
"abstract": "test abstract",
"venue": "NeurIPS",
"citationCount": 10,
"externalIds": {"DOI": "10.1234/test"},
"url": "https://example.com",
}
]
}
).encode("utf-8")
return _DummyResponse(payload)
with patch("urllib.request.urlopen", side_effect=mock_urlopen):
with patch("time.sleep"):
papers = search_semantic_scholar("test query", limit=5)
assert call_count >= 2
assert len(papers) == 1
def test_s2_persistent_429_exhausts_retries_and_returns_empty(self) -> None:
from researchclaw.literature.semantic_scholar import (
_MAX_RETRIES,
_reset_circuit_breaker,
search_semantic_scholar,
)
_reset_circuit_breaker() # ensure clean CB state from prior tests
call_count = 0
def mock_urlopen(req, **kwargs):
_ = kwargs
nonlocal call_count
call_count += 1
raise urllib.error.HTTPError(
req.full_url if hasattr(req, "full_url") else str(req),
429,
"Too Many Requests",
Message(),
None,
)
with patch("urllib.request.urlopen", side_effect=mock_urlopen):
with patch("time.sleep"):
papers = search_semantic_scholar("test query", limit=5)
assert papers == []
assert call_count == _MAX_RETRIES
class TestDegradationChain:
def test_search_degrades_to_cache_on_api_failure(self, tmp_path: Path) -> None:
from researchclaw.literature.cache import put_cache
from researchclaw.literature.search import search_papers
cached = [
{
"paper_id": "cached-1",
"title": "Cached Paper",
"authors": [],
"year": 2024,
"abstract": "cached",
"venue": "",
"citation_count": 5,
"doi": "",
"arxiv_id": "",
"url": "",
"source": "semantic_scholar",
}
]
put_cache(
"test degradation", "semantic_scholar", 20, cached, cache_base=tmp_path
)
with patch(
"researchclaw.literature.search.search_semantic_scholar",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.search.search_arxiv",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.cache._DEFAULT_CACHE_DIR", tmp_path
):
papers = search_papers("test degradation", limit=20)
assert len(papers) >= 1
assert any(p.title == "Cached Paper" for p in papers)
def test_search_empty_on_total_failure(self, tmp_path: Path) -> None:
from researchclaw.literature.search import search_papers
with patch(
"researchclaw.literature.search.search_openalex",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.search.search_semantic_scholar",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.search.search_arxiv",
side_effect=RuntimeError("API down"),
):
with patch(
"researchclaw.literature.cache._DEFAULT_CACHE_DIR",
tmp_path / "empty-cache",
):
papers = search_papers("no results query", limit=20)
assert papers == []
class TestLLMFallback:
def test_primary_403_forbidden_fallback_succeeds(self) -> None:
from researchclaw.llm.client import LLMClient, LLMConfig, LLMResponse
client = LLMClient(
LLMConfig(
base_url="https://api.example.com/v1",
api_key="test-key",
primary_model="gpt-blocked",
fallback_models=["gpt-fallback"],
max_retries=1,
)
)
call_models: list[str] = []
def mock_raw_call(model, messages, max_tokens, temperature, json_mode):
_ = messages, max_tokens, temperature, json_mode
call_models.append(model)
if model == "gpt-blocked":
raise urllib.error.HTTPError(
"url", 403, "not allowed to use model", Message(), None
)
return LLMResponse(content="ok", model=model)
with patch.object(client, "_raw_call", side_effect=mock_raw_call):
resp = client.chat([{"role": "user", "content": "test"}])
assert resp.content == "ok"
assert "gpt-blocked" in call_models
assert "gpt-fallback" in call_models
def test_preflight_detects_401(self) -> None:
from researchclaw.llm.client import LLMClient, LLMConfig
client = LLMClient(
LLMConfig(
base_url="https://api.example.com/v1",
api_key="bad-key",
primary_model="gpt-test",
fallback_models=[],
max_retries=1,
)
)
if not hasattr(client, "preflight"):
pytest.skip("preflight() not yet implemented")
err = urllib.error.HTTPError("url", 401, "Unauthorized", Message(), None)
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "Invalid API key" in msg
class TestNoncriticalStageSkip:
@staticmethod
def _make_rc_config(tmp_path: Path):
from researchclaw.config import RCConfig
data = {
"project": {"name": "rc-e2e-regression", "mode": "docs-first"},
"research": {"topic": "pipeline regression"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local"},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline",
},
}
return RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
def test_noncritical_stage_failure_is_skipped(self, tmp_path: Path) -> None:
from researchclaw.adapters import AdapterBundle
from researchclaw.pipeline import runner as rc_runner
from researchclaw.pipeline.executor import StageResult
from researchclaw.pipeline.stages import STAGE_SEQUENCE, Stage, StageStatus
run_dir = tmp_path / "run"
run_dir.mkdir()
config = self._make_rc_config(tmp_path)
adapters = AdapterBundle()
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
if stage is Stage.KNOWLEDGE_ARCHIVE:
return StageResult(
stage=stage,
status=StageStatus.FAILED,
artifacts=(),
error="archive error",
)
return StageResult(
stage=stage, status=StageStatus.DONE, artifacts=("ok.md",)
)
with patch.object(rc_runner, "execute_stage", side_effect=mock_execute_stage):
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-skip-noncritical",
config=config,
adapters=adapters,
skip_noncritical=True,
)
assert len(results) == len(STAGE_SEQUENCE)
assert results[-1].stage is Stage.CITATION_VERIFY
assert any(
r.stage is Stage.KNOWLEDGE_ARCHIVE and r.status is StageStatus.FAILED
for r in results
)
def test_critical_stage_failure_still_aborts(self, tmp_path: Path) -> None:
from researchclaw.adapters import AdapterBundle
from researchclaw.pipeline import runner as rc_runner
from researchclaw.pipeline.executor import StageResult
from researchclaw.pipeline.stages import Stage, StageStatus
run_dir = tmp_path / "run-critical"
run_dir.mkdir()
config = self._make_rc_config(tmp_path)
adapters = AdapterBundle()
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
if stage is Stage.PAPER_DRAFT:
return StageResult(
stage=stage,
status=StageStatus.FAILED,
artifacts=(),
error="draft error",
)
return StageResult(
stage=stage, status=StageStatus.DONE, artifacts=("ok.md",)
)
with patch.object(rc_runner, "execute_stage", side_effect=mock_execute_stage):
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-fail-critical",
config=config,
adapters=adapters,
skip_noncritical=True,
)
assert results[-1].stage is Stage.PAPER_DRAFT
assert results[-1].status is StageStatus.FAILED
================================================
FILE: tests/test_rc_evolution.py
================================================
# pyright: reportPrivateUsage=false
"""Tests for the evolution (self-learning) system."""
from __future__ import annotations
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
import pytest
from researchclaw.evolution import (
EvolutionStore,
LessonCategory,
LessonEntry,
extract_lessons,
_classify_error,
_time_weight,
)
# ── LessonEntry tests ──
class TestLessonEntry:
def test_to_dict_and_from_dict_roundtrip(self) -> None:
entry = LessonEntry(
stage_name="hypothesis_gen",
stage_num=8,
category="experiment",
severity="error",
description="Code validation failed",
timestamp="2026-03-10T12:00:00+00:00",
run_id="run-1",
)
data = entry.to_dict()
restored = LessonEntry.from_dict(data)
assert restored.stage_name == "hypothesis_gen"
assert restored.stage_num == 8
assert restored.category == "experiment"
assert restored.severity == "error"
def test_from_dict_handles_missing_fields(self) -> None:
entry = LessonEntry.from_dict({})
assert entry.stage_name == ""
assert entry.stage_num == 0
assert entry.category == "pipeline"
# ── Classification tests ──
class TestClassifyError:
def test_timeout_classified_as_system(self) -> None:
assert _classify_error("experiment_run", "Connection timeout after 30s") == "system"
def test_validation_classified_as_experiment(self) -> None:
assert _classify_error("code_generation", "Syntax error in code") == "experiment"
def test_citation_classified_as_literature(self) -> None:
assert _classify_error("citation_verify", "Hallucinated reference") == "literature"
def test_paper_classified_as_writing(self) -> None:
assert _classify_error("paper_draft", "Draft quality too low") == "writing"
def test_unknown_defaults_to_pipeline(self) -> None:
assert _classify_error("unknown_stage", "something random") == "pipeline"
# ── Time weight tests ──
class TestTimeWeight:
def test_recent_lesson_has_high_weight(self) -> None:
now = datetime.now(timezone.utc).isoformat(timespec="seconds")
assert _time_weight(now) > 0.9
def test_30_day_old_has_half_weight(self) -> None:
ts = (datetime.now(timezone.utc) - timedelta(days=30)).isoformat(timespec="seconds")
weight = _time_weight(ts)
assert 0.4 < weight < 0.6 # Should be ~0.5
def test_90_day_old_returns_zero(self) -> None:
ts = (datetime.now(timezone.utc) - timedelta(days=91)).isoformat(timespec="seconds")
assert _time_weight(ts) == 0.0
def test_invalid_timestamp_returns_zero(self) -> None:
assert _time_weight("not-a-date") == 0.0
def test_empty_timestamp_returns_zero(self) -> None:
assert _time_weight("") == 0.0
# ── Extract lessons tests ──
class TestExtractLessons:
def _make_result(self, stage_num, status, error=None, decision="proceed"):
from types import SimpleNamespace
from researchclaw.pipeline.stages import Stage, StageStatus
stage = Stage(stage_num)
return SimpleNamespace(
stage=stage,
status=StageStatus(status),
error=error,
decision=decision,
)
def test_extracts_lesson_from_failed_stage(self) -> None:
results = [self._make_result(4, "failed", error="API rate limited")]
lessons = extract_lessons(results, run_id="test-run")
assert len(lessons) == 1
assert lessons[0].severity == "error"
assert "rate limited" in lessons[0].description
def test_extracts_lesson_from_blocked_stage(self) -> None:
results = [self._make_result(5, "blocked_approval")]
lessons = extract_lessons(results, run_id="test-run")
assert len(lessons) == 1
assert lessons[0].severity == "warning"
assert "blocked" in lessons[0].description
def test_extracts_lesson_from_pivot_decision(self) -> None:
results = [self._make_result(15, "done", decision="pivot")]
lessons = extract_lessons(results, run_id="test-run")
assert len(lessons) == 1
assert "PIVOT" in lessons[0].description
def test_no_lessons_from_successful_proceed(self) -> None:
results = [self._make_result(1, "done", decision="proceed")]
lessons = extract_lessons(results)
assert len(lessons) == 0
def test_multiple_results_multiple_lessons(self) -> None:
results = [
self._make_result(4, "failed", error="timeout"),
self._make_result(5, "blocked_approval"),
self._make_result(15, "done", decision="refine"),
]
lessons = extract_lessons(results)
assert len(lessons) == 3
def test_extracts_decision_rationale(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
stage_dir = run_dir / "stage-15"
stage_dir.mkdir(parents=True)
(stage_dir / "decision_structured.json").write_text(
json.dumps({"decision": "pivot", "rationale": "NaN in metrics"}),
encoding="utf-8",
)
results = [self._make_result(15, "done", decision="pivot")]
lessons = extract_lessons(results, run_id="test", run_dir=run_dir)
assert any("NaN in metrics" in l.description for l in lessons)
def test_extracts_rationale_from_raw_text_excerpt(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
stage_dir = run_dir / "stage-15"
stage_dir.mkdir(parents=True)
(stage_dir / "decision_structured.json").write_text(
json.dumps({
"decision": "refine",
"raw_text_excerpt": (
"## Decision\n**REFINE**\n\n"
"## Justification\n"
"The analysis provides promising evidence but lacks statistical rigor."
),
"generated": "2026-03-11T05:15:43+00:00",
}),
encoding="utf-8",
)
results = [self._make_result(15, "done", decision="refine")]
lessons = extract_lessons(results, run_id="test", run_dir=run_dir)
assert any("statistical rigor" in l.description for l in lessons)
def test_extracts_stderr_runtime_lesson(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"metrics": {"loss": 0.5},
"stderr": "RuntimeWarning: invalid value encountered in divide",
}),
encoding="utf-8",
)
results = [self._make_result(12, "done")]
lessons = extract_lessons(results, run_dir=run_dir)
assert any("RuntimeWarning" in l.description for l in lessons)
def test_extracts_nan_metric_lesson(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True)
(runs_dir / "run-1.json").write_text(
json.dumps({"metrics": {"accuracy": "nan"}}),
encoding="utf-8",
)
results = [self._make_result(12, "done")]
lessons = extract_lessons(results, run_dir=run_dir)
assert any("accuracy" in l.description and "nan" in l.description.lower()
for l in lessons)
def test_no_runtime_lessons_without_run_dir(self) -> None:
results = [self._make_result(12, "done")]
lessons = extract_lessons(results)
assert len(lessons) == 0
# ── EvolutionStore tests ──
class TestEvolutionStore:
def test_append_and_load(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
lesson = LessonEntry(
stage_name="hypothesis_gen",
stage_num=8,
category="pipeline",
severity="warning",
description="PIVOT triggered",
timestamp=datetime.now(timezone.utc).isoformat(timespec="seconds"),
)
store.append(lesson)
loaded = store.load_all()
assert len(loaded) == 1
assert loaded[0].stage_name == "hypothesis_gen"
def test_append_many(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
lessons = [
LessonEntry("s1", 1, "system", "error", "err1",
datetime.now(timezone.utc).isoformat()),
LessonEntry("s2", 2, "pipeline", "info", "info1",
datetime.now(timezone.utc).isoformat()),
]
store.append_many(lessons)
assert store.count() == 2
def test_append_many_empty_is_noop(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
store.append_many([])
assert store.count() == 0
def test_load_all_empty_store(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
assert store.load_all() == []
def test_query_for_stage_returns_relevant_lessons(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
now = datetime.now(timezone.utc).isoformat(timespec="seconds")
store.append(LessonEntry("hypothesis_gen", 8, "pipeline", "error",
"Failed hypothesis", now))
store.append(LessonEntry("paper_draft", 17, "writing", "warning",
"Draft too short", now))
result = store.query_for_stage("hypothesis_gen", max_lessons=5)
# hypothesis_gen lesson should be boosted
assert len(result) >= 1
assert result[0].stage_name == "hypothesis_gen"
def test_query_respects_max_lessons(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
now = datetime.now(timezone.utc).isoformat(timespec="seconds")
for i in range(10):
store.append(LessonEntry("stage_1", 1, "system", "error",
f"Error {i}", now))
result = store.query_for_stage("stage_1", max_lessons=3)
assert len(result) == 3
def test_build_overlay_returns_empty_for_no_lessons(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
assert store.build_overlay("hypothesis_gen") == ""
def test_build_overlay_returns_formatted_text(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
now = datetime.now(timezone.utc).isoformat(timespec="seconds")
store.append(LessonEntry("hypothesis_gen", 8, "experiment", "error",
"Code syntax error in experiment", now))
overlay = store.build_overlay("hypothesis_gen")
assert "Lessons from Prior Runs" in overlay
assert "Code syntax error" in overlay
assert "❌" in overlay
def test_old_lessons_filtered_by_time_weight(self, tmp_path: Path) -> None:
store = EvolutionStore(tmp_path / "evo")
old_ts = (datetime.now(timezone.utc) - timedelta(days=100)).isoformat()
store.append(LessonEntry("stage_1", 1, "system", "error", "Old error", old_ts))
result = store.query_for_stage("stage_1")
assert len(result) == 0 # Filtered out due to age > 90 days
def test_creates_directory_if_not_exists(self, tmp_path: Path) -> None:
store_dir = tmp_path / "nested" / "evo"
store = EvolutionStore(store_dir)
assert store_dir.exists()
# ── PromptManager evolution overlay integration ──
class TestPromptManagerEvolutionOverlay:
def test_overlay_appended_to_user_prompt(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
overlay = "## Lessons\n1. Avoid timeout errors."
sp = pm.for_stage(
"topic_init",
evolution_overlay=overlay,
topic="test",
domains="ml",
project_name="p1",
quality_threshold="8.0",
)
assert "Avoid timeout errors" in sp.user
def test_no_overlay_when_empty(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp1 = pm.for_stage(
"topic_init",
topic="test",
domains="ml",
project_name="p1",
quality_threshold="8.0",
)
sp2 = pm.for_stage(
"topic_init",
evolution_overlay="",
topic="test",
domains="ml",
project_name="p1",
quality_threshold="8.0",
)
assert sp1.user == sp2.user
================================================
FILE: tests/test_rc_executor.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnusedCallResult=false, reportAttributeAccessIssue=false, reportUnknownLambdaType=false
from __future__ import annotations
import json
import re
import sys
from pathlib import Path
from types import SimpleNamespace
from typing import Any, cast
import pytest
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.pipeline import executor as rc_executor
from researchclaw.pipeline.stages import Stage, StageStatus
class FakeLLMClient:
def __init__(self, response_text: str = "mock response"):
self.response_text: str = response_text
self.calls: list[list[dict[str, str]]] = []
def chat(self, messages: list[dict[str, str]], **kwargs: object):
_ = kwargs
self.calls.append(messages)
from researchclaw.llm.client import LLMResponse
return LLMResponse(content=self.response_text, model="fake-model")
class FakeLLMClientWithConfig(FakeLLMClient):
def __init__(self, response_text: str = "mock response"):
super().__init__(response_text=response_text)
self.config: SimpleNamespace = SimpleNamespace(
base_url="http://fake", api_key="fake-key"
)
@pytest.fixture()
def rc_config(tmp_path: Path) -> RCConfig:
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {
"topic": "test-driven science",
"domains": ["ml", "systems"],
"daily_paper_count": 2,
"quality_threshold": 8.2,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline-test-key",
"primary_model": "fake-model",
"fallback_models": [],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {"mode": "sandbox"},
}
return RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
@pytest.fixture()
def adapters() -> AdapterBundle:
return AdapterBundle()
@pytest.fixture()
def run_dir(tmp_path: Path) -> Path:
path = tmp_path / "run"
path.mkdir()
return path
def _write_prior_artifact(
run_dir: Path, stage_num: int, filename: str, content: str
) -> None:
stage_dir = run_dir / f"stage-{stage_num:02d}"
stage_dir.mkdir(parents=True, exist_ok=True)
(stage_dir / filename).write_text(content, encoding="utf-8")
def test_executor_map_has_23_entries() -> None:
executor_map = getattr(rc_executor, "EXECUTOR_MAP", rc_executor._STAGE_EXECUTORS)
assert len(executor_map) == 23
def test_every_stage_member_has_matching_executor() -> None:
executor_map = getattr(rc_executor, "EXECUTOR_MAP", rc_executor._STAGE_EXECUTORS)
assert set(executor_map.keys()) == set(Stage)
def test_stage_result_dataclass_fields() -> None:
result = rc_executor.StageResult(
stage=Stage.TOPIC_INIT, status=StageStatus.DONE, artifacts=("goal.md",)
)
assert result.stage == Stage.TOPIC_INIT
assert result.status == StageStatus.DONE
assert result.artifacts == ("goal.md",)
assert result.error is None
assert result.decision == "proceed"
assert result.evidence_refs == ()
def test_utcnow_iso_returns_valid_iso_timestamp() -> None:
ts = rc_executor._utcnow_iso()
assert ts.endswith("+00:00")
assert "T" in ts
@pytest.mark.parametrize(
("text", "expected"),
[
("before\n```yaml\na: 1\n```\nafter", "a: 1"),
("```yml\nkey: value\n```", "key: value"),
("```\nplain: true\n```", "plain: true"),
(" x: y ", "x: y"),
],
)
def test_extract_yaml_block_variants(text: str, expected: str) -> None:
assert rc_executor._extract_yaml_block(text) == expected
@pytest.mark.parametrize(
("payload", "default", "expected"),
[
('{"ok": true}', {"fallback": True}, {"ok": True}),
("[1, 2, 3]", {"fallback": True}, [1, 2, 3]),
("not-json", {"fallback": True}, {"fallback": True}),
],
)
def test_safe_json_loads_valid_and_invalid(payload: str, default, expected) -> None:
assert rc_executor._safe_json_loads(payload, default) == expected
@pytest.mark.parametrize(
("raw", "expected"),
[
("a/b", "a_b"),
("a\\b", "a_b"),
("../secret", "__secret"),
("name with spaces!.md", "name_with_spaces_.md"),
("", "unnamed"),
],
)
def test_safe_filename_sanitization(raw: str, expected: str) -> None:
assert rc_executor._safe_filename(raw) == expected
def test_safe_filename_truncates_to_100_chars() -> None:
raw = "x" * 120
cleaned = rc_executor._safe_filename(raw)
assert len(cleaned) == 100
assert cleaned == "x" * 100
def test_build_context_preamble_basic_fields(
rc_config: RCConfig, run_dir: Path
) -> None:
text = rc_executor._build_context_preamble(rc_config, run_dir)
assert "## Research Context" in text
assert "test-driven science" in text
assert "ml, systems" in text
def test_build_context_preamble_includes_selected_prior_artifacts(
rc_config: RCConfig, run_dir: Path
) -> None:
_write_prior_artifact(run_dir, 1, "goal.md", "goal content")
_write_prior_artifact(run_dir, 8, "hypotheses.md", "hyp content")
_write_prior_artifact(run_dir, 7, "synthesis.md", "synth content")
text = rc_executor._build_context_preamble(
rc_config,
run_dir,
include_goal=True,
include_hypotheses=True,
include_synthesis=True,
)
assert "### Goal" in text
assert "goal content" in text
assert "### Hypotheses" in text
assert "hyp content" in text
assert "### Synthesis" in text
assert "synth content" in text
def test_read_prior_artifact_finds_newest_file(run_dir: Path) -> None:
_write_prior_artifact(run_dir, 1, "goal.md", "old")
_write_prior_artifact(run_dir, 3, "goal.md", "new")
found = rc_executor._read_prior_artifact(run_dir, "goal.md")
assert found == "new"
def test_read_prior_artifact_finds_directory_path(run_dir: Path) -> None:
cards_dir = run_dir / "stage-06" / "cards"
cards_dir.mkdir(parents=True)
(cards_dir / "card-1.json").write_text("{}", encoding="utf-8")
found = rc_executor._read_prior_artifact(run_dir, "cards/")
assert found == str(cards_dir)
def test_read_prior_artifact_returns_none_when_not_found(run_dir: Path) -> None:
assert rc_executor._read_prior_artifact(run_dir, "missing.md") is None
def test_read_best_analysis_prefers_best_file(run_dir: Path) -> None:
"""BUG-225: _read_best_analysis prefers analysis_best.md at run root."""
from researchclaw.pipeline._helpers import _read_best_analysis
# Create degenerate analysis in stage-14 and best at run root
s14 = run_dir / "stage-14"
s14.mkdir(parents=True)
(s14 / "analysis.md").write_text("Degenerate analysis", encoding="utf-8")
(run_dir / "analysis_best.md").write_text("Best analysis", encoding="utf-8")
result = _read_best_analysis(run_dir)
assert result == "Best analysis"
def test_read_best_analysis_falls_back_to_prior_artifact(run_dir: Path) -> None:
"""BUG-225: Falls back to _read_prior_artifact when no analysis_best.md."""
from researchclaw.pipeline._helpers import _read_best_analysis
s14 = run_dir / "stage-14"
s14.mkdir(parents=True)
(s14 / "analysis.md").write_text("Only analysis", encoding="utf-8")
result = _read_best_analysis(run_dir)
assert result == "Only analysis"
def test_read_best_analysis_returns_empty_when_none(run_dir: Path) -> None:
"""BUG-225: Returns empty string when no analysis exists at all."""
from researchclaw.pipeline._helpers import _read_best_analysis
result = _read_best_analysis(run_dir)
assert result == ""
def test_write_stage_meta_writes_expected_json(run_dir: Path) -> None:
stage_dir = run_dir / "stage-01"
stage_dir.mkdir()
result = rc_executor.StageResult(
stage=Stage.TOPIC_INIT,
status=StageStatus.DONE,
artifacts=("goal.md",),
decision="proceed",
evidence_refs=("stage-01/goal.md",),
)
rc_executor._write_stage_meta(stage_dir, Stage.TOPIC_INIT, "run-abc", result)
payload = cast(
dict[str, Any],
json.loads((stage_dir / "decision.json").read_text(encoding="utf-8")),
)
assert payload["stage_id"] == "01-topic_init"
assert payload["run_id"] == "run-abc"
assert payload["status"] == "done"
assert payload["decision"] == "proceed"
assert payload["output_artifacts"] == ["goal.md"]
assert payload["evidence_refs"] == ["stage-01/goal.md"]
assert payload["next_stage"] == 2
assert re.match(r"\d{4}-\d{2}-\d{2}T", payload["ts"])
def test_execute_stage_creates_stage_dir_writes_artifacts_and_meta(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
fake_llm = FakeLLMClientWithConfig("# Goal\n\nMocked goal body")
monkeypatch.setattr(
"researchclaw.pipeline.executor.LLMClient.from_rc_config",
lambda _config: fake_llm,
)
result = rc_executor.execute_stage(
Stage.TOPIC_INIT,
run_dir=run_dir,
run_id="run-1",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.DONE
assert "goal.md" in result.artifacts
assert "hardware_profile.json" in result.artifacts
assert (run_dir / "stage-01").is_dir()
assert (
(run_dir / "stage-01" / "goal.md")
.read_text(encoding="utf-8")
.startswith("# Goal")
)
assert (run_dir / "stage-01" / "hardware_profile.json").exists()
assert len(fake_llm.calls) == 1
decision = cast(
dict[str, Any],
json.loads(
(run_dir / "stage-01" / "decision.json").read_text(encoding="utf-8")
),
)
assert decision["run_id"] == "run-1"
assert decision["status"] == "done"
assert decision["output_artifacts"] == ["goal.md", "hardware_profile.json"]
def test_execute_stage_contract_validation_missing_output_file_marks_failed(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
def bad_executor(
_stage_dir: Path,
_run_dir: Path,
_config: RCConfig,
_adapters: AdapterBundle,
*,
llm: object = None,
):
_ = llm
return rc_executor.StageResult(
stage=Stage.TOPIC_INIT, status=StageStatus.DONE, artifacts=("goal.md",)
)
monkeypatch.setitem(rc_executor._STAGE_EXECUTORS, Stage.TOPIC_INIT, bad_executor)
result = rc_executor.execute_stage(
Stage.TOPIC_INIT,
run_dir=run_dir,
run_id="run-2",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.FAILED
assert "Missing or empty output: goal.md" in (result.error or "")
def test_execute_stage_contract_validation_missing_output_directory_marks_failed(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(run_dir, 5, "shortlist.jsonl", '{"title": "x"}')
def bad_executor(
_stage_dir: Path,
_run_dir: Path,
_config: RCConfig,
_adapters: AdapterBundle,
*,
llm: object = None,
):
_ = llm
return rc_executor.StageResult(
stage=Stage.KNOWLEDGE_EXTRACT,
status=StageStatus.DONE,
artifacts=("cards/",),
)
monkeypatch.setitem(
rc_executor._STAGE_EXECUTORS, Stage.KNOWLEDGE_EXTRACT, bad_executor
)
result = rc_executor.execute_stage(
Stage.KNOWLEDGE_EXTRACT,
run_dir=run_dir,
run_id="run-3",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.FAILED
assert "Missing output directory: cards/" in (result.error or "")
def test_execute_stage_missing_required_input_returns_failed(
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
result = rc_executor.execute_stage(
Stage.PROBLEM_DECOMPOSE,
run_dir=run_dir,
run_id="run-4",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.FAILED
assert "Missing input: goal.md" in (result.error or "")
def test_execute_stage_gate_behavior_auto_approve_true_keeps_done(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(run_dir, 4, "candidates.jsonl", '{"title": "paper"}')
def good_executor(
stage_dir: Path,
_run_dir: Path,
_config: RCConfig,
_adapters: AdapterBundle,
*,
llm: object = None,
**_kwargs: object,
):
_ = llm
(stage_dir / "shortlist.jsonl").write_text(
'{"title": "paper"}\n', encoding="utf-8"
)
return rc_executor.StageResult(
stage=Stage.LITERATURE_SCREEN,
status=StageStatus.DONE,
artifacts=("shortlist.jsonl",),
)
monkeypatch.setitem(
rc_executor._STAGE_EXECUTORS, Stage.LITERATURE_SCREEN, good_executor
)
result = rc_executor.execute_stage(
Stage.LITERATURE_SCREEN,
run_dir=run_dir,
run_id="run-5",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.DONE
memory_entries = getattr(adapters.memory, "entries", [])
assert any(
ns == "gates" and "auto-approved" in content for ns, content in memory_entries
)
def test_execute_stage_gate_behavior_auto_approve_false_blocks(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(run_dir, 4, "candidates.jsonl", '{"title": "paper"}')
def good_executor(
stage_dir: Path,
_run_dir: Path,
_config: RCConfig,
_adapters: AdapterBundle,
*,
llm: object = None,
**_kwargs: object,
):
_ = llm
(stage_dir / "shortlist.jsonl").write_text(
'{"title": "paper"}\n', encoding="utf-8"
)
return rc_executor.StageResult(
stage=Stage.LITERATURE_SCREEN,
status=StageStatus.DONE,
artifacts=("shortlist.jsonl",),
)
monkeypatch.setitem(
rc_executor._STAGE_EXECUTORS, Stage.LITERATURE_SCREEN, good_executor
)
result = rc_executor.execute_stage(
Stage.LITERATURE_SCREEN,
run_dir=run_dir,
run_id="run-6",
config=rc_config,
adapters=adapters,
auto_approve_gates=False,
)
assert result.status == StageStatus.BLOCKED_APPROVAL
assert result.decision == "block"
message_calls = getattr(adapters.message, "calls", [])
assert message_calls
assert "Approval required" in message_calls[-1][2]
def test_execute_stage_llm_client_creation_error_falls_back_without_crash(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
def boom(_config: RCConfig):
raise RuntimeError("llm init failed")
monkeypatch.setattr("researchclaw.pipeline.executor.LLMClient.from_rc_config", boom)
result = rc_executor.execute_stage(
Stage.TOPIC_INIT,
run_dir=run_dir,
run_id="run-7",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.DONE
assert (run_dir / "stage-01" / "goal.md").exists()
def test_execute_stage_executor_exception_returns_failed(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
def raising_executor(
_stage_dir: Path,
_run_dir: Path,
_config: RCConfig,
_adapters: AdapterBundle,
*,
llm: object = None,
**_kwargs: object,
):
_ = llm
raise RuntimeError("stage exploded")
monkeypatch.setitem(
rc_executor._STAGE_EXECUTORS, Stage.TOPIC_INIT, raising_executor
)
result = rc_executor.execute_stage(
Stage.TOPIC_INIT,
run_dir=run_dir,
run_id="run-8",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert result.status == StageStatus.FAILED
assert result.decision == "retry"
assert "stage exploded" in (result.error or "")
@pytest.mark.parametrize(
"stage",
[
Stage.TOPIC_INIT,
Stage.PROBLEM_DECOMPOSE,
Stage.SEARCH_STRATEGY,
Stage.LITERATURE_COLLECT,
Stage.LITERATURE_SCREEN,
Stage.KNOWLEDGE_EXTRACT,
Stage.SYNTHESIS,
Stage.HYPOTHESIS_GEN,
Stage.EXPERIMENT_DESIGN,
Stage.CODE_GENERATION,
],
)
def test_stage_executor_mapping_values_are_callable(stage: Stage) -> None:
assert callable(rc_executor._STAGE_EXECUTORS[stage])
class TestStageHealth:
def test_stage_health_json_written(self, tmp_path: Path) -> None:
from researchclaw.pipeline.executor import execute_stage
from researchclaw.pipeline.stages import Stage
config = RCConfig.load(
Path(__file__).parent.parent / "config.researchclaw.example.yaml",
check_paths=False,
)
result = execute_stage(
Stage.TOPIC_INIT,
run_dir=tmp_path,
run_id="test-health",
config=config,
adapters=AdapterBundle(),
auto_approve_gates=True,
)
health_path = tmp_path / "stage-01" / "stage_health.json"
assert result is not None
assert health_path.exists()
def test_stage_health_has_required_fields(self, tmp_path: Path) -> None:
from unittest.mock import MagicMock, patch
from researchclaw.pipeline.executor import execute_stage
from researchclaw.pipeline.stages import Stage
config = RCConfig.load(
Path(__file__).parent.parent / "config.researchclaw.example.yaml",
check_paths=False,
)
with patch("researchclaw.pipeline.executor.LLMClient") as mock_llm_cls:
mock_client = MagicMock()
mock_client.chat.return_value = MagicMock(
content='{"topic": "test", "research_questions": ["q1"]}'
)
mock_llm_cls.from_rc_config.return_value = mock_client
execute_stage(
Stage.TOPIC_INIT,
run_dir=tmp_path,
run_id="test-health-fields",
config=config,
adapters=AdapterBundle(),
auto_approve_gates=True,
)
health_path = tmp_path / "stage-01" / "stage_health.json"
if health_path.exists():
data = json.loads(health_path.read_text(encoding="utf-8"))
assert "stage_id" in data
assert "run_id" in data
assert "duration_sec" in data
assert "status" in data
assert "timestamp" in data
assert data["duration_sec"] >= 0
def test_stage_health_duration_positive(self, tmp_path: Path) -> None:
from unittest.mock import MagicMock, patch
from researchclaw.pipeline.executor import execute_stage
from researchclaw.pipeline.stages import Stage
config = RCConfig.load(
Path(__file__).parent.parent / "config.researchclaw.example.yaml",
check_paths=False,
)
with patch("researchclaw.pipeline.executor.LLMClient") as mock_llm_cls:
mock_client = MagicMock()
mock_client.chat.return_value = MagicMock(
content='{"topic": "test", "sub_problems": []}'
)
mock_llm_cls.from_rc_config.return_value = mock_client
execute_stage(
Stage.TOPIC_INIT,
run_dir=tmp_path,
run_id="test-duration",
config=config,
adapters=AdapterBundle(),
auto_approve_gates=True,
)
health_path = tmp_path / "stage-01" / "stage_health.json"
if health_path.exists():
data = json.loads(health_path.read_text(encoding="utf-8"))
assert data["duration_sec"] >= 0
# Contracts import for Stage 13/22 preservation features.
from researchclaw.pipeline.contracts import CONTRACTS
class TestIterativeRefine:
def _prepare_refine_inputs(self, run_dir: Path) -> None:
_write_prior_artifact(
run_dir,
10,
"experiment.py",
(
"import random\n"
"random.seed(42)\n"
"for i in range(5):\n"
" print(f'val_loss: {0.5 - i*0.05:.4f}')\n"
),
)
(run_dir / "stage-12" / "runs").mkdir(parents=True, exist_ok=True)
_write_prior_artifact(
run_dir,
12,
"runs/run-1.json",
json.dumps(
{
"run_id": "run-1",
"status": "completed",
"metrics": {"val_loss": 0.35},
}
),
)
def test_refine_simulated_mode_skips(
self,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
"""R10-Fix3: Simulated mode should skip iterative refinement entirely."""
self._prepare_refine_inputs(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
# Force simulated mode to test the skip behavior
import copy
sim_cfg = copy.deepcopy(rc_config)
object.__setattr__(sim_cfg.experiment, "mode", "simulated")
result = rc_executor._execute_iterative_refine(
stage_dir,
run_dir,
sim_cfg,
adapters,
llm=None,
)
payload = json.loads(
(stage_dir / "refinement_log.json").read_text(encoding="utf-8")
)
assert payload["skipped"] is True
assert payload["mode"] == "simulated"
assert result.status == StageStatus.DONE
# Original code should be copied as final
assert (stage_dir / "experiment_final.py").exists()
def test_refine_no_llm_saves_original_as_final(
self,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
self._prepare_refine_inputs(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
result = rc_executor._execute_iterative_refine(
stage_dir,
run_dir,
rc_config,
adapters,
llm=None,
)
original_code = (run_dir / "stage-10" / "experiment.py").read_text(
encoding="utf-8"
)
final_code = (stage_dir / "experiment_final.py").read_text(encoding="utf-8")
assert original_code == final_code
payload = json.loads(
(stage_dir / "refinement_log.json").read_text(encoding="utf-8")
)
assert payload["stop_reason"] == "llm_unavailable"
assert result.status == StageStatus.DONE
def test_refine_with_llm_generates_improved_code(
self,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
self._prepare_refine_inputs(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
llm = FakeLLMClient(
"```python\n"
"import random\n"
"random.seed(42)\n"
"for i in range(10):\n"
" print(f'val_loss: {0.4 - i*0.03:.4f}')\n"
"```"
)
rc_executor._execute_iterative_refine(
stage_dir, run_dir, rc_config, adapters, llm=llm
)
assert (stage_dir / "experiment_v1").is_dir()
assert (stage_dir / "experiment_final.py").exists()
payload = json.loads(
(stage_dir / "refinement_log.json").read_text(encoding="utf-8")
)
assert isinstance(payload.get("iterations"), list)
assert payload["iterations"]
def test_refine_converges_after_no_improvement(
self,
tmp_path: Path,
run_dir: Path,
adapters: AdapterBundle,
) -> None:
import sys
self._prepare_refine_inputs(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
sandbox_data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {
"topic": "test-driven science",
"domains": ["ml", "systems"],
"daily_paper_count": 2,
"quality_threshold": 8.2,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline-test-key",
"primary_model": "fake-model",
"fallback_models": [],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 3,
"metric_key": "val_loss",
"metric_direction": "minimize",
"sandbox": {
"python_path": sys.executable,
"gpu_required": False,
"max_memory_mb": 1024,
},
},
}
sandbox_config = RCConfig.from_dict(
sandbox_data,
project_root=tmp_path,
check_paths=False,
)
llm = FakeLLMClient(
"```python\nfor _ in range(3):\n print('val_loss: 0.5000')\n```"
)
rc_executor._execute_iterative_refine(
stage_dir,
run_dir,
sandbox_config,
adapters,
llm=llm,
)
payload = json.loads(
(stage_dir / "refinement_log.json").read_text(encoding="utf-8")
)
assert payload["converged"] is True
assert payload["stop_reason"] == "no_improvement_for_2_iterations"
def test_refine_artifacts_include_version_files(
self,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
self._prepare_refine_inputs(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
llm = FakeLLMClient(
"```python\n"
"import random\n"
"random.seed(42)\n"
"for i in range(10):\n"
" print(f'val_loss: {0.4 - i*0.03:.4f}')\n"
"```"
)
result = rc_executor._execute_iterative_refine(
stage_dir,
run_dir,
rc_config,
adapters,
llm=llm,
)
assert "refinement_log.json" in result.artifacts
assert "experiment_final/" in result.artifacts
assert any(
artifact.startswith("experiment_v") and artifact.endswith("/")
for artifact in result.artifacts
)
def test_refine_sandbox_mode_runs_code(
self,
tmp_path: Path,
run_dir: Path,
adapters: AdapterBundle,
) -> None:
import sys
self._prepare_refine_inputs(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
sandbox_data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {
"topic": "test-driven science",
"domains": ["ml", "systems"],
"daily_paper_count": 2,
"quality_threshold": 8.2,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline-test-key",
"primary_model": "fake-model",
"fallback_models": [],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 3,
"metric_key": "val_loss",
"metric_direction": "minimize",
"sandbox": {
"python_path": sys.executable,
"gpu_required": False,
"max_memory_mb": 1024,
},
},
}
sandbox_config = RCConfig.from_dict(
sandbox_data,
project_root=tmp_path,
check_paths=False,
)
llm = FakeLLMClient(
"```python\n"
"import random\n"
"random.seed(42)\n"
"for i in range(10):\n"
" print(f'val_loss: {0.4 - i*0.03:.4f}')\n"
"```"
)
rc_executor._execute_iterative_refine(
stage_dir,
run_dir,
sandbox_config,
adapters,
llm=llm,
)
payload = json.loads(
(stage_dir / "refinement_log.json").read_text(encoding="utf-8")
)
assert any(
"sandbox" in iteration for iteration in payload.get("iterations", [])
)
class TestExportPublishCodePackage:
def test_export_packages_experiment_final(
self,
tmp_path: Path,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(
run_dir, 19, "paper_revised.md", "# Test Paper\n\nSome content..."
)
_write_prior_artifact(
run_dir,
13,
"experiment_final.py",
'import numpy\nprint("val_loss: 0.1")\n',
)
stage_dir = tmp_path / "run" / "stage-22"
stage_dir.mkdir(parents=True, exist_ok=True)
rc_executor._execute_export_publish(
stage_dir, run_dir, rc_config, adapters, llm=None
)
assert (stage_dir / "code" / "experiment.py").exists()
assert (stage_dir / "code" / "README.md").exists()
req_text = (stage_dir / "code" / "requirements.txt").read_text(encoding="utf-8")
assert "numpy" in req_text
def test_export_falls_back_to_experiment_py(
self,
tmp_path: Path,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(
run_dir, 19, "paper_revised.md", "# Test Paper\n\nSome content..."
)
_write_prior_artifact(
run_dir,
10,
"experiment.py",
'import numpy\nprint("val_loss: 0.1")\n',
)
stage_dir = tmp_path / "run" / "stage-22"
stage_dir.mkdir(parents=True, exist_ok=True)
rc_executor._execute_export_publish(
stage_dir, run_dir, rc_config, adapters, llm=None
)
code_text = (stage_dir / "code" / "experiment.py").read_text(encoding="utf-8")
assert "val_loss: 0.1" in code_text
def test_export_no_experiment_skips_code_dir(
self,
tmp_path: Path,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(
run_dir, 19, "paper_revised.md", "# Test Paper\n\nSome content..."
)
stage_dir = tmp_path / "run" / "stage-22"
stage_dir.mkdir(parents=True, exist_ok=True)
result = rc_executor._execute_export_publish(
stage_dir,
run_dir,
rc_config,
adapters,
llm=None,
)
assert not (stage_dir / "code").exists()
assert "code/" not in result.artifacts
def test_export_detects_multiple_dependencies(
self,
tmp_path: Path,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(
run_dir, 19, "paper_revised.md", "# Test Paper\n\nSome content..."
)
_write_prior_artifact(
run_dir,
13,
"experiment_final.py",
(
"import numpy\n"
"import torch\n"
"from sklearn.metrics import accuracy_score\n"
"print(accuracy_score([1], [1]))\n"
),
)
stage_dir = tmp_path / "run" / "stage-22"
stage_dir.mkdir(parents=True, exist_ok=True)
rc_executor._execute_export_publish(
stage_dir, run_dir, rc_config, adapters, llm=None
)
requirements = (stage_dir / "code" / "requirements.txt").read_text(
encoding="utf-8"
)
assert "numpy" in requirements
assert "torch" in requirements
assert "scikit-learn" in requirements
def test_export_code_readme_contains_title(
self,
tmp_path: Path,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
_write_prior_artifact(
run_dir, 19, "paper_revised.md", "# My Great Paper\n\nSome content..."
)
_write_prior_artifact(
run_dir,
13,
"experiment_final.py",
'print("val_loss: 0.1")\n',
)
stage_dir = tmp_path / "run" / "stage-22"
stage_dir.mkdir(parents=True, exist_ok=True)
rc_executor._execute_export_publish(
stage_dir, run_dir, rc_config, adapters, llm=None
)
readme = (stage_dir / "code" / "README.md").read_text(encoding="utf-8")
assert "My Great Paper" in readme
def test_contracts_stage13_includes_experiment_final() -> None:
assert "experiment_final/" in CONTRACTS[Stage.ITERATIVE_REFINE].output_files
def test_contracts_stage22_includes_code_dir() -> None:
assert "code/" in CONTRACTS[Stage.EXPORT_PUBLISH].output_files
# ── P1-1: Topic keyword extraction tests ──
class TestExtractTopicKeywords:
def test_basic_extraction(self) -> None:
keywords = rc_executor._extract_topic_keywords(
"Agent-based Reinforcement Learning for Automated Scientific Discovery"
)
assert "agent-based" in keywords
assert "reinforcement" in keywords
assert "learning" in keywords
assert "automated" in keywords
assert "scientific" in keywords
assert "discovery" in keywords
# Stop words excluded
# Stop words excluded
assert "for" not in keywords
def test_includes_domain_keywords(self) -> None:
keywords = rc_executor._extract_topic_keywords(
"Neural network pruning", domains=("ml", "optimization")
)
assert "neural" in keywords
assert "network" in keywords
assert "pruning" in keywords
assert "ml" in keywords
assert "optimization" in keywords
def test_deduplication(self) -> None:
keywords = rc_executor._extract_topic_keywords(
"Learning to learn meta-learning", domains=("learning",)
)
assert keywords.count("learning") == 1
def test_empty_topic(self) -> None:
keywords = rc_executor._extract_topic_keywords("")
assert keywords == []
# ── P1-2: Topic constraint block test ──
class TestTopicConstraintBlock:
def test_contains_topic(self) -> None:
block = rc_executor._topic_constraint_block("Transformer attention for time series")
assert "Transformer attention for time series" in block
def test_contains_prohibition(self) -> None:
block = rc_executor._topic_constraint_block("anything")
assert "PROHIBITED" in block
assert "environment" in block.lower()
assert "infrastructure" in block.lower()
def test_hard_constraint_markers(self) -> None:
block = rc_executor._topic_constraint_block("test")
assert "HARD TOPIC CONSTRAINT" in block
assert "END CONSTRAINT" in block
# ── Multi-perspective debate tests ──
class TestParseDecision:
def test_proceed_default(self) -> None:
assert rc_executor._parse_decision("Some random text") == "proceed"
def test_proceed_explicit(self) -> None:
text = "## Decision\nPROCEED\n## Justification\nGood results."
assert rc_executor._parse_decision(text) == "proceed"
def test_pivot_detected(self) -> None:
text = "## Decision\nPIVOT\n## Justification\nHypotheses flawed."
assert rc_executor._parse_decision(text) == "pivot"
def test_refine_detected(self) -> None:
text = "## Decision\nREFINE\n## Justification\nNeed more tuning."
assert rc_executor._parse_decision(text) == "refine"
def test_pivot_case_insensitive(self) -> None:
text = "## Decision\npivot\n## Justification\nBad approach."
assert rc_executor._parse_decision(text) == "pivot"
def test_pivot_takes_priority_over_proceed(self) -> None:
text = "## Decision\nPIVOT\nWe should not PROCEED."
assert rc_executor._parse_decision(text) == "pivot"
def test_decision_in_body_not_heading(self) -> None:
text = "The results suggest we should PIVOT to a new approach."
assert rc_executor._parse_decision(text) == "pivot"
class TestResearchDecisionStructured:
def test_decision_produces_structured_json(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-15"
stage_dir.mkdir(parents=True)
_write_prior_artifact(run_dir, 14, "analysis.md", "# Analysis\nResults ok.")
fake_llm = FakeLLMClient("## Decision\nPROCEED\n## Justification\nGood.")
result = rc_executor._execute_research_decision(
stage_dir, run_dir, rc_config, adapters, llm=fake_llm
)
assert result.decision == "proceed"
assert "decision_structured.json" in result.artifacts
import json
data = json.loads((stage_dir / "decision_structured.json").read_text())
assert data["decision"] == "proceed"
def test_pivot_decision_from_llm(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-15"
stage_dir.mkdir(parents=True)
_write_prior_artifact(run_dir, 14, "analysis.md", "# Analysis\nBad results.")
fake_llm = FakeLLMClient("## Decision\nPIVOT\n## Justification\nFlawed.")
result = rc_executor._execute_research_decision(
stage_dir, run_dir, rc_config, adapters, llm=fake_llm
)
assert result.decision == "pivot"
def test_no_llm_defaults_to_proceed(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-15"
stage_dir.mkdir(parents=True)
result = rc_executor._execute_research_decision(
stage_dir, run_dir, rc_config, adapters, llm=None
)
assert result.decision == "proceed"
class TestMultiPerspectiveGenerate:
def test_generates_all_perspectives(self, tmp_path: Path) -> None:
roles = {
"role_a": {"system": "You are A.", "user": "Do A for {topic}."},
"role_b": {"system": "You are B.", "user": "Do B for {topic}."},
}
fake_llm = FakeLLMClient("perspective output")
perspectives_dir = tmp_path / "perspectives"
result = rc_executor._multi_perspective_generate(
fake_llm, roles, {"topic": "test"}, perspectives_dir
)
assert set(result.keys()) == {"role_a", "role_b"}
assert (perspectives_dir / "role_a.md").exists()
assert (perspectives_dir / "role_b.md").exists()
assert len(fake_llm.calls) == 2
def test_saves_perspective_content(self, tmp_path: Path) -> None:
roles = {"critic": {"system": "Be critical.", "user": "Criticize {topic}."}}
fake_llm = FakeLLMClient("critical analysis here")
perspectives_dir = tmp_path / "perspectives"
rc_executor._multi_perspective_generate(
fake_llm, roles, {"topic": "ml"}, perspectives_dir
)
content = (perspectives_dir / "critic.md").read_text()
assert content == "critical analysis here"
def test_renders_variables_in_prompts(self, tmp_path: Path) -> None:
roles = {"r1": {"system": "Sys for {topic}.", "user": "User for {topic}."}}
fake_llm = FakeLLMClient("ok")
rc_executor._multi_perspective_generate(
fake_llm, roles, {"topic": "RL"}, tmp_path / "p"
)
call = fake_llm.calls[0]
assert "RL" in call[0]["content"]
class TestSynthesizePerspectives:
def test_combines_perspectives(self) -> None:
fake_llm = FakeLLMClient("synthesized result")
pm = rc_executor.PromptManager()
perspectives = {"innovator": "idea A", "contrarian": "idea B"}
result = rc_executor._synthesize_perspectives(
fake_llm, perspectives, "hypothesis_synthesize", pm
)
assert result == "synthesized result"
# Check the user prompt contained both perspectives
call_content = fake_llm.calls[0][0]["content"]
assert "innovator" in call_content
assert "contrarian" in call_content
class TestHypothesisGenDebate:
def test_hypothesis_gen_with_llm_creates_perspectives(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-08"
stage_dir.mkdir(parents=True)
_write_prior_artifact(run_dir, 7, "synthesis.md", "# Synthesis\nGap found.")
fake_llm = FakeLLMClient("## H1\nTest hypothesis")
result = rc_executor._execute_hypothesis_gen(
stage_dir, run_dir, rc_config, adapters, llm=fake_llm
)
assert result.status == StageStatus.DONE
assert "hypotheses.md" in result.artifacts
perspectives_dir = stage_dir / "perspectives"
assert perspectives_dir.exists()
# Should have 3 perspective files (innovator, pragmatist, contrarian)
perspective_files = list(perspectives_dir.glob("*.md"))
assert len(perspective_files) == 3
def test_hypothesis_gen_without_llm_no_perspectives(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-08"
stage_dir.mkdir(parents=True)
_write_prior_artifact(run_dir, 7, "synthesis.md", "# Synthesis\nGap found.")
result = rc_executor._execute_hypothesis_gen(
stage_dir, run_dir, rc_config, adapters, llm=None
)
assert result.status == StageStatus.DONE
assert "hypotheses.md" in result.artifacts
# No perspectives directory when no LLM
assert not (stage_dir / "perspectives").exists()
class TestResultAnalysisDebate:
def test_result_analysis_with_llm_creates_perspectives(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-14"
stage_dir.mkdir(parents=True)
_write_prior_artifact(run_dir, 1, "goal.md", "# Goal\nTest")
_write_prior_artifact(run_dir, 8, "hypotheses.md", "# H1\nTest")
fake_llm = FakeLLMClient("## Analysis\nResults look good.")
result = rc_executor._execute_result_analysis(
stage_dir, run_dir, rc_config, adapters, llm=fake_llm
)
assert result.status == StageStatus.DONE
assert "analysis.md" in result.artifacts
perspectives_dir = stage_dir / "perspectives"
assert perspectives_dir.exists()
# Should have 3 perspective files (optimist, skeptic, methodologist)
perspective_files = list(perspectives_dir.glob("*.md"))
assert len(perspective_files) == 3
def test_result_analysis_without_llm_no_perspectives(
self, tmp_path: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-14"
stage_dir.mkdir(parents=True)
result = rc_executor._execute_result_analysis(
stage_dir, run_dir, rc_config, adapters, llm=None
)
assert result.status == StageStatus.DONE
assert "analysis.md" in result.artifacts
assert not (stage_dir / "perspectives").exists()
class TestParseMetricsFromStdout:
"""Tests for _parse_metrics_from_stdout() helper."""
def test_parses_simple_name_value(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
stdout = "loss: 0.0042\naccuracy: 0.95"
metrics = _parse_metrics_from_stdout(stdout)
assert metrics["loss"] == pytest.approx(0.0042)
assert metrics["accuracy"] == pytest.approx(0.95)
def test_parses_compound_names(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
stdout = "UCB (Stochastic) cumulative_regret: 361.9233\nEXP3 (Adversarial) total_rewards: 13368.4811"
metrics = _parse_metrics_from_stdout(stdout)
assert "UCB (Stochastic) cumulative_regret" in metrics
assert metrics["UCB (Stochastic) cumulative_regret"] == pytest.approx(361.9233)
def test_ignores_non_numeric_lines(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
stdout = "Running experiment...\nloss: 0.5\nDone."
metrics = _parse_metrics_from_stdout(stdout)
assert len(metrics) == 1
assert metrics["loss"] == pytest.approx(0.5)
def test_empty_stdout_returns_empty_dict(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
assert _parse_metrics_from_stdout("") == {}
def test_handles_negative_values(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
stdout = "UCB (Adversarial) cumulative_regret: -3877.5323"
metrics = _parse_metrics_from_stdout(stdout)
assert metrics["UCB (Adversarial) cumulative_regret"] == pytest.approx(-3877.5323)
def test_filters_log_lines(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
stdout = (
"Running experiments for support set size: 1\n"
"Loading model weights: 42\n"
"Training epoch: 5\n"
"loss: 0.123\n"
"accuracy: 0.95\n"
)
metrics = _parse_metrics_from_stdout(stdout)
assert "loss" in metrics
assert "accuracy" in metrics
assert len(metrics) == 2 # log lines should be excluded
def test_filters_long_name_lines(self) -> None:
from researchclaw.pipeline.executor import _parse_metrics_from_stdout
stdout = "this is a very long status message that should not be a metric: 42\n"
metrics = _parse_metrics_from_stdout(stdout)
assert len(metrics) == 0
class TestDetectRuntimeIssues:
"""Tests for _detect_runtime_issues() helper."""
def _make_sandbox_result(
self,
metrics: dict | None = None,
stdout: str = "",
stderr: str = "",
):
from types import SimpleNamespace
return SimpleNamespace(
metrics=metrics or {},
stdout=stdout,
stderr=stderr,
returncode=0,
elapsed_sec=1.0,
timed_out=False,
)
def test_no_issues_returns_empty_string(self) -> None:
r = self._make_sandbox_result(metrics={"loss": 0.5}, stdout="loss: 0.5")
assert rc_executor._detect_runtime_issues(r) == ""
def test_detects_nan_in_metrics(self) -> None:
r = self._make_sandbox_result(metrics={"loss": float("nan")})
result = rc_executor._detect_runtime_issues(r)
assert "NaN" in result
assert "loss" in result
def test_detects_inf_in_metrics(self) -> None:
r = self._make_sandbox_result(metrics={"loss": float("inf")})
result = rc_executor._detect_runtime_issues(r)
assert "Inf" in result
def test_detects_nan_in_stdout(self) -> None:
r = self._make_sandbox_result(stdout="accuracy: nan\nloss: 0.5")
result = rc_executor._detect_runtime_issues(r)
assert "NaN" in result or "nan" in result
def test_detects_runtime_warning_in_stderr(self) -> None:
stderr = (
"optimizers.py:76: RuntimeWarning: invalid value encountered in divide\n"
" directions = np.vstack((directions[1:], new_direction / norm))\n"
)
r = self._make_sandbox_result(stderr=stderr)
result = rc_executor._detect_runtime_issues(r)
assert "RuntimeWarning" in result
assert "invalid value" in result
def test_detects_division_error_in_stderr(self) -> None:
stderr = "ZeroDivisionError: division by zero\n"
r = self._make_sandbox_result(stderr=stderr)
result = rc_executor._detect_runtime_issues(r)
assert "Error" in result
def test_ignores_benign_stderr(self) -> None:
# Non-warning stderr should be ignored
r = self._make_sandbox_result(stderr="Loading module...\nDone.\n")
assert rc_executor._detect_runtime_issues(r) == ""
def test_combined_nan_and_stderr(self) -> None:
r = self._make_sandbox_result(
metrics={"accuracy": float("nan")},
stderr="RuntimeWarning: invalid value\n",
)
result = rc_executor._detect_runtime_issues(r)
assert "NaN" in result
assert "RuntimeWarning" in result
def test_detects_dummy_metric_identical_values(self) -> None:
stdout = (
"UCB (Stochastic) convergence_rate: 1.0000\n"
"UCB (Adversarial) convergence_rate: 1.0000\n"
"Thompson (Stochastic) convergence_rate: 1.0000\n"
"Thompson (Adversarial) convergence_rate: 1.0000\n"
)
r = self._make_sandbox_result(stdout=stdout)
result = rc_executor._detect_runtime_issues(r)
assert "DUMMY" in result
assert "convergence_rate" in result
def test_no_dummy_metric_when_values_differ(self) -> None:
stdout = (
"UCB (Stochastic) regret: 78.5\n"
"Thompson (Stochastic) regret: 121.0\n"
"EpsilonGreedy (Stochastic) regret: 42.1\n"
)
r = self._make_sandbox_result(stdout=stdout)
result = rc_executor._detect_runtime_issues(r)
assert "DUMMY" not in result
class TestRemoveBibtexEntries:
"""Tests for _remove_bibtex_entries() helper."""
def test_removes_specified_keys(self) -> None:
bib = (
'@article{smith2024,\n title={Good Paper},\n author={Smith},\n}\n\n'
'@article{venus2024,\n title={Venus Exploration},\n author={NASA},\n}\n'
)
result = rc_executor._remove_bibtex_entries(bib, {"venus2024"})
assert "smith2024" in result
assert "venus2024" not in result
def test_keeps_all_when_no_match(self) -> None:
bib = '@article{smith2024,\n title={Paper},\n}\n'
result = rc_executor._remove_bibtex_entries(bib, {"other_key"})
assert "smith2024" in result
def test_empty_bib(self) -> None:
assert rc_executor._remove_bibtex_entries("", {"key"}) == ""
class TestRemoveCitationsFromText:
"""Tests for _remove_citations_from_text() helper."""
def test_removes_latex_cite(self) -> None:
text = r"As shown in \cite{venus2024}, the results are..."
result = rc_executor._remove_citations_from_text(text, {"venus2024"})
assert "venus2024" not in result
assert "results are" in result
def test_removes_markdown_cite(self) -> None:
text = "Prior work [venus2024] explored this topic."
result = rc_executor._remove_citations_from_text(text, {"venus2024"})
assert "venus2024" not in result
def test_cleans_multi_cite_comma(self) -> None:
text = r"\cite{good2024,venus2024}"
result = rc_executor._remove_citations_from_text(text, {"venus2024"})
assert r"\cite{good2024}" in result
class TestCollectRawExperimentMetrics:
"""Tests for _collect_raw_experiment_metrics() helper."""
def test_returns_empty_when_no_runs(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
block, has_parsed = rc_executor._collect_raw_experiment_metrics(run_dir)
assert block == ""
assert not has_parsed
def test_extracts_metrics_from_stdout(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True)
payload = {
"metrics": {},
"stdout": "UCB regret: 361.92\nThompson regret: 576.24\n",
}
(runs_dir / "run-1.json").write_text(json.dumps(payload))
result, has_parsed = rc_executor._collect_raw_experiment_metrics(run_dir)
assert "361.92" in result
assert "576.24" in result
assert "1 run(s)" in result
assert not has_parsed
def test_extracts_from_metrics_dict(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True)
payload = {"metrics": {"loss": 0.042, "accuracy": 0.95}, "stdout": ""}
(runs_dir / "run-1.json").write_text(json.dumps(payload))
result, has_parsed = rc_executor._collect_raw_experiment_metrics(run_dir)
assert "loss" in result
assert "0.042" in result
assert has_parsed
def test_deduplicates_metrics(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True)
payload = {
"metrics": {"loss": 0.5},
"stdout": "loss: 0.5\nloss: 0.5\n",
}
(runs_dir / "run-1.json").write_text(json.dumps(payload))
result, _ = rc_executor._collect_raw_experiment_metrics(run_dir)
# "loss: 0.5" should appear only once (deduplicated)
assert result.count("loss: 0.5") == 1
class TestCollectExperimentEvidence:
"""Tests for _collect_experiment_evidence() helper."""
def test_returns_empty_when_no_artifacts(self, tmp_path: Path) -> None:
run_dir = tmp_path / "run"
run_dir.mkdir()
assert rc_executor._collect_experiment_evidence(run_dir) == ""
def test_includes_main_py_code(self, run_dir: Path) -> None:
exp_dir = run_dir / "stage-10" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("print('hello')", encoding="utf-8")
result = rc_executor._collect_experiment_evidence(run_dir)
assert "main.py" in result
assert "hello" in result
def test_includes_run_metrics(self, run_dir: Path) -> None:
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({"metrics": {"loss": 0.5}, "elapsed_sec": 3.2}),
encoding="utf-8",
)
result = rc_executor._collect_experiment_evidence(run_dir)
assert "loss" in result
assert "0.5" in result
def test_includes_stderr_excerpt(self, run_dir: Path) -> None:
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"metrics": {"loss": 0.5},
"stderr": "RuntimeWarning: divide by zero",
}),
encoding="utf-8",
)
result = rc_executor._collect_experiment_evidence(run_dir)
assert "divide by zero" in result
def test_includes_refinement_summary(self, run_dir: Path) -> None:
refine_dir = run_dir / "stage-13"
refine_dir.mkdir(parents=True, exist_ok=True)
(refine_dir / "refinement_log.json").write_text(
json.dumps({
"iterations": [{"iteration": 1}, {"iteration": 2}],
"converged": True,
"stop_reason": "no_improvement_for_2_iterations",
"best_metric": 0.3,
}),
encoding="utf-8",
)
result = rc_executor._collect_experiment_evidence(run_dir)
assert "iterations_executed" in result
assert "2" in result
def test_includes_actual_trial_count(self, run_dir: Path) -> None:
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({"metrics": {"loss": 0.5}}), encoding="utf-8"
)
result = rc_executor._collect_experiment_evidence(run_dir)
assert "1 time(s)" in result
assert "CRITICAL" in result
class TestWritePaperSections:
"""Tests for _write_paper_sections() multi-call writing."""
def test_produces_three_part_draft(self) -> None:
call_count = {"n": 0}
parts = [
"# Test Title\n\n## Abstract\nTest abstract.\n\n## Introduction\nTest intro.\n\n## Related Work\nTest related.",
"## Method\nTest method.\n\n## Experiments\nTest experiments.",
"## Results\nTest results.\n\n## Discussion\nTest discussion.\n\n## Limitations\nTest limits.\n\n## Conclusion\nTest conclusion.",
]
class MultiCallLLM:
def __init__(self):
self.calls: list = []
def chat(self, messages, **kwargs):
self.calls.append(messages)
from researchclaw.llm.client import LLMResponse
idx = len(self.calls) - 1
return LLMResponse(content=parts[min(idx, 2)], model="fake")
llm = MultiCallLLM()
from researchclaw.prompts import PromptManager
pm = PromptManager()
draft = rc_executor._write_paper_sections(
llm=llm,
pm=pm,
preamble="Test preamble",
topic_constraint="",
exp_metrics_instruction="",
citation_instruction="",
outline="Test outline",
)
assert llm.calls is not None
assert len(llm.calls) == 3
assert "## Abstract" in draft
assert "## Method" in draft
assert "## Results" in draft
assert "## Conclusion" in draft
def test_each_call_receives_prior_context(self) -> None:
class ContextTrackingLLM:
def __init__(self):
self.user_prompts: list[str] = []
def chat(self, messages, **kwargs):
for m in messages:
if m.get("role") == "user":
self.user_prompts.append(m["content"])
from researchclaw.llm.client import LLMResponse
return LLMResponse(content="## Section\nContent here.", model="fake")
llm = ContextTrackingLLM()
from researchclaw.prompts import PromptManager
pm = PromptManager()
rc_executor._write_paper_sections(
llm=llm,
pm=pm,
preamble="Preamble",
topic_constraint="",
exp_metrics_instruction="",
citation_instruction="",
outline="Outline",
)
assert len(llm.user_prompts) == 3
# Call 2 and 3 should contain "sections written so far"
assert "sections written so far" in llm.user_prompts[1]
assert "completing a paper" in llm.user_prompts[2]
class TestLoadHardwareProfile:
"""Tests for _load_hardware_profile()."""
@pytest.fixture()
def run_dir(self, tmp_path: Path) -> Path:
d = tmp_path / "run"
d.mkdir()
return d
def test_loads_valid_profile(self, run_dir: Path) -> None:
stage = run_dir / "stage-01"
stage.mkdir()
profile = {"has_gpu": True, "gpu_type": "mps", "tier": "limited"}
(stage / "hardware_profile.json").write_text(
json.dumps(profile), encoding="utf-8"
)
result = rc_executor._load_hardware_profile(run_dir)
assert result is not None
assert result["gpu_type"] == "mps"
def test_returns_none_when_missing(self, run_dir: Path) -> None:
assert rc_executor._load_hardware_profile(run_dir) is None
def test_returns_none_on_invalid_json(self, run_dir: Path) -> None:
stage = run_dir / "stage-01"
stage.mkdir()
(stage / "hardware_profile.json").write_text("not json", encoding="utf-8")
assert rc_executor._load_hardware_profile(run_dir) is None
class TestExpandSearchQueries:
"""Tests for _expand_search_queries()."""
def test_adds_broader_queries(self) -> None:
queries = ["gradient descent optimization algorithms"]
topic = "Comparing gradient descent optimization algorithms on benchmark functions"
result = rc_executor._expand_search_queries(queries, topic)
assert len(result) > len(queries)
def test_deduplicates(self) -> None:
queries = ["gradient descent survey"]
topic = "gradient descent optimization"
result = rc_executor._expand_search_queries(queries, topic)
lowered = [q.lower().strip() for q in result]
assert len(lowered) == len(set(lowered))
def test_preserves_original_queries(self) -> None:
queries = ["query A", "query B"]
topic = "some research topic about machine learning methods"
result = rc_executor._expand_search_queries(queries, topic)
assert result[0] == "query A"
assert result[1] == "query B"
def test_adds_survey_benchmark_variants(self) -> None:
queries = ["deep learning"]
topic = "deep learning for image classification with limited data"
result = rc_executor._expand_search_queries(queries, topic)
has_survey = any("survey" in q.lower() for q in result)
has_benchmark = any("benchmark" in q.lower() for q in result)
assert has_survey
assert has_benchmark
# ── R4-1: Experiment Budget Guard Tests ──────────────────────────────
class TestComputeBudgetBlock:
"""Test compute_budget prompt block injection (R4-1a)."""
def test_compute_budget_block_exists_in_prompt_manager(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
block = pm.block("compute_budget")
assert "time_budget_sec" in block or "Compute Budget" in block
def test_compute_budget_injected_into_code_generation(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
import sys
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {
"topic": "optimizer comparison",
"domains": ["ml"],
"daily_paper_count": 2,
"quality_threshold": 8.2,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline-test-key",
"primary_model": "fake-model",
"fallback_models": [],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 60,
"metric_key": "best_loss",
"metric_direction": "minimize",
"sandbox": {
"python_path": sys.executable,
"gpu_required": False,
"max_memory_mb": 1024,
},
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
# Write exp_plan prior artifact
_write_prior_artifact(run_dir, 10, "exp_plan.yaml", "objectives: test")
# Capture what the LLM receives
llm = FakeLLMClient(
"```filename:main.py\nimport numpy as np\nprint('best_loss: 0.1')\n```"
)
stage_dir = run_dir / "stage-11"
stage_dir.mkdir(parents=True, exist_ok=True)
rc_executor._execute_code_generation(
stage_dir, run_dir, cfg, adapters, llm=llm
)
# The LLM should have received compute budget info in some call
# (may be first call in legacy mode, or second call with CodeAgent)
assert len(llm.calls) >= 1
all_user_msgs = " ".join(
call[-1]["content"] for call in llm.calls if call
)
assert "60" in all_user_msgs or "Compute Budget" in all_user_msgs
class TestPartialTimeoutStatus:
"""Test partial status for timed-out experiments with data (R4-1c)."""
def test_timed_out_with_metrics_sets_partial_status(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
import sys
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {
"topic": "test topic",
"domains": ["ml"],
"daily_paper_count": 2,
"quality_threshold": 8.2,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline-test-key",
"primary_model": "fake-model",
"fallback_models": [],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 2,
"metric_key": "best_loss",
"metric_direction": "minimize",
"sandbox": {
"python_path": sys.executable,
"gpu_required": False,
"max_memory_mb": 1024,
},
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
# Write experiment code that prints some metrics then sleeps
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text(
"import time, sys\n"
"print('best_loss: 0.5', flush=True)\n"
"sys.stdout.flush()\n"
"time.sleep(10)\n",
encoding="utf-8",
)
stage_dir = run_dir / "stage-12"
stage_dir.mkdir(parents=True, exist_ok=True)
rc_executor._execute_experiment_run(
stage_dir, run_dir, cfg, adapters
)
run_file = stage_dir / "runs" / "run-1.json"
assert run_file.exists()
payload = json.loads(run_file.read_text(encoding="utf-8"))
# Should be "partial" since metrics were captured before timeout
assert payload["timed_out"] is True
# Status should be "partial" if metrics captured, "failed" if not
if payload["metrics"]:
assert payload["status"] == "partial"
else:
# Subprocess stdout may not flush before kill on some platforms
assert payload["status"] == "failed"
class TestTimeoutAwareRefine:
"""Test timeout-aware prompt injection in iterative refine (R4-1b)."""
def _prepare_timed_out_run(self, run_dir: Path) -> None:
"""Create a prior run that timed out with no metrics."""
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"task_id": "sandbox-main",
"status": "failed",
"metrics": {},
"timed_out": True,
"elapsed_sec": 120.0,
}),
encoding="utf-8",
)
# Write experiment code
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text(
"print('best_loss: 0.1')\n",
encoding="utf-8",
)
def test_timeout_refine_injects_scale_reduction_prompt(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
self._prepare_timed_out_run(run_dir)
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {
"topic": "test topic",
"domains": ["ml"],
"daily_paper_count": 2,
"quality_threshold": 8.2,
},
"runtime": {"timezone": "UTC"},
"notifications": {
"channel": "local",
"on_stage_start": True,
"on_stage_fail": False,
"on_gate_required": True,
},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline-test-key",
"primary_model": "fake-model",
"fallback_models": [],
},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 120,
"max_iterations": 1,
"metric_key": "best_loss",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
llm = FakeLLMClient(
"```python\nimport numpy as np\nprint('best_loss: 0.1')\n```"
)
rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
# The LLM should have received the timeout-aware prompt
assert len(llm.calls) >= 1
user_msg = llm.calls[0][-1]["content"]
assert "TIMED OUT" in user_msg
assert "120" in user_msg
# ── R4-2: Data Integrity Enforcement Tests ───────────────────────────
class TestDataIntegrityBlock:
"""Test paper draft blocked when no metrics exist (R4-2a)."""
def test_paper_draft_blocked_with_no_metrics(
self, tmp_path: Path, run_dir: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
# Write prior artifacts with NO metrics
_write_prior_artifact(run_dir, 16, "outline.md", "# Outline\n## Abstract\n")
# No experiment_summary.json, no run files with metrics
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({"run_id": "run-1", "status": "failed", "metrics": {}, "timed_out": True}),
encoding="utf-8",
)
stage_dir = run_dir / "stage-17"
stage_dir.mkdir(parents=True, exist_ok=True)
llm = FakeLLMClient("should not be called")
result = rc_executor._execute_paper_draft(
stage_dir, run_dir, rc_config, adapters, llm=llm
)
assert result.status == StageStatus.FAILED
draft = (stage_dir / "paper_draft.md").read_text(encoding="utf-8")
assert "Blocked" in draft or "BLOCKED" in draft or "no metrics" in draft.lower()
# LLM should NOT have been called
assert len(llm.calls) == 0
def test_paper_draft_proceeds_with_metrics(
self, tmp_path: Path, run_dir: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
_write_prior_artifact(run_dir, 16, "outline.md", "# Outline\n## Abstract\n")
# Write experiment data with real metrics
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"status": "completed",
"metrics": {"best_loss": 0.123},
"stdout": "best_loss: 0.123\n",
}),
encoding="utf-8",
)
stage_dir = run_dir / "stage-17"
stage_dir.mkdir(parents=True, exist_ok=True)
llm = FakeLLMClient("# Paper Title\n## Abstract\nSome abstract text.")
result = rc_executor._execute_paper_draft(
stage_dir, run_dir, rc_config, adapters, llm=llm
)
# Should proceed (LLM was called)
assert len(llm.calls) >= 1
# The prompt should contain anti-fabrication instructions
all_prompts = " ".join(
msg["content"] for call in llm.calls for msg in call
)
assert "Data Integrity" in all_prompts or "ONLY report numbers" in all_prompts
# ── R4-3: Conference-Grade Title Guidelines Tests ────────────────────
class TestTitleGuidelines:
"""Test title_guidelines and abstract_structure blocks (R4-3)."""
def test_title_guidelines_block_exists(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
block = pm.block("title_guidelines")
assert "novelty" in block.lower() or "TITLE RULES" in block
assert "14 words" in block or "15 words" in block or "concrete" in block.lower()
def test_abstract_structure_block_exists(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
block = pm.block("abstract_structure")
assert "5-sentence" in block or "problem" in block.lower()
def test_title_guidelines_injected_into_paper_draft(
self, tmp_path: Path, run_dir: Path, rc_config: RCConfig, adapters: AdapterBundle
) -> None:
_write_prior_artifact(run_dir, 16, "outline.md", "# Outline\n")
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({"run_id": "run-1", "status": "completed",
"metrics": {"best_loss": 0.1}, "stdout": "best_loss: 0.1\n"}),
encoding="utf-8",
)
stage_dir = run_dir / "stage-17"
stage_dir.mkdir(parents=True, exist_ok=True)
llm = FakeLLMClient("# Paper Title\n## Abstract\nText.")
rc_executor._execute_paper_draft(
stage_dir, run_dir, rc_config, adapters, llm=llm
)
all_prompts = " ".join(
msg["content"] for call in llm.calls for msg in call
)
assert "Title" in all_prompts or "TITLE" in all_prompts
# ── R4-4: Conference-Grade Writing Quality Tests ─────────────────────
class TestConferenceWritingQuality:
"""Test enhanced writing prompts and writing_guide.py (R4-4)."""
def test_writing_guide_format_all(self) -> None:
from researchclaw.writing_guide import format_writing_tips
result = format_writing_tips()
assert "Conference Writing Best Practices" in result
assert "Title" in result
assert "Common Rejections" in result
def test_writing_guide_format_subset(self) -> None:
from researchclaw.writing_guide import format_writing_tips
result = format_writing_tips(["title", "abstract"])
assert "Title" in result
assert "Abstract" in result
assert "Common Rejections" not in result
def test_paper_draft_system_includes_principles(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"paper_draft",
preamble="test",
topic_constraint="test",
exp_metrics_instruction="test",
citation_instruction="test",
outline="test",
)
# System prompt should mention key principles
assert "NOVELTY" in sp.system or "novelty" in sp.system.lower()
assert "fabricate" in sp.system.lower() or "real experimental" in sp.system.lower()
# ── R5-1 & R5-2: Bug Fixes Tests ────────────────────────────────────
class TestRefineTimeoutAndIterationCap:
"""Test R5-1 (no 120s cap) and R5-2 (iteration cap raised to 10)."""
def test_refine_timeout_uses_full_budget(self) -> None:
"""R5-1: Refine sandbox should NOT cap at 120s."""
import ast
import inspect
source = inspect.getsource(rc_executor._execute_iterative_refine)
tree = ast.parse(source)
source_text = inspect.getsource(rc_executor._execute_iterative_refine)
# Should NOT contain min(..., 120)
assert "min(config.experiment.time_budget_sec, 120)" not in source_text
def test_iteration_cap_is_10(self) -> None:
"""R5-2: Max iterations should be capped at 10, not 3."""
import inspect
source = inspect.getsource(rc_executor._execute_iterative_refine)
assert "min(requested_iterations, 10)" in source
assert "min(requested_iterations, 3)" not in source
def test_refine_respects_high_iteration_count(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
"""R5-2: Setting max_iterations=7 should actually allow 7 iterations."""
# Write prior run artifacts
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({"run_id": "run-1", "status": "completed",
"metrics": {"best_loss": 0.5}}),
encoding="utf-8",
)
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("print('best_loss: 0.5')\n", encoding="utf-8")
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 300,
"max_iterations": 7,
"metric_key": "best_loss",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
# LLM always returns same code — will trigger no_improvement early stop
llm = FakeLLMClient("```python\nprint('best_loss: 0.5')\n```")
rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
log = json.loads((stage_dir / "refinement_log.json").read_text(encoding="utf-8"))
# Should have been allowed more than 3 iterations (capped at 7)
assert log["max_iterations_executed"] == 7
# But may have stopped early due to no_improvement_for_2_iterations
assert len(log["iterations"]) >= 2
# ── R5-3: NaN/Divergence Fast-Fail Tests ────────────────────────────
class TestNaNDivergenceDetection:
"""Test NaN/Inf filtering and divergence detection (R5-3)."""
def test_parse_metrics_filters_nan(self) -> None:
from researchclaw.experiment.sandbox import parse_metrics
stdout = "best_loss: 0.5\nbad_metric: nan\ngood_metric: 1.23\n"
metrics = parse_metrics(stdout)
assert "best_loss" in metrics
assert "good_metric" in metrics
assert "bad_metric" not in metrics # NaN should be filtered
def test_parse_metrics_filters_inf(self) -> None:
from researchclaw.experiment.sandbox import parse_metrics
stdout = "metric_a: inf\nmetric_b: -inf\nmetric_c: 0.42\n"
metrics = parse_metrics(stdout)
assert "metric_c" in metrics
assert "metric_a" not in metrics
assert "metric_b" not in metrics
def test_detect_nan_divergence_finds_nan(self) -> None:
from researchclaw.experiment.sandbox import detect_nan_divergence
result = detect_nan_divergence("loss: nan\nstep 5 done", "")
assert result is not None
assert "NaN" in result or "nan" in result.lower()
def test_detect_nan_divergence_finds_diverging_loss(self) -> None:
from researchclaw.experiment.sandbox import detect_nan_divergence
result = detect_nan_divergence("best_loss: 999.5\n", "")
assert result is not None
assert "loss" in result.lower() or "999" in result
def test_detect_nan_divergence_returns_none_for_clean(self) -> None:
from researchclaw.experiment.sandbox import detect_nan_divergence
result = detect_nan_divergence("best_loss: 0.123\naccuracy: 0.95\n", "")
assert result is None
def test_runtime_issues_detects_diverging_loss(self) -> None:
from types import SimpleNamespace
fake_result = SimpleNamespace(
metrics={"best_loss": 500.0},
stdout="best_loss: 500.0\n",
stderr="",
)
issues = rc_executor._detect_runtime_issues(fake_result)
assert "DIVERGING" in issues or "diverging" in issues.lower()
def test_compute_budget_includes_nan_guard(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
block = pm.block("compute_budget")
assert "NaN" in block or "nan" in block.lower() or "divergence" in block.lower()
# ── R5-4: Experiment Harness Template Tests ──────────────────────────
class TestExperimentHarness:
"""Test the immutable experiment harness (R5-4)."""
def test_harness_should_stop(self) -> None:
from researchclaw.experiment.harness_template import ExperimentHarness
h = ExperimentHarness(time_budget=1)
assert not h.should_stop() # Just created, not at 80% yet
import time
time.sleep(0.9)
assert h.should_stop() # Should be past 80% of 1s
def test_harness_report_metric(self, capsys: pytest.CaptureFixture[str]) -> None:
from researchclaw.experiment.harness_template import ExperimentHarness
h = ExperimentHarness(time_budget=60)
h.report_metric("best_loss", 0.123)
captured = capsys.readouterr()
assert "best_loss: 0.123" in captured.out
assert h._metrics["best_loss"] == 0.123
def test_harness_rejects_nan(self, capsys: pytest.CaptureFixture[str]) -> None:
from researchclaw.experiment.harness_template import ExperimentHarness
h = ExperimentHarness(time_budget=60)
h.report_metric("bad", float("nan"))
captured = capsys.readouterr()
assert "bad" not in h._metrics
assert "non-finite" in captured.err.lower() or "WARNING" in captured.err
def test_harness_rejects_inf(self, capsys: pytest.CaptureFixture[str]) -> None:
from researchclaw.experiment.harness_template import ExperimentHarness
h = ExperimentHarness(time_budget=60)
h.report_metric("bad", float("inf"))
assert "bad" not in h._metrics
def test_harness_finalize(self, tmp_path: Path) -> None:
import os
from researchclaw.experiment.harness_template import ExperimentHarness
old_cwd = os.getcwd()
os.chdir(tmp_path)
try:
h = ExperimentHarness(time_budget=60)
h.report_metric("accuracy", 0.95)
h.report_metric("loss", 0.05)
h.log_result({"condition": "A", "value": 1.0})
h.finalize()
results = json.loads((tmp_path / "results.json").read_text(encoding="utf-8"))
assert results["metrics"]["accuracy"] == 0.95
assert results["metrics"]["loss"] == 0.05
assert len(results["results"]) == 1
finally:
os.chdir(old_cwd)
def test_harness_progress(self) -> None:
from researchclaw.experiment.harness_template import ExperimentHarness
h = ExperimentHarness(time_budget=1000)
assert h.progress < 0.01 # Just started
assert 0.0 <= h.progress <= 1.0
def test_harness_injected_into_sandbox(self, tmp_path: Path) -> None:
import sys
from researchclaw.config import SandboxConfig
from researchclaw.experiment.sandbox import ExperimentSandbox
config = SandboxConfig(python_path=sys.executable)
sandbox = ExperimentSandbox(config, tmp_path / "sandbox")
# Create a project dir
project = tmp_path / "project"
project.mkdir()
(project / "main.py").write_text("print('test: 1.0')\n", encoding="utf-8")
sandbox.run_project(project, timeout_sec=5)
# Check that harness was injected (BUG-DA8-06: dir is now _project_{N})
project_dirs = list((tmp_path / "sandbox").glob("_project_*"))
assert project_dirs, "No _project_N directory found"
harness_path = project_dirs[0] / "experiment_harness.py"
assert harness_path.exists()
content = harness_path.read_text(encoding="utf-8")
assert "ExperimentHarness" in content
def test_harness_not_overwritten_by_project(self, tmp_path: Path) -> None:
import sys
from researchclaw.config import SandboxConfig
from researchclaw.experiment.sandbox import ExperimentSandbox
config = SandboxConfig(python_path=sys.executable)
sandbox = ExperimentSandbox(config, tmp_path / "sandbox")
# Create a project with a fake experiment_harness.py
project = tmp_path / "project"
project.mkdir()
(project / "main.py").write_text("print('test: 1.0')\n", encoding="utf-8")
(project / "experiment_harness.py").write_text("# FAKE HARNESS", encoding="utf-8")
sandbox.run_project(project, timeout_sec=5)
# The real harness should be there, not the fake one (BUG-DA8-06)
project_dirs = list((tmp_path / "sandbox").glob("_project_*"))
assert project_dirs
harness_path = project_dirs[0] / "experiment_harness.py"
content = harness_path.read_text(encoding="utf-8")
assert "ExperimentHarness" in content
assert "FAKE HARNESS" not in content
def test_prompt_mentions_harness(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
block = pm.block("compute_budget")
assert "experiment_harness" in block or "ExperimentHarness" in block
# ── R5-5: Stdout Truncation Tests ────────────────────────────────────
class TestStdoutTruncation:
"""Test stdout/stderr truncation in refine run summaries (R5-5)."""
def test_long_stdout_truncated_in_refine(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
# Create a run with very long stdout
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
long_stdout = "\n".join(f"step {i}: loss={0.5 - i * 0.001:.6f}" for i in range(200))
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"status": "completed",
"metrics": {"best_loss": 0.3},
"stdout": long_stdout,
}),
encoding="utf-8",
)
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("print('best_loss: 0.3')\n", encoding="utf-8")
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 1,
"metric_key": "best_loss",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
llm = FakeLLMClient("```python\nprint('best_loss: 0.3')\n```")
rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
# The LLM should have received truncated stdout, not all 200 lines
assert len(llm.calls) >= 1
user_msg = llm.calls[0][-1]["content"]
# Should contain truncation indicator
assert "truncated" in user_msg or len(user_msg) < len(long_stdout)
# ===================================================================
# R6 Tests — Post-E2E Failure Analysis Fixes
# ===================================================================
class TestNoImproveStreakFix:
"""R6-1: no_improve_streak should only count iterations with real metrics."""
def test_empty_metrics_dont_increment_streak(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
"""When metrics are empty (None), the streak should NOT increment."""
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"status": "failed",
"metrics": {},
"stdout": "FAIL: NaN/divergence detected",
}),
encoding="utf-8",
)
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("print('hello')\n", encoding="utf-8")
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 4,
"metric_key": "primary_metric",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
# LLM returns code that won't produce metrics in simulated mode
llm = FakeLLMClient("```python\nprint('no metrics here')\n```")
result = rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
# Should abort after 3 consecutive no-metrics iterations
log_path = stage_dir / "refinement_log.json"
log_data = json.loads(log_path.read_text())
# consecutive_no_metrics triggers early abort after 3 iterations
assert len(log_data["iterations"]) == 3
assert log_data.get("stop_reason") == "consecutive_no_metrics"
class TestStdoutFailureDetection:
"""R6-2: Detect stdout failure signals even when exit code is 0."""
def test_fail_signal_in_stdout_marks_failed(self, tmp_path: Path) -> None:
"""Exit code 0 + 'FAIL:' in stdout + no metrics → status='failed'."""
from researchclaw.pipeline.executor import _execute_experiment_run
# Create necessary structure
run_dir = tmp_path / "run"
run_dir.mkdir()
(run_dir / "stage-10").mkdir()
exp_dir = run_dir / "stage-10" / "experiment"
exp_dir.mkdir()
# Simple code that prints FAIL but exits 0
(exp_dir / "main.py").write_text(
"print('FAIL: NaN/divergence detected')\n", encoding="utf-8"
)
(run_dir / "stage-11").mkdir()
(run_dir / "stage-11" / "schedule.json").write_text("{}", encoding="utf-8")
stage_dir = run_dir / "stage-12"
stage_dir.mkdir()
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 1,
"metric_key": "primary_metric",
"metric_direction": "minimize",
"sandbox": {
"python_path": sys.executable,
"gpu_required": False,
"max_memory_mb": 512,
"allowed_imports": ["json"],
},
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
adapters = AdapterBundle()
result = _execute_experiment_run(
stage_dir, run_dir, cfg, adapters
)
# Check the run payload
runs_dir = stage_dir / "runs"
run_file = runs_dir / "run-1.json"
assert run_file.exists()
payload = json.loads(run_file.read_text())
assert payload["status"] == "failed"
def test_clean_exit_no_fail_signal_marks_completed(self, tmp_path: Path) -> None:
"""Exit code 0 + valid metrics + no FAIL signal → status='completed'."""
from researchclaw.pipeline.executor import _execute_experiment_run
run_dir = tmp_path / "run"
run_dir.mkdir()
(run_dir / "stage-10").mkdir()
exp_dir = run_dir / "stage-10" / "experiment"
exp_dir.mkdir()
(exp_dir / "main.py").write_text(
"print('primary_metric: 0.95')\n", encoding="utf-8"
)
(run_dir / "stage-11").mkdir()
(run_dir / "stage-11" / "schedule.json").write_text("{}", encoding="utf-8")
stage_dir = run_dir / "stage-12"
stage_dir.mkdir()
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 1,
"metric_key": "primary_metric",
"metric_direction": "minimize",
"sandbox": {
"python_path": sys.executable,
"gpu_required": False,
"max_memory_mb": 512,
"allowed_imports": ["json"],
},
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
adapters = AdapterBundle()
result = _execute_experiment_run(
stage_dir, run_dir, cfg, adapters
)
runs_dir = stage_dir / "runs"
payload = json.loads((runs_dir / "run-1.json").read_text())
assert payload["status"] == "completed"
class TestMetricValUndefined:
"""R6-3: metric_val should be initialized to None before conditional block."""
def test_metric_val_initialized_before_use(self) -> None:
"""Verify the code pattern: metric_val = None before if block."""
import inspect
source = inspect.getsource(rc_executor._execute_iterative_refine)
# Find that metric_val = None appears before the sandbox block
init_pos = source.find("metric_val = None")
sandbox_pos = source.find("if validation.ok and config.experiment.mode")
assert init_pos != -1, "metric_val = None not found"
assert sandbox_pos != -1, "sandbox block not found"
assert init_pos < sandbox_pos, "metric_val = None should come before sandbox block"
class TestConsecutiveEmptyMetrics:
"""R6-4: Pipeline should detect consecutive empty-metrics REFINE cycles."""
def test_detects_consecutive_empty(self, tmp_path: Path) -> None:
"""Two cycles with empty metrics should return True."""
from researchclaw.pipeline.runner import _consecutive_empty_metrics
run_dir = tmp_path / "run"
# Current cycle (stage-14)
s14 = run_dir / "stage-14"
s14.mkdir(parents=True)
(s14 / "experiment_summary.json").write_text(json.dumps({
"metrics_summary": {},
"best_run": {"metrics": {}},
}))
# Previous cycle (stage-14_v1)
s14v1 = run_dir / "stage-14_v1"
s14v1.mkdir(parents=True)
(s14v1 / "experiment_summary.json").write_text(json.dumps({
"metrics_summary": {},
"best_run": {"metrics": {}},
}))
assert _consecutive_empty_metrics(run_dir, pivot_count=1) is True
def test_not_empty_when_metrics_exist(self, tmp_path: Path) -> None:
"""If any cycle has real metrics, return False."""
from researchclaw.pipeline.runner import _consecutive_empty_metrics
run_dir = tmp_path / "run"
s14 = run_dir / "stage-14"
s14.mkdir(parents=True)
(s14 / "experiment_summary.json").write_text(json.dumps({
"metrics_summary": {},
"best_run": {"metrics": {"loss": 0.5}},
}))
s14v1 = run_dir / "stage-14_v1"
s14v1.mkdir(parents=True)
(s14v1 / "experiment_summary.json").write_text(json.dumps({
"metrics_summary": {},
"best_run": {"metrics": {}},
}))
assert _consecutive_empty_metrics(run_dir, pivot_count=1) is False
def test_false_when_no_previous_cycle(self, tmp_path: Path) -> None:
"""First cycle (no v1) should return False."""
from researchclaw.pipeline.runner import _consecutive_empty_metrics
run_dir = tmp_path / "run"
s14 = run_dir / "stage-14"
s14.mkdir(parents=True)
(s14 / "experiment_summary.json").write_text(json.dumps({
"metrics_summary": {},
"best_run": {"metrics": {}},
}))
# No stage-14_v1 exists
assert _consecutive_empty_metrics(run_dir, pivot_count=1) is False
# ===================================================================
# R7 Tests — Experiment-Paper Quality Alignment
# ===================================================================
class TestMultiConditionEnforcement:
"""R7-1: Code generation prompt must enforce multi-condition experiments."""
def test_code_generation_prompt_has_multi_condition_block(self) -> None:
"""The code_generation prompt should contain multi-condition instructions."""
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"code_generation",
topic="test topic",
metric="primary_metric",
pkg_hint="",
exp_plan="conditions:\n - echo_chamber\n - bridge_building\n - random",
)
assert "MULTI-CONDITION REQUIREMENT" in sp.user
assert "condition=" in sp.user
assert "SUMMARY" in sp.user
def test_multi_condition_labels_required(self) -> None:
"""Prompt must mention per-condition labeled output format."""
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"code_generation",
topic="test",
metric="loss",
pkg_hint="",
exp_plan="treatments: [A, B, C]",
)
assert "condition=" in sp.user
class TestEvidenceBoundedWriting:
"""R7-2: Paper draft prompt must enforce evidence-bounded claims."""
def test_paper_draft_has_evidence_bounding_rules(self) -> None:
"""System prompt should contain evidence-bounding rules."""
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"paper_draft",
preamble="test preamble",
topic_constraint="",
exp_metrics_instruction="",
citation_instruction="",
outline="# Outline",
)
assert "EVIDENCE-BOUNDING RULES" in sp.system
assert "title" in sp.system.lower()
assert "causal claim" in sp.system.lower() or "causal claims" in sp.system.lower()
def test_hedging_language_guidance(self) -> None:
"""Should suggest hedged alternatives like 'Toward...' for partial data."""
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"paper_draft",
preamble="",
topic_constraint="",
exp_metrics_instruction="",
citation_instruction="",
outline="",
)
assert "Toward" in sp.system or "Investigating" in sp.system
class TestConditionCoverageDetection:
"""R7-3: REFINE should detect condition coverage gaps."""
def test_coverage_hint_injected_when_no_labels(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
"""If stdout has no 'condition=' labels, a coverage hint should be injected."""
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"status": "completed",
"metrics": {"primary_metric": 0.5},
"stdout": "primary_metric: 0.5\nprimary_metric: 0.3\n",
}),
encoding="utf-8",
)
exp_plan_dir = run_dir / "stage-09"
exp_plan_dir.mkdir(parents=True, exist_ok=True)
(exp_plan_dir / "exp_plan.yaml").write_text(
"conditions:\n - echo_chamber\n - bridge_building\n - random\n",
encoding="utf-8",
)
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("print('primary_metric: 0.5')\n", encoding="utf-8")
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 1,
"metric_key": "primary_metric",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
llm = FakeLLMClient("```python\nprint('primary_metric: 0.3')\n```")
rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
assert len(llm.calls) >= 1
user_msg = llm.calls[0][-1]["content"]
assert "CONDITION COVERAGE GAP" in user_msg
def test_no_hint_when_labels_present(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
"""If stdout already has 'condition=' labels, no hint should be injected."""
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"status": "completed",
"metrics": {"primary_metric": 0.5},
"stdout": "condition=echo primary_metric: 0.5\ncondition=bridge primary_metric: 0.3\n",
}),
encoding="utf-8",
)
exp_plan_dir = run_dir / "stage-09"
exp_plan_dir.mkdir(parents=True, exist_ok=True)
(exp_plan_dir / "exp_plan.yaml").write_text(
"conditions:\n - echo\n - bridge\n",
encoding="utf-8",
)
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("print('primary_metric: 0.5')\n", encoding="utf-8")
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 1,
"metric_key": "primary_metric",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
llm = FakeLLMClient("```python\nprint('primary_metric: 0.3')\n```")
rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
assert len(llm.calls) >= 1
user_msg = llm.calls[0][-1]["content"]
assert "CONDITION COVERAGE GAP" not in user_msg
# ===================================================================
# R8 Tests — AutoBench Round 1 Fixes
# ===================================================================
class TestBreadthFirstPrompt:
"""R8-1: Code generation prompt should require breadth-first condition ordering."""
def test_breadth_first_in_code_generation(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"code_generation",
topic="test",
metric="primary_metric",
pkg_hint="",
exp_plan="conditions: [A, B, C]",
)
assert "BREADTH-FIRST" in sp.user
assert "ONE representative" in sp.user
class TestRefineFilePreservation:
"""R8-2: Refine should preserve supporting files when LLM only returns main.py."""
def test_supporting_files_preserved_in_refine(
self, tmp_path: Path, run_dir: Path, adapters: AdapterBundle
) -> None:
"""When LLM returns only main.py, other project files should be preserved."""
runs_dir = run_dir / "stage-12" / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
(runs_dir / "run-1.json").write_text(
json.dumps({
"run_id": "run-1",
"status": "completed",
"metrics": {"primary_metric": 0.5},
"stdout": "primary_metric: 0.5",
}),
encoding="utf-8",
)
# Multi-file experiment project
exp_dir = run_dir / "stage-11" / "experiment"
exp_dir.mkdir(parents=True, exist_ok=True)
(exp_dir / "main.py").write_text("from helpers import foo\nprint('primary_metric: 0.5')\n")
(exp_dir / "helpers.py").write_text("def foo(): return 42\n")
(exp_dir / "utils.py").write_text("def bar(): return 99\n")
stage_dir = run_dir / "stage-13"
stage_dir.mkdir(parents=True, exist_ok=True)
data = {
"project": {"name": "rc-test", "mode": "docs-first"},
"research": {"topic": "test", "domains": ["ml"],
"daily_paper_count": 2, "quality_threshold": 8.2},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local", "on_stage_start": True,
"on_stage_fail": False, "on_gate_required": True},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {"use_memory": True, "use_message": True},
"llm": {"provider": "openai-compatible", "base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY", "api_key": "inline-test-key",
"primary_model": "fake-model", "fallback_models": []},
"security": {"hitl_required_stages": [5, 9, 20]},
"experiment": {
"mode": "sandbox",
"time_budget_sec": 30,
"max_iterations": 1,
"metric_key": "primary_metric",
"metric_direction": "minimize",
},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
# LLM returns only main.py in multi-file format
llm = FakeLLMClient("```filename:main.py\nfrom helpers import foo\nprint('primary_metric: 0.3')\n```")
rc_executor._execute_iterative_refine(
stage_dir, run_dir, cfg, adapters, llm=llm
)
# Check that experiment_v1 has ALL files, not just main.py
v1_dir = stage_dir / "experiment_v1"
assert v1_dir.exists()
v1_files = {f.name for f in v1_dir.glob("*.py")}
assert "main.py" in v1_files
assert "helpers.py" in v1_files, "Supporting file helpers.py should be preserved"
assert "utils.py" in v1_files, "Supporting file utils.py should be preserved"
# ===================================================================
# R9 Tests — AutoBench Round 2 Fixes
# ===================================================================
class TestCodeGenTopicNeutral:
"""R9-1: Code generation prompt should be topic-neutral, not optimization-biased."""
def test_no_gradient_descent_bias(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"code_generation",
topic="multi-agent simulation",
metric="primary_metric",
pkg_hint="",
exp_plan="conditions: [L1, L2, L3, L4]",
)
# Should NOT contain optimization-specific examples as recommended approaches
assert "Adam" not in sp.user
assert "SGD" not in sp.user
assert "Rosenbrock" not in sp.user
# "gradient descent" may appear as anti-pattern warning but not as example
assert "e.g., gradient descent" not in sp.user
def test_topic_relevant_guidance(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.for_stage(
"code_generation",
topic="multi-agent simulation",
metric="primary_metric",
pkg_hint="",
exp_plan="conditions: [L1, L2, L3, L4]",
)
# Should contain generic guidance that works for any topic
assert "simulation" in sp.user.lower() or "appropriate" in sp.user.lower()
assert "ACTUAL experiment" in sp.user or "relevant to the TOPIC" in sp.user
class TestRefineTopicAlignment:
"""R9-2: Refine prompt should include topic-code alignment check."""
def test_topic_alignment_in_refine_prompt(self) -> None:
from researchclaw.prompts import PromptManager
pm = PromptManager()
sp = pm.sub_prompt(
"iterative_improve",
metric_key="primary_metric",
metric_direction="maximize",
files_context="# main.py\nprint('hello')",
run_summaries="{}",
condition_coverage_hint="",
topic="multi-agent diversity scaling",
exp_plan_anchor="",
)
assert "EXPERIMENT PLAN ANCHOR" in sp.user
assert "multi-agent diversity scaling" in sp.user
assert "NEVER rename" in sp.user
# =====================================================================
# _validate_draft_quality tests
# =====================================================================
def _make_prose(word_count: int) -> str: # noqa: E302
"""Generate flowing prose text of approximately *word_count* words."""
sentence = (
"This is a flowing academic prose sentence "
"that demonstrates our research findings. "
)
words_per = len(sentence.split())
return sentence * (word_count // words_per + 1)
def _make_bullets(word_count: int) -> str:
"""Generate bullet-point text of approximately *word_count* words."""
line = "- This is a bullet point about a research finding\n"
words_per = len(line.split())
return line * (word_count // words_per + 1)
def _make_comparative_prose(word_count: int) -> str:
"""Generate related-work style prose with comparative language."""
sentence = (
"Unlike prior work that focuses on simple baselines, "
"our approach differs by incorporating novel techniques. "
"In contrast to existing methods, we address key limitations. "
"However, while previous approaches rely on heuristics, "
"our method provides theoretical guarantees. "
)
words_per = len(sentence.split())
return sentence * (word_count // words_per + 1)
def _make_results_prose(word_count: int) -> str:
"""Generate results prose with statistical measures."""
sentence = (
"Our method achieves 85.3 ± 1.2 accuracy averaged over 5 seeds. "
"The baseline comparison yields a p-value of 0.003, confirming "
"statistical significance with 95% confidence interval. "
)
words_per = len(sentence.split())
return sentence * (word_count // words_per + 1)
def _build_draft(**section_overrides: str) -> str:
"""Build a paper draft with default prose sections."""
defaults = {
"Abstract": _make_prose(200),
"Introduction": _make_prose(900),
"Related Work": _make_comparative_prose(700),
"Method": _make_prose(1200),
"Experiments": _make_prose(1000),
"Results": _make_results_prose(700),
"Discussion": _make_prose(500),
"Limitations": _make_prose(250),
"Conclusion": _make_prose(250),
}
defaults.update(section_overrides)
parts = ["# My Research Title\n"]
for heading, body in defaults.items():
parts.append(f"# {heading}\n{body}\n")
return "\n".join(parts)
class TestValidateDraftQuality:
"""Tests for _validate_draft_quality()."""
def test_short_section_triggers_warning(self) -> None:
"""Short Method section triggers expand warning."""
draft = _build_draft(Method=_make_prose(200))
result = rc_executor._validate_draft_quality(draft)
assert any("Method" in w for w in result["overall_warnings"])
assert any("EXPAND" in d or "Expand" in d
for d in result["revision_directives"])
def test_bullet_density_triggers_warning(self) -> None:
"""Bullet-heavy Method section triggers rewrite warning."""
draft = _build_draft(Method=_make_bullets(1200))
result = rc_executor._validate_draft_quality(draft)
assert any(
"bullet" in w.lower() or "density" in w.lower()
for w in result["overall_warnings"]
)
assert any("REWRITE" in d for d in result["revision_directives"])
def test_clean_draft_no_warnings(self) -> None:
"""Balanced prose draft produces zero warnings."""
draft = _build_draft()
result = rc_executor._validate_draft_quality(draft)
assert len(result["overall_warnings"]) == 0
assert len(result["revision_directives"]) == 0
def test_balance_warning(self) -> None:
"""Large imbalance between sections triggers balance warning."""
draft = _build_draft(
Introduction=_make_prose(1500),
Results=_make_prose(100),
)
result = rc_executor._validate_draft_quality(draft)
bal = [w for w in result["overall_warnings"]
if "imbalance" in w.lower()]
assert len(bal) >= 1, (
f"Expected balance warning, got: {result['overall_warnings']}"
)
def test_writes_json_to_stage_dir(self, tmp_path: Path) -> None:
"""Quality report is written as draft_quality.json."""
draft = _build_draft(Method=_make_prose(200))
rc_executor._validate_draft_quality(draft, stage_dir=tmp_path)
assert (tmp_path / "draft_quality.json").exists()
data = json.loads(
(tmp_path / "draft_quality.json").read_text()
)
assert "section_analysis" in data
assert "overall_warnings" in data
assert "revision_directives" in data
================================================
FILE: tests/test_rc_hardware.py
================================================
"""Tests for researchclaw.hardware — GPU detection & metric filtering."""
from __future__ import annotations
import subprocess
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.hardware import (
HardwareProfile,
_detect_mps,
_detect_nvidia,
detect_hardware,
ensure_torch_available,
is_metric_name,
)
# ---------------------------------------------------------------------------
# HardwareProfile
# ---------------------------------------------------------------------------
class TestHardwareProfile:
def test_to_dict(self):
hp = HardwareProfile(
has_gpu=True, gpu_type="cuda", gpu_name="RTX 4090",
vram_mb=24564, tier="high", warning="",
)
d = hp.to_dict()
assert d["has_gpu"] is True
assert d["gpu_type"] == "cuda"
assert d["vram_mb"] == 24564
def test_cpu_only_profile(self):
hp = HardwareProfile(
has_gpu=False, gpu_type="cpu", gpu_name="CPU only",
vram_mb=None, tier="cpu_only", warning="No GPU",
)
assert hp.tier == "cpu_only"
assert hp.warning == "No GPU"
# ---------------------------------------------------------------------------
# NVIDIA detection
# ---------------------------------------------------------------------------
class TestDetectNvidia:
def test_high_vram_nvidia(self):
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "NVIDIA GeForce RTX 4090, 24564\n"
with patch("researchclaw.hardware.subprocess.run", return_value=mock_result):
profile = _detect_nvidia()
assert profile is not None
assert profile.has_gpu is True
assert profile.gpu_type == "cuda"
assert profile.gpu_name == "NVIDIA GeForce RTX 4090"
assert profile.vram_mb == 24564
assert profile.tier == "high"
assert profile.warning == ""
def test_low_vram_nvidia(self):
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "NVIDIA GeForce GTX 1650, 4096\n"
with patch("researchclaw.hardware.subprocess.run", return_value=mock_result):
profile = _detect_nvidia()
assert profile is not None
assert profile.tier == "limited"
assert "limited memory" in profile.warning
def test_nvidia_smi_not_found(self):
with patch(
"researchclaw.hardware.subprocess.run",
side_effect=FileNotFoundError,
):
assert _detect_nvidia() is None
def test_nvidia_smi_failure(self):
mock_result = MagicMock()
mock_result.returncode = 1
with patch("researchclaw.hardware.subprocess.run", return_value=mock_result):
assert _detect_nvidia() is None
def test_nvidia_smi_timeout(self):
with patch(
"researchclaw.hardware.subprocess.run",
side_effect=subprocess.TimeoutExpired("nvidia-smi", 10),
):
assert _detect_nvidia() is None
# ---------------------------------------------------------------------------
# MPS detection
# ---------------------------------------------------------------------------
class TestDetectMPS:
def test_apple_silicon(self):
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "Apple M3 Pro\n"
with (
patch("researchclaw.hardware.platform.system", return_value="Darwin"),
patch("researchclaw.hardware.platform.machine", return_value="arm64"),
patch("researchclaw.hardware.subprocess.run", return_value=mock_result),
):
profile = _detect_mps()
assert profile is not None
assert profile.has_gpu is True
assert profile.gpu_type == "mps"
assert profile.gpu_name == "Apple M3 Pro"
assert profile.tier == "limited"
assert "MPS" in profile.warning
def test_non_darwin(self):
with patch("researchclaw.hardware.platform.system", return_value="Linux"):
assert _detect_mps() is None
def test_intel_mac(self):
with (
patch("researchclaw.hardware.platform.system", return_value="Darwin"),
patch("researchclaw.hardware.platform.machine", return_value="x86_64"),
):
assert _detect_mps() is None
# ---------------------------------------------------------------------------
# detect_hardware (integration)
# ---------------------------------------------------------------------------
class TestDetectHardware:
def test_falls_back_to_cpu(self):
with (
patch("researchclaw.hardware._detect_nvidia", return_value=None),
patch("researchclaw.hardware._detect_mps", return_value=None),
):
profile = detect_hardware()
assert profile.has_gpu is False
assert profile.gpu_type == "cpu"
assert profile.tier == "cpu_only"
assert "No GPU" in profile.warning
def test_nvidia_takes_priority(self):
nvidia_profile = HardwareProfile(
has_gpu=True, gpu_type="cuda", gpu_name="RTX 4090",
vram_mb=24564, tier="high", warning="",
)
mps_profile = HardwareProfile(
has_gpu=True, gpu_type="mps", gpu_name="M3",
vram_mb=None, tier="limited", warning="MPS",
)
with (
patch("researchclaw.hardware._detect_nvidia", return_value=nvidia_profile),
patch("researchclaw.hardware._detect_mps", return_value=mps_profile),
):
profile = detect_hardware()
assert profile.gpu_type == "cuda"
# ---------------------------------------------------------------------------
# ensure_torch_available
# ---------------------------------------------------------------------------
class TestEnsureTorchAvailable:
def test_already_installed(self):
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = "2.3.0\n"
with patch("researchclaw.hardware.subprocess.run", return_value=mock_result):
assert ensure_torch_available("/usr/bin/python3", "cuda") is True
def test_cpu_only_skips_install(self):
mock_check = MagicMock()
mock_check.returncode = 1 # not installed
mock_check.stdout = ""
with patch("researchclaw.hardware.subprocess.run", return_value=mock_check):
assert ensure_torch_available("/usr/bin/python3", "cpu") is False
def test_install_succeeds(self):
call_count = {"n": 0}
def side_effect(*args, **kwargs):
call_count["n"] += 1
mock = MagicMock()
if call_count["n"] == 1:
mock.returncode = 1 # import check fails
mock.stdout = ""
else:
mock.returncode = 0 # pip install succeeds
mock.stdout = ""
return mock
with patch("researchclaw.hardware.subprocess.run", side_effect=side_effect):
assert ensure_torch_available("/usr/bin/python3", "cuda") is True
def test_install_fails(self):
mock = MagicMock()
mock.returncode = 1
mock.stdout = ""
mock.stderr = "ERROR: Could not install"
with patch("researchclaw.hardware.subprocess.run", return_value=mock):
assert ensure_torch_available("/usr/bin/python3", "mps") is False
def test_python_not_found(self):
with patch(
"researchclaw.hardware.subprocess.run",
side_effect=FileNotFoundError,
):
assert ensure_torch_available("/nonexistent/python3", "cuda") is False
# ---------------------------------------------------------------------------
# is_metric_name
# ---------------------------------------------------------------------------
class TestIsMetricName:
def test_valid_metrics(self):
assert is_metric_name("loss") is True
assert is_metric_name("primary_metric") is True
assert is_metric_name("UCB (Stochastic) cumulative_regret") is True
assert is_metric_name("accuracy") is True
assert is_metric_name("f1_score") is True
def test_log_lines_filtered(self):
assert is_metric_name("Running experiments for support set size") is False
assert is_metric_name("Loading model weights") is False
assert is_metric_name("Training epoch 5") is False
assert is_metric_name("Evaluating on test set") is False
assert is_metric_name("Processing batch") is False
assert is_metric_name("Initializing optimizer") is False
def test_too_many_words_filtered(self):
assert is_metric_name("this is a very long name that has many words") is False
def test_short_names_pass(self):
assert is_metric_name("val_loss") is True
assert is_metric_name("test accuracy score") is True
================================================
FILE: tests/test_rc_health.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnusedCallResult=false, reportAttributeAccessIssue=false, reportUnknownLambdaType=false, reportMissingImports=false, reportUntypedNamedTuple=false, reportMissingTypeArgument=false, reportArgumentType=false
from __future__ import annotations
import json
import socket
import urllib.error
from pathlib import Path
from typing import NamedTuple, cast
from unittest.mock import patch
import pytest
from researchclaw import health
class _VersionInfo(NamedTuple):
major: int
minor: int
micro: int
releaselevel: str
serial: int
class _DummyHTTPResponse:
status: int
_payload: dict[str, object]
def __init__(
self, *, status: int = 200, payload: dict[str, object] | None = None
) -> None:
self.status = status
self._payload = payload if payload is not None else {}
def read(self) -> bytes:
return json.dumps(self._payload).encode("utf-8")
def __enter__(self) -> _DummyHTTPResponse:
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
return None
def _write_valid_config(path: Path) -> None:
_ = path.write_text(
"""
project:
name: demo
research:
topic: Doctor checks
runtime:
timezone: UTC
notifications:
channel: test
knowledge_base:
root: kb
llm:
base_url: https://api.example.com/v1
api_key_env: OPENAI_API_KEY
""".strip()
+ "\n",
encoding="utf-8",
)
def test_check_python_version_pass() -> None:
with patch("sys.version_info", _VersionInfo(3, 11, 0, "final", 0)):
result = health.check_python_version()
assert result.status == "pass"
def test_check_python_version_fail() -> None:
with patch("sys.version_info", _VersionInfo(3, 10, 9, "final", 0)):
result = health.check_python_version()
assert result.status == "fail"
assert "Install Python 3.11 or newer" == result.fix
def test_check_yaml_import_pass() -> None:
with patch("importlib.import_module", return_value=object()):
result = health.check_yaml_import()
assert result.status == "pass"
def test_check_yaml_import_fail() -> None:
with patch("importlib.import_module", side_effect=ImportError):
result = health.check_yaml_import()
assert result.status == "fail"
assert result.fix == "pip install pyyaml"
def test_check_config_valid_pass(tmp_path: Path) -> None:
config_path = tmp_path / "config.yaml"
_write_valid_config(config_path)
result = health.check_config_valid(config_path)
assert result.status == "pass"
def test_check_config_invalid(tmp_path: Path) -> None:
config_path = tmp_path / "config.yaml"
_ = config_path.write_text("project: {}\n", encoding="utf-8")
result = health.check_config_valid(config_path)
assert result.status == "fail"
assert "Missing required field:" in result.detail
def test_check_config_missing_file(tmp_path: Path) -> None:
result = health.check_config_valid(tmp_path / "missing.yaml")
assert result.status == "fail"
assert "Config file not found" in result.detail
def test_check_llm_connectivity_pass() -> None:
with patch("urllib.request.urlopen", return_value=_DummyHTTPResponse(status=200)):
result = health.check_llm_connectivity("https://api.example.com/v1")
assert result.status == "pass"
def test_check_llm_connectivity_timeout() -> None:
with patch(
"urllib.request.urlopen",
side_effect=urllib.error.URLError(socket.timeout("timed out")),
):
result = health.check_llm_connectivity("https://api.example.com/v1")
assert result.status == "fail"
assert result.detail == "LLM endpoint unreachable"
def test_check_llm_connectivity_http_error() -> None:
with patch(
"urllib.request.urlopen",
side_effect=urllib.error.HTTPError(
"https://api.example.com/v1/models", 503, "unavailable", {}, None
),
):
result = health.check_llm_connectivity("https://api.example.com/v1")
assert result.status == "fail"
assert "503" in result.detail
def test_check_api_key_valid() -> None:
with patch(
"urllib.request.urlopen",
return_value=_DummyHTTPResponse(status=200, payload={"data": []}),
):
result = health.check_api_key_valid("https://api.example.com/v1", "sk-test")
assert result.status == "pass"
def test_check_api_key_invalid_401() -> None:
with patch(
"urllib.request.urlopen",
side_effect=urllib.error.HTTPError(
"https://api.example.com/v1/models", 401, "unauthorized", {}, None
),
):
result = health.check_api_key_valid("https://api.example.com/v1", "bad")
assert result.status == "fail"
assert result.detail == "Invalid API key"
def test_check_model_available_pass() -> None:
payload = {"data": [{"id": "gpt-5.2"}, {"id": "gpt-4o"}]}
with patch(
"urllib.request.urlopen",
return_value=_DummyHTTPResponse(status=200, payload=payload),
):
result = health.check_model_available(
"https://api.example.com/v1", "sk-test", "gpt-5.2"
)
assert result.status == "pass"
def test_check_model_not_available() -> None:
payload = {"data": [{"id": "gpt-4o"}]}
with patch(
"urllib.request.urlopen",
return_value=_DummyHTTPResponse(status=200, payload=payload),
):
result = health.check_model_available(
"https://api.example.com/v1", "sk-test", "gpt-5.2"
)
assert result.status == "fail"
assert result.detail == "Model gpt-5.2 not available"
def test_check_model_chain_all_available() -> None:
payload = {"data": [{"id": "gpt-4o"}, {"id": "gpt-4.1"}]}
with patch(
"urllib.request.urlopen",
return_value=_DummyHTTPResponse(status=200, payload=payload),
):
result = health.check_model_chain(
"https://api.example.com/v1", "sk-test", "gpt-4o", ("gpt-4.1",)
)
assert result.status == "pass"
assert "All models available" in result.detail
def test_check_model_chain_primary_missing_fallback_ok() -> None:
payload = {"data": [{"id": "gpt-4.1"}, {"id": "gpt-4o-mini"}]}
with patch(
"urllib.request.urlopen",
return_value=_DummyHTTPResponse(status=200, payload=payload),
):
result = health.check_model_chain(
"https://api.example.com/v1", "sk-test",
"gpt-5.2", ("gpt-4.1", "gpt-4o-mini")
)
assert result.status == "pass"
assert "unavailable" in result.detail
assert "gpt-5.2" in result.detail
def test_check_model_chain_all_missing() -> None:
payload = {"data": [{"id": "gpt-4o"}]}
with patch(
"urllib.request.urlopen",
return_value=_DummyHTTPResponse(status=200, payload=payload),
):
result = health.check_model_chain(
"https://api.example.com/v1", "sk-test",
"gpt-5.2", ("gpt-5.1",)
)
assert result.status == "fail"
assert "No models available" in result.detail
def test_check_model_chain_no_models() -> None:
result = health.check_model_chain(
"https://api.example.com/v1", "sk-test", "", ()
)
assert result.status == "warn"
assert "No models configured" in result.detail
def test_check_sandbox_python_exists() -> None:
with (
patch.object(Path, "exists", return_value=True),
patch("os.access", return_value=True),
):
result = health.check_sandbox_python(".venv_arc/bin/python3")
assert result.status == "pass"
def test_check_sandbox_python_missing() -> None:
with (
patch.object(Path, "exists", return_value=False),
patch("os.access", return_value=False),
):
result = health.check_sandbox_python(".venv_arc/bin/python3")
assert result.status == "warn"
def test_check_matplotlib_available() -> None:
with patch("importlib.import_module", return_value=object()):
result = health.check_matplotlib()
assert result.status == "pass"
def test_check_matplotlib_missing() -> None:
with patch("importlib.import_module", side_effect=ImportError):
result = health.check_matplotlib()
assert result.status == "warn"
assert result.detail == "Not installed; charts will be skipped"
def test_check_experiment_mode_simulated() -> None:
result = health.check_experiment_mode("simulated")
assert result.status == "warn"
def test_check_experiment_mode_sandbox() -> None:
result = health.check_experiment_mode("sandbox")
assert result.status == "pass"
def test_run_doctor_all_pass_openai(tmp_path: Path) -> None:
config_path = tmp_path / "config.yaml"
_ = config_path.write_text("project: {}\n", encoding="utf-8")
with (
patch.object(
health,
"check_python_version",
return_value=health.CheckResult("python_version", "pass", "ok"),
),
patch.object(
health,
"check_yaml_import",
return_value=health.CheckResult("yaml_import", "pass", "ok"),
),
patch.object(
health,
"check_config_valid",
return_value=health.CheckResult("config_valid", "pass", "ok"),
),
patch.object(
health,
"check_llm_connectivity",
return_value=health.CheckResult("llm_connectivity", "pass", "ok"),
),
patch.object(
health,
"check_api_key_valid",
return_value=health.CheckResult("api_key_valid", "pass", "ok"),
),
patch.object(
health,
"check_model_chain",
return_value=health.CheckResult("model_chain", "pass", "ok"),
),
patch.object(
health,
"check_sandbox_python",
return_value=health.CheckResult("sandbox_python", "pass", "ok"),
),
patch.object(
health,
"check_matplotlib",
return_value=health.CheckResult("matplotlib", "pass", "ok"),
),
patch.object(
health,
"check_experiment_mode",
return_value=health.CheckResult("experiment_mode", "pass", "ok"),
),
):
report = health.run_doctor(config_path)
assert report.overall == "pass"
assert len(report.checks) == 9
def test_run_doctor_with_failures(tmp_path: Path) -> None:
config_path = tmp_path / "config.yaml"
_ = config_path.write_text("project: {}\n", encoding="utf-8")
with (
patch.object(
health,
"check_python_version",
return_value=health.CheckResult("python_version", "pass", "ok"),
),
patch.object(
health,
"check_yaml_import",
return_value=health.CheckResult("yaml_import", "pass", "ok"),
),
patch.object(
health,
"check_config_valid",
return_value=health.CheckResult("config_valid", "fail", "bad", "fix it"),
),
patch.object(
health,
"check_llm_connectivity",
return_value=health.CheckResult("llm_connectivity", "pass", "ok"),
),
patch.object(
health,
"check_api_key_valid",
return_value=health.CheckResult("api_key_valid", "warn", "warn", "later"),
),
patch.object(
health,
"check_model_chain",
return_value=health.CheckResult("model_chain", "pass", "ok"),
),
patch.object(
health,
"check_sandbox_python",
return_value=health.CheckResult("sandbox_python", "pass", "ok"),
),
patch.object(
health,
"check_matplotlib",
return_value=health.CheckResult("matplotlib", "pass", "ok"),
),
patch.object(
health,
"check_experiment_mode",
return_value=health.CheckResult("experiment_mode", "pass", "ok"),
),
):
report = health.run_doctor(config_path)
assert report.overall == "fail"
assert "fix it" in report.actionable_fixes
def test_doctor_report_json_structure(tmp_path: Path) -> None:
report = health.DoctorReport(
timestamp="2026-01-01T00:00:00+00:00",
checks=[
health.CheckResult("python_version", "pass", "ok"),
health.CheckResult(
"matplotlib", "warn", "missing", "pip install matplotlib"
),
],
overall="pass",
)
output_path = tmp_path / "reports" / "doctor.json"
health.write_doctor_report(report, output_path)
raw = cast(dict[str, object], json.loads(output_path.read_text(encoding="utf-8")))
assert raw["timestamp"] == "2026-01-01T00:00:00+00:00"
assert raw["overall"] == "pass"
assert isinstance(raw["checks"], list)
assert raw["actionable_fixes"] == ["pip install matplotlib"]
def test_doctor_report_overall_logic() -> None:
passing = health.DoctorReport(
timestamp="2026-01-01T00:00:00+00:00",
checks=[health.CheckResult("x", "pass", "ok")],
overall="pass",
)
failing = health.DoctorReport(
timestamp="2026-01-01T00:00:00+00:00",
checks=[health.CheckResult("x", "fail", "bad", "fix")],
overall="fail",
)
assert passing.overall == "pass"
assert failing.overall == "fail"
assert failing.actionable_fixes == ["fix"]
def test_print_doctor_report_pass(capsys: pytest.CaptureFixture[str]) -> None:
report = health.DoctorReport(
timestamp="2026-01-01T00:00:00+00:00",
checks=[health.CheckResult("python_version", "pass", "ok")],
overall="pass",
)
health.print_doctor_report(report)
out = capsys.readouterr().out
assert "✅" in out
assert "Result: PASS" in out
def test_print_doctor_report_fail(capsys: pytest.CaptureFixture[str]) -> None:
report = health.DoctorReport(
timestamp="2026-01-01T00:00:00+00:00",
checks=[
health.CheckResult("config_valid", "fail", "bad config", "fix config"),
health.CheckResult(
"matplotlib", "warn", "missing", "pip install matplotlib"
),
],
overall="fail",
)
health.print_doctor_report(report)
out = capsys.readouterr().out
assert "❌" in out
assert "⚠️" in out
assert "Result: FAIL (1 errors, 1 warnings)" in out
# --- ACP agent checks ---
def test_check_acp_agent_found() -> None:
with patch("shutil.which", return_value="/usr/local/bin/claude"):
result = health.check_acp_agent("claude")
assert result.status == "pass"
assert "/usr/local/bin/claude" in result.detail
def test_check_acp_agent_missing() -> None:
with patch("shutil.which", return_value=None):
result = health.check_acp_agent("claude")
assert result.status == "fail"
assert "'claude' not found" in result.detail
assert "Install claude" in result.fix
def _write_acp_config(path: Path) -> None:
_ = path.write_text(
"""\
project:
name: demo
research:
topic: ACP test
runtime:
timezone: UTC
notifications:
channel: test
knowledge_base:
root: kb
llm:
provider: acp
acp:
agent: claude
""",
encoding="utf-8",
)
def test_run_doctor_acp_skips_http_checks(tmp_path: Path) -> None:
config_path = tmp_path / "config.yaml"
_write_acp_config(config_path)
with (
patch.object(
health, "check_python_version",
return_value=health.CheckResult("python_version", "pass", "ok"),
),
patch.object(
health, "check_yaml_import",
return_value=health.CheckResult("yaml_import", "pass", "ok"),
),
patch.object(
health, "check_config_valid",
return_value=health.CheckResult("config_valid", "pass", "ok"),
),
patch.object(
health, "check_acp_agent",
return_value=health.CheckResult("acp_agent", "pass", "ok"),
),
patch.object(
health, "check_sandbox_python",
return_value=health.CheckResult("sandbox_python", "pass", "ok"),
),
patch.object(
health, "check_matplotlib",
return_value=health.CheckResult("matplotlib", "pass", "ok"),
),
patch.object(
health, "check_experiment_mode",
return_value=health.CheckResult("experiment_mode", "pass", "ok"),
),
):
report = health.run_doctor(config_path)
check_names = [c.name for c in report.checks]
assert "llm_connectivity" not in check_names
assert "api_key_valid" not in check_names
assert "model_chain" not in check_names
def test_run_doctor_acp_includes_agent_check(tmp_path: Path) -> None:
config_path = tmp_path / "config.yaml"
_write_acp_config(config_path)
with (
patch.object(
health, "check_python_version",
return_value=health.CheckResult("python_version", "pass", "ok"),
),
patch.object(
health, "check_yaml_import",
return_value=health.CheckResult("yaml_import", "pass", "ok"),
),
patch.object(
health, "check_config_valid",
return_value=health.CheckResult("config_valid", "pass", "ok"),
),
patch.object(
health, "check_acp_agent",
return_value=health.CheckResult("acp_agent", "pass", "ok"),
),
patch.object(
health, "check_sandbox_python",
return_value=health.CheckResult("sandbox_python", "pass", "ok"),
),
patch.object(
health, "check_matplotlib",
return_value=health.CheckResult("matplotlib", "pass", "ok"),
),
patch.object(
health, "check_experiment_mode",
return_value=health.CheckResult("experiment_mode", "pass", "ok"),
),
):
report = health.run_doctor(config_path)
check_names = [c.name for c in report.checks]
assert "acp_agent" in check_names
assert report.overall == "pass"
assert len(report.checks) == 7
def test_print_doctor_report_ascii_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
report = health.DoctorReport(
timestamp="2026-01-01T00:00:00+00:00",
checks=[health.CheckResult("python_version", "pass", "ok")],
overall="pass",
)
class _AsciiStdout:
encoding = "ascii"
def __init__(self) -> None:
self.parts: list[str] = []
def write(self, text: str) -> int:
text.encode(self.encoding)
self.parts.append(text)
return len(text)
def flush(self) -> None:
return None
fake_stdout = _AsciiStdout()
monkeypatch.setattr(health.sys, "stdout", fake_stdout)
health.print_doctor_report(report)
out = "".join(fake_stdout.parts)
assert "[OK] python_version: ok" in out
assert "Result: PASS" in out
================================================
FILE: tests/test_rc_kb.py
================================================
from __future__ import annotations
import json
from pathlib import Path
import yaml
from researchclaw.knowledge.base import (
KB_CATEGORY_MAP,
KBEntry,
_markdown_frontmatter,
_obsidian_enhancements,
generate_weekly_report,
write_kb_entry,
write_stage_to_kb,
)
def _kb_root(tmp_path: Path) -> Path:
return tmp_path / "kb"
def test_kb_entry_dataclass_creation():
entry = KBEntry(
category="findings",
entry_id="e1",
title="T",
content="C",
source_stage="01-goal_define",
run_id="run1",
)
assert entry.category == "findings"
assert entry.entry_id == "e1"
assert entry.run_id == "run1"
def test_write_kb_entry_creates_expected_file_path(tmp_path: Path):
kb_root = _kb_root(tmp_path)
entry = KBEntry("questions", "q-1", "Q", "Body", "01-goal_define", "run-a")
path = write_kb_entry(kb_root, entry)
assert path == kb_root / "questions" / "q-1.md"
assert path.exists()
def test_write_kb_entry_includes_frontmatter_markers(tmp_path: Path):
kb_root = _kb_root(tmp_path)
entry = KBEntry("findings", "f-1", "Finding", "Body", "14-result_analysis", "run-a")
text = write_kb_entry(kb_root, entry).read_text(encoding="utf-8")
assert text.startswith("---\n")
assert "\n---\n" in text
def test_write_kb_entry_markdown_backend_has_no_obsidian_extras(tmp_path: Path):
kb_root = _kb_root(tmp_path)
entry = KBEntry(
"questions",
"q-2",
"Question",
"Body",
"01-goal_define",
"run-a",
tags=["hypothesis"],
links=["run-run-a"],
)
text = write_kb_entry(kb_root, entry, backend="markdown").read_text(
encoding="utf-8"
)
assert "[[run-run-a]]" not in text
assert "#hypothesis" not in text
def test_write_kb_entry_obsidian_backend_includes_tags_and_wikilinks(tmp_path: Path):
kb_root = _kb_root(tmp_path)
entry = KBEntry(
"questions",
"q-3",
"Question",
"Body",
"01-goal_define",
"run-a",
tags=["hypothesis", "q1"],
links=["run-run-a", "topic-a"],
)
text = write_kb_entry(kb_root, entry, backend="obsidian").read_text(
encoding="utf-8"
)
assert "#hypothesis #q1" in text
assert "Related: [[run-run-a]], [[topic-a]]" in text
def test_markdown_frontmatter_output_format_and_fields():
entry = KBEntry(
"reviews",
"r-1",
"Report",
"Body",
"report",
"run-x",
tags=["weekly"],
evidence_refs=["stage-01/goal.md"],
)
fm = _markdown_frontmatter(entry)
assert fm.startswith("---\n")
assert fm.endswith("\n---\n")
parsed = yaml.safe_load(fm.split("---\n", 1)[1].rsplit("\n---\n", 1)[0])
assert parsed["id"] == "r-1"
assert parsed["title"] == "Report"
assert parsed["stage"] == "report"
assert parsed["run_id"] == "run-x"
assert parsed["tags"] == ["weekly"]
assert parsed["evidence"] == ["stage-01/goal.md"]
def test_obsidian_enhancements_with_tags_and_links():
entry = KBEntry(
"findings",
"f-2",
"Finding",
"Body",
"14-result_analysis",
"run-z",
tags=["a", "b"],
links=["run-z", "result-node"],
)
enh = _obsidian_enhancements(entry)
assert "#a #b" in enh
assert "Related: [[run-z]], [[result-node]]" in enh
def test_obsidian_enhancements_with_no_tags_or_links_returns_empty():
entry = KBEntry("findings", "f-3", "Finding", "Body", "14-result_analysis", "run-z")
assert _obsidian_enhancements(entry) == ""
def test_kb_category_map_has_exactly_22_stage_entries():
assert len(KB_CATEGORY_MAP) == 22
assert set(KB_CATEGORY_MAP) == set(range(1, 23))
def test_kb_category_map_values_are_valid_categories():
valid = {
"questions",
"literature",
"experiments",
"findings",
"decisions",
"reviews",
}
assert set(KB_CATEGORY_MAP.values()).issubset(valid)
def test_write_stage_to_kb_places_entry_in_mapped_category(tmp_path: Path):
kb_root = _kb_root(tmp_path)
stage_dir = tmp_path / "stage-10"
stage_dir.mkdir()
(stage_dir / "run.md").write_text("exp content", encoding="utf-8")
paths = write_stage_to_kb(
kb_root, 10, "experiment_cycle", "run-1", ["run.md"], stage_dir
)
assert len(paths) == 1
assert paths[0].parent.name == "experiments"
def test_write_stage_to_kb_reads_artifact_file_contents(tmp_path: Path):
kb_root = _kb_root(tmp_path)
stage_dir = tmp_path / "stage-04"
stage_dir.mkdir()
(stage_dir / "lit.md").write_text("paper A\npaper B", encoding="utf-8")
path = write_stage_to_kb(
kb_root, 4, "literature_search", "run-1", ["lit.md"], stage_dir
)[0]
text = path.read_text(encoding="utf-8")
assert "paper A" in text
assert "stage-04/lit.md" in text
def test_write_stage_to_kb_handles_missing_artifacts_gracefully(tmp_path: Path):
kb_root = _kb_root(tmp_path)
stage_dir = tmp_path / "stage-05"
stage_dir.mkdir()
path = write_stage_to_kb(
kb_root, 5, "literature_extract", "run-2", ["missing.md"], stage_dir
)[0]
text = path.read_text(encoding="utf-8")
assert "Stage 05 (literature_extract) completed" in text
def test_write_stage_to_kb_truncates_large_artifact_content(tmp_path: Path):
kb_root = _kb_root(tmp_path)
stage_dir = tmp_path / "stage-12"
stage_dir.mkdir()
large_text = "x" * 6000
(stage_dir / "big.txt").write_text(large_text, encoding="utf-8")
path = write_stage_to_kb(
kb_root, 12, "experiment_implement", "run-3", ["big.txt"], stage_dir
)[0]
text = path.read_text(encoding="utf-8")
assert "... (truncated, see full artifact)" in text
assert text.count("x") >= 5000
def test_write_stage_to_kb_directory_artifact_records_listing(tmp_path: Path):
kb_root = _kb_root(tmp_path)
stage_dir = tmp_path / "stage-13"
artifact_dir = stage_dir / "outputs"
artifact_dir.mkdir(parents=True)
(artifact_dir / "a.txt").write_text("a", encoding="utf-8")
(artifact_dir / "b.txt").write_text("b", encoding="utf-8")
path = write_stage_to_kb(
kb_root, 13, "experiment_execute", "run-4", ["outputs/"], stage_dir
)[0]
text = path.read_text(encoding="utf-8")
assert "Directory with 2 files: a.txt, b.txt" in text
assert "stage-13/outputs/" in text
def test_generate_weekly_report_creates_file_in_reviews_category(tmp_path: Path):
kb_root = _kb_root(tmp_path)
run_dir = tmp_path / "run-a"
run_dir.mkdir()
(run_dir / "pipeline_summary.json").write_text(
json.dumps({"run_id": "run-a", "stages_executed": 10, "stages_done": 10}),
encoding="utf-8",
)
path = generate_weekly_report(kb_root, [run_dir], week_label="2026-W10")
assert path.parent.name == "reviews"
assert path.name == "weekly-report-2026-W10.md"
def test_generate_weekly_report_with_empty_run_dirs(tmp_path: Path):
kb_root = _kb_root(tmp_path)
path = generate_weekly_report(kb_root, [], week_label="2026-W11")
text = path.read_text(encoding="utf-8")
assert "Pipeline runs: 0" in text
assert "Success rate: N/A" in text
def test_generate_weekly_report_aggregates_statistics_correctly(tmp_path: Path):
kb_root = _kb_root(tmp_path)
run1 = tmp_path / "run-1"
run2 = tmp_path / "run-2"
run1.mkdir()
run2.mkdir()
(run1 / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "run-1",
"stages_executed": 20,
"stages_done": 18,
"stages_failed": 1,
"stages_blocked": 1,
"final_status": "failed",
}
),
encoding="utf-8",
)
(run2 / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "run-2",
"stages_executed": 10,
"stages_done": 10,
"stages_failed": 0,
"stages_blocked": 0,
"final_status": "done",
}
),
encoding="utf-8",
)
report = generate_weekly_report(kb_root, [run1, run2], week_label="2026-W12")
text = report.read_text(encoding="utf-8")
assert "Pipeline runs: 2" in text
assert "Stages executed: 30" in text
assert "Stages completed: 28" in text
assert "Stages failed: 1" in text
assert "Stages blocked (gate): 1" in text
assert "Success rate: 93.3%" in text
def test_generate_weekly_report_ignores_missing_summary_files(tmp_path: Path):
kb_root = _kb_root(tmp_path)
run_ok = tmp_path / "run-ok"
run_empty = tmp_path / "run-empty"
run_ok.mkdir()
run_empty.mkdir()
(run_ok / "pipeline_summary.json").write_text(
json.dumps({"run_id": "run-ok", "stages_executed": 5, "stages_done": 5}),
encoding="utf-8",
)
report = generate_weekly_report(kb_root, [run_ok, run_empty], week_label="2026-W13")
text = report.read_text(encoding="utf-8")
assert "Pipeline runs: 1" in text
================================================
FILE: tests/test_rc_literature.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false
"""Unit tests for researchclaw.literature module.
All network-dependent tests mock HTTP responses via monkeypatch.
"""
from __future__ import annotations
import json
import textwrap
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.literature.models import Author, Paper
from researchclaw.literature.semantic_scholar import (
_parse_s2_paper,
search_semantic_scholar,
)
from researchclaw.literature.arxiv_client import (
_convert_result,
search_arxiv,
)
from researchclaw.literature.search import (
_deduplicate,
_normalise_title,
papers_to_bibtex,
search_papers,
search_papers_multi_query,
)
# ──────────────────────────────────────────────────────────────────────
# Fixtures & helpers
# ──────────────────────────────────────────────────────────────────────
def _make_paper(**kwargs: Any) -> Paper:
defaults = {
"paper_id": "s2-abc",
"title": "Attention Is All You Need",
"authors": (Author(name="Ashish Vaswani"),),
"year": 2017,
"venue": "NeurIPS",
"citation_count": 80000,
"doi": "10.5555/3295222.3295349",
"arxiv_id": "1706.03762",
"url": "https://arxiv.org/abs/1706.03762",
"source": "semantic_scholar",
}
defaults.update(kwargs)
return Paper(**defaults)
SAMPLE_S2_RESPONSE = {
"total": 1,
"data": [
{
"paperId": "abc123",
"title": "Test Paper on Transformers",
"abstract": "We study transformers for NLP tasks.",
"year": 2024,
"venue": "NeurIPS",
"citationCount": 42,
"authors": [
{"authorId": "1", "name": "Jane Smith"},
{"authorId": "2", "name": "John Doe"},
],
"externalIds": {"DOI": "10.1234/test", "ArXiv": "2401.00001"},
"url": "https://www.semanticscholar.org/paper/abc123",
}
],
}
SAMPLE_ARXIV_ATOM = textwrap.dedent("""\
http://arxiv.org/abs/2401.00001v1
A Novel Approach to Protein Folding
We propose a new method for protein structure prediction.
2024-01-15T00:00:00Z
Alice Researcher
Bob Scientist
10.5678/protein
http://arxiv.org/abs/2402.00002v1
Deep Reinforcement Learning Survey
A comprehensive survey of deep RL methods.
2024-02-20T00:00:00Z
Charlie Expert
""")
# ──────────────────────────────────────────────────────────────────────
# Author tests
# ──────────────────────────────────────────────────────────────────────
class TestAuthor:
def test_last_name_simple(self) -> None:
a = Author(name="Jane Smith")
assert a.last_name() == "smith"
def test_last_name_accented(self) -> None:
a = Author(name="José García")
assert a.last_name() == "garcia" # accent stripped, but 'i' preserved
def test_last_name_single(self) -> None:
a = Author(name="Madonna")
assert a.last_name() == "madonna"
def test_last_name_empty(self) -> None:
a = Author(name="")
assert a.last_name() == "unknown"
# ──────────────────────────────────────────────────────────────────────
# Paper tests
# ──────────────────────────────────────────────────────────────────────
class TestPaper:
def test_cite_key_format(self) -> None:
p = _make_paper()
key = p.cite_key
assert key == "vaswani2017attention"
def test_cite_key_no_authors(self) -> None:
p = _make_paper(authors=())
assert p.cite_key.startswith("anon")
def test_cite_key_no_year(self) -> None:
p = _make_paper(year=0)
assert "0000" in p.cite_key
def test_to_bibtex_contains_fields(self) -> None:
p = _make_paper()
bib = p.to_bibtex()
assert "@inproceedings{vaswani2017attention," in bib
assert "title = {Attention Is All You Need}" in bib
assert "author = {Ashish Vaswani}" in bib
assert "year = {2017}" in bib
assert "doi = {10.5555/3295222.3295349}" in bib
assert "eprint = {1706.03762}" in bib
def test_to_bibtex_override(self) -> None:
p = _make_paper(_bibtex_override="@article{custom, title={Custom}}")
assert p.to_bibtex() == "@article{custom, title={Custom}}"
def test_to_bibtex_article_no_venue(self) -> None:
p = _make_paper(venue="", arxiv_id="2301.00001")
bib = p.to_bibtex()
assert "@article{" in bib
assert "journal = {arXiv preprint arXiv:2301.00001}" in bib
def test_to_bibtex_arxiv_category_venue(self) -> None:
"""T1.4: arXiv category codes (cs.CL) must not be used as journal names."""
p = _make_paper(venue="cs.CL", arxiv_id="2301.00001")
bib = p.to_bibtex()
assert "journal = {cs.CL}" not in bib
assert "arXiv preprint" in bib
def test_to_dict(self) -> None:
p = _make_paper()
d = p.to_dict()
assert d["paper_id"] == "s2-abc"
assert d["cite_key"] == "vaswani2017attention"
assert isinstance(d["authors"], list)
assert d["authors"][0]["name"] == "Ashish Vaswani"
def test_paper_frozen(self) -> None:
p = _make_paper()
with pytest.raises(AttributeError):
p.title = "new title" # type: ignore[misc]
# ──────────────────────────────────────────────────────────────────────
# Semantic Scholar client tests
# ──────────────────────────────────────────────────────────────────────
class TestSemanticScholar:
def test_parse_s2_paper(self) -> None:
item = SAMPLE_S2_RESPONSE["data"][0]
p = _parse_s2_paper(item)
assert p.paper_id == "s2-abc123"
assert p.title == "Test Paper on Transformers"
assert len(p.authors) == 2
assert p.authors[0].name == "Jane Smith"
assert p.year == 2024
assert p.doi == "10.1234/test"
assert p.arxiv_id == "2401.00001"
assert p.source == "semantic_scholar"
assert p.citation_count == 42
def test_search_semantic_scholar_mock(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Mock urllib to return sample S2 response."""
# Reset S2 circuit breaker (may be tripped from prior test API calls)
from researchclaw.literature.semantic_scholar import _reset_circuit_breaker
_reset_circuit_breaker()
response_bytes = json.dumps(SAMPLE_S2_RESPONSE).encode("utf-8")
mock_resp = MagicMock()
mock_resp.read.return_value = response_bytes
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
monkeypatch.setattr(
"researchclaw.literature.semantic_scholar.urllib.request.urlopen",
lambda *a, **kw: mock_resp,
)
papers = search_semantic_scholar("transformers", limit=5)
assert len(papers) == 1
assert papers[0].title == "Test Paper on Transformers"
def test_search_semantic_scholar_network_error(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Should return empty list on network error."""
from researchclaw.literature.semantic_scholar import _reset_circuit_breaker
_reset_circuit_breaker()
import urllib.error
monkeypatch.setattr(
"researchclaw.literature.semantic_scholar.urllib.request.urlopen",
lambda *a, **kw: (_ for _ in ()).throw(urllib.error.URLError("timeout")),
)
# Patch sleep to speed up test
monkeypatch.setattr(
"researchclaw.literature.semantic_scholar.time.sleep", lambda _: None
)
papers = search_semantic_scholar("test", limit=5)
assert papers == []
# ──────────────────────────────────────────────────────────────────────
# arXiv client tests
# ──────────────────────────────────────────────────────────────────────
class TestArxiv:
def test_convert_result(self) -> None:
"""Test converting arxiv.Result to Paper via the new library."""
from unittest.mock import MagicMock
from datetime import datetime
mock_result = MagicMock()
mock_result.entry_id = "http://arxiv.org/abs/2401.00001v1"
mock_result.title = "A Novel Approach to Protein Folding"
mock_result.summary = "We study protein folding."
mock_result.published = datetime(2024, 1, 15)
mock_result.doi = "10.5678/protein"
mock_result.primary_category = "q-bio.BM"
mock_author1 = MagicMock()
mock_author1.name = "Alice Researcher"
mock_author2 = MagicMock()
mock_author2.name = "Bob Scientist"
mock_result.authors = [mock_author1, mock_author2]
paper = _convert_result(mock_result)
assert paper.title == "A Novel Approach to Protein Folding"
assert paper.arxiv_id == "2401.00001"
assert paper.year == 2024
assert len(paper.authors) == 2
assert paper.authors[0].name == "Alice Researcher"
assert paper.source == "arxiv"
assert paper.doi == "10.5678/protein"
assert paper.venue == "q-bio.BM"
def test_search_arxiv_mock(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test search_arxiv with mocked arxiv library."""
from unittest.mock import MagicMock
from datetime import datetime
import types
mock_result = MagicMock()
mock_result.entry_id = "http://arxiv.org/abs/2401.00001v1"
mock_result.title = "Test Paper"
mock_result.summary = "Abstract."
mock_result.published = datetime(2024, 1, 1)
mock_result.doi = ""
mock_result.primary_category = "cs.LG"
mock_author = MagicMock()
mock_author.name = "Test Author"
mock_result.authors = [mock_author]
mock_client = MagicMock()
mock_client.results.return_value = iter([mock_result])
# Mock the module-level `arxiv` so the `if arxiv is None` guard
# doesn't short-circuit before the mocked _get_client is reached.
# Use MagicMock so all attributes (Search, SortOrder, etc.) auto-resolve.
_fake_arxiv = MagicMock()
monkeypatch.setattr(
"researchclaw.literature.arxiv_client.arxiv", _fake_arxiv,
)
monkeypatch.setattr(
"researchclaw.literature.arxiv_client._get_client",
lambda: mock_client,
)
from researchclaw.literature.arxiv_client import _reset_circuit_breaker
_reset_circuit_breaker()
papers = search_arxiv("test", limit=10)
assert len(papers) == 1
assert papers[0].title == "Test Paper"
assert papers[0].arxiv_id == "2401.00001"
def test_search_arxiv_error_graceful(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""search_arxiv returns empty list on error, not raise."""
from unittest.mock import MagicMock
import types
# Build a fake arxiv module with real exception classes so
# `except arxiv.HTTPError` doesn't TypeError.
_fake_arxiv = types.ModuleType("arxiv")
class _FakeHTTPError(Exception):
pass
class _FakeUnexpectedEmptyPageError(Exception):
pass
_fake_arxiv.HTTPError = _FakeHTTPError
_fake_arxiv.UnexpectedEmptyPageError = _FakeUnexpectedEmptyPageError
_fake_arxiv.SortCriterion = MagicMock()
_fake_arxiv.SortOrder = MagicMock()
_fake_arxiv.Search = MagicMock()
monkeypatch.setattr(
"researchclaw.literature.arxiv_client.arxiv", _fake_arxiv,
)
mock_client = MagicMock()
mock_client.results.side_effect = _FakeHTTPError("Simulated arXiv HTTP error")
monkeypatch.setattr(
"researchclaw.literature.arxiv_client._get_client",
lambda: mock_client,
)
from researchclaw.literature.arxiv_client import _reset_circuit_breaker
_reset_circuit_breaker()
papers = search_arxiv("test", limit=10)
assert papers == []
# ──────────────────────────────────────────────────────────────────────
# Unified search & deduplication tests
# ──────────────────────────────────────────────────────────────────────
class TestDeduplication:
def test_dedup_by_doi(self) -> None:
p1 = _make_paper(paper_id="s2-1", doi="10.1234/a", citation_count=100)
p2 = _make_paper(
paper_id="arxiv-1", doi="10.1234/a", citation_count=50, source="arxiv"
)
result = _deduplicate([p1, p2])
assert len(result) == 1
assert result[0].citation_count == 100 # keeps higher
def test_dedup_by_arxiv_id(self) -> None:
p1 = _make_paper(
paper_id="s2-1", doi="", arxiv_id="2401.00001", citation_count=10
)
p2 = _make_paper(
paper_id="arxiv-1",
doi="",
arxiv_id="2401.00001",
citation_count=20,
source="arxiv",
)
result = _deduplicate([p1, p2])
assert len(result) == 1
assert result[0].citation_count == 20 # arxiv version had more
def test_dedup_by_title(self) -> None:
p1 = _make_paper(
paper_id="s2-1",
doi="",
arxiv_id="",
title="My Cool Paper",
citation_count=5,
)
p2 = _make_paper(
paper_id="s2-2",
doi="",
arxiv_id="",
title="My Cool Paper",
citation_count=10,
)
result = _deduplicate([p1, p2])
assert len(result) == 1
assert result[0].citation_count == 10
def test_dedup_no_duplicates(self) -> None:
p1 = _make_paper(paper_id="s2-1", title="Paper A", doi="10.1/a", arxiv_id="1111.11111")
p2 = _make_paper(paper_id="s2-2", title="Paper B", doi="10.1/b", arxiv_id="2222.22222")
result = _deduplicate([p1, p2])
assert len(result) == 2
def test_normalise_title(self) -> None:
assert _normalise_title(" The Great Paper!!! ") == "the great paper"
assert _normalise_title("A/B Testing: Methods") == "ab testing methods"
class TestPapersToBibtex:
def test_generates_combined(self) -> None:
p1 = _make_paper(paper_id="s2-1", title="Paper A")
p2 = _make_paper(paper_id="s2-2", title="Paper B", venue="ICML 2024")
bib = papers_to_bibtex([p1, p2])
assert bib.count("@") == 2
assert "Paper A" in bib
assert "Paper B" in bib
class TestSearchPapers:
def test_search_papers_combines_sources(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Mock both S2 and arXiv to verify combined search."""
s2_paper = _make_paper(
paper_id="s2-1", source="semantic_scholar", citation_count=100
)
arxiv_paper = _make_paper(
paper_id="arxiv-1",
title="Different Paper",
doi="10.2/b",
arxiv_id="2402.99999",
source="arxiv",
citation_count=50,
)
monkeypatch.setattr(
"researchclaw.literature.search.search_semantic_scholar",
lambda *a, **kw: [s2_paper],
)
monkeypatch.setattr(
"researchclaw.literature.search.search_arxiv",
lambda *a, **kw: [arxiv_paper],
)
monkeypatch.setattr("researchclaw.literature.search.time.sleep", lambda _: None)
papers = search_papers("test", sources=["semantic_scholar", "arxiv"])
assert len(papers) == 2
# Should be sorted by citation_count desc
assert papers[0].citation_count >= papers[1].citation_count
def test_default_sources_openalex_first(self) -> None:
"""OpenAlex should be the primary (first) source — least restrictive limits."""
from researchclaw.literature.search import _DEFAULT_SOURCES
assert _DEFAULT_SOURCES[0] == "openalex"
assert "semantic_scholar" in _DEFAULT_SOURCES
assert "arxiv" in _DEFAULT_SOURCES
def test_s2_failure_does_not_block_others(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When S2 fails, other sources should still return results."""
arxiv_paper = _make_paper(
paper_id="arxiv-ok", title="ArXiv Paper", source="arxiv",
doi="10.1/ax", arxiv_id="2401.99991",
)
monkeypatch.setattr(
"researchclaw.literature.search.search_openalex",
lambda *a, **kw: [],
)
monkeypatch.setattr(
"researchclaw.literature.search.search_semantic_scholar",
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("S2 down")),
)
monkeypatch.setattr(
"researchclaw.literature.search.search_arxiv",
lambda *a, **kw: [arxiv_paper],
)
monkeypatch.setattr("researchclaw.literature.search.time.sleep", lambda _: None)
papers = search_papers("test")
assert len(papers) >= 1
assert papers[0].source == "arxiv"
def test_search_papers_unknown_source(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr("researchclaw.literature.search.time.sleep", lambda _: None)
papers = search_papers("test", sources=["unknown_source"])
assert papers == []
def test_search_papers_multi_query(self, monkeypatch: pytest.MonkeyPatch) -> None:
call_count = 0
def mock_search(*a: Any, **kw: Any) -> list[Paper]:
nonlocal call_count
call_count += 1
return [
_make_paper(
paper_id=f"s2-{call_count}",
title=f"Unique Paper {call_count}",
doi=f"10.{call_count}/unique",
arxiv_id=f"240{call_count}.{call_count:05d}",
)
]
monkeypatch.setattr(
"researchclaw.literature.search.search_papers",
mock_search,
)
monkeypatch.setattr("researchclaw.literature.search.time.sleep", lambda _: None)
papers = search_papers_multi_query(["q1", "q2", "q3"])
assert call_count == 3
# All unique titles so no dedup
assert len(papers) == 3
# ──────────────────────────────────────────────────────────────────────
# Edge cases
# ──────────────────────────────────────────────────────────────────────
class TestEdgeCases:
def test_paper_with_no_meaningful_title_word(self) -> None:
"""cite_key should still work with stopword-only titles."""
p = _make_paper(title="The And For With", year=2024)
# All words are stopwords or <4 chars, keyword should be empty
key = p.cite_key
assert key.startswith("vaswani2024")
def test_paper_multiple_authors_bibtex(self) -> None:
p = _make_paper(
authors=(
Author(name="Alice One"),
Author(name="Bob Two"),
Author(name="Charlie Three"),
)
)
bib = p.to_bibtex()
assert "Alice One and Bob Two and Charlie Three" in bib
def test_empty_s2_response(self) -> None:
"""_parse_s2_paper shouldn't crash on minimal data."""
p = _parse_s2_paper({"paperId": "x"})
assert p.paper_id == "s2-x"
assert p.title == ""
assert p.authors == ()
# ──────────────────────────────────────────────────────────────────────
# arXiv circuit breaker tests
# ──────────────────────────────────────────────────────────────────────
class TestArxivCircuitBreaker:
def setup_method(self) -> None:
from researchclaw.literature.arxiv_client import _reset_circuit_breaker
_reset_circuit_breaker()
def test_failure_triggers_circuit_breaker(self) -> None:
"""Three consecutive failures should trip the circuit breaker."""
from researchclaw.literature import arxiv_client
# Simulate 3 consecutive failures
for _ in range(3):
arxiv_client._cb_on_failure()
assert arxiv_client._cb_state == arxiv_client._CB_OPEN
assert arxiv_client._cb_trip_count == 1
def test_breaker_open_skips_requests(self) -> None:
"""When breaker is OPEN, requests should be skipped."""
import time as time_mod
from researchclaw.literature import arxiv_client
arxiv_client._cb_state = arxiv_client._CB_OPEN
arxiv_client._cb_open_since = time_mod.monotonic()
arxiv_client._cb_cooldown_sec = 999
assert not arxiv_client._cb_should_allow()
def test_success_resets_breaker(self) -> None:
"""A successful request should reset the circuit breaker."""
from researchclaw.literature import arxiv_client
arxiv_client._cb_state = arxiv_client._CB_HALF_OPEN
arxiv_client._cb_consecutive_429s = 2
arxiv_client._cb_on_success()
assert arxiv_client._cb_state == arxiv_client._CB_CLOSED
assert arxiv_client._cb_consecutive_429s == 0
def test_half_open_probe_failure_doubles_cooldown(self) -> None:
"""Probe failure in HALF_OPEN should double the cooldown."""
from researchclaw.literature import arxiv_client
arxiv_client._cb_state = arxiv_client._CB_HALF_OPEN
initial_cooldown = arxiv_client._cb_cooldown_sec
arxiv_client._cb_on_failure()
assert arxiv_client._cb_state == arxiv_client._CB_OPEN
assert arxiv_client._cb_cooldown_sec == min(initial_cooldown * 2, 600)
def test_search_with_http_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""search_arxiv should return empty list on HTTPError."""
import types
_fake_arxiv = types.ModuleType("arxiv")
class _FakeHTTPError(Exception):
pass
class _FakeUnexpectedEmptyPageError(Exception):
pass
_fake_arxiv.HTTPError = _FakeHTTPError
_fake_arxiv.UnexpectedEmptyPageError = _FakeUnexpectedEmptyPageError
_fake_arxiv.SortCriterion = MagicMock()
_fake_arxiv.SortOrder = MagicMock()
_fake_arxiv.Search = MagicMock()
monkeypatch.setattr(
"researchclaw.literature.arxiv_client.arxiv", _fake_arxiv,
)
mock_client = MagicMock()
mock_client.results.side_effect = _FakeHTTPError("Simulated 429")
monkeypatch.setattr(
"researchclaw.literature.arxiv_client._get_client",
lambda: mock_client,
)
from researchclaw.literature.arxiv_client import _reset_circuit_breaker
_reset_circuit_breaker()
papers = search_arxiv("test", limit=5)
assert papers == []
# ──────────────────────────────────────────────────────────────────────
# OpenAlex client tests
# ──────────────────────────────────────────────────────────────────────
SAMPLE_OPENALEX_RESPONSE = {
"results": [
{
"id": "https://openalex.org/W123456",
"title": "Attention Is All You Need",
"authorships": [
{
"author": {"display_name": "Ashish Vaswani"},
"institutions": [{"display_name": "Google Brain"}],
}
],
"publication_year": 2017,
"primary_location": {
"source": {"display_name": "NeurIPS"}
},
"cited_by_count": 85000,
"doi": "https://doi.org/10.5555/3295222.3295349",
"ids": {
"openalex": "https://openalex.org/W123456",
"arxiv": "https://arxiv.org/abs/1706.03762",
},
"abstract_inverted_index": {
"The": [0],
"dominant": [1],
"models": [2, 6],
"are": [3],
"based": [4],
"on": [5],
},
}
]
}
class TestOpenAlex:
def test_parse_openalex_response(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock urllib to return sample OpenAlex response."""
from researchclaw.literature.openalex_client import search_openalex
response_bytes = json.dumps(SAMPLE_OPENALEX_RESPONSE).encode("utf-8")
mock_resp = MagicMock()
mock_resp.read.return_value = response_bytes
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
monkeypatch.setattr(
"researchclaw.literature.openalex_client.urllib.request.urlopen",
lambda *a, **kw: mock_resp,
)
papers = search_openalex("attention", limit=5)
assert len(papers) == 1
p = papers[0]
assert p.title == "Attention Is All You Need"
assert p.year == 2017
assert p.citation_count == 85000
assert p.doi == "10.5555/3295222.3295349"
assert p.arxiv_id == "1706.03762"
assert p.source == "openalex"
assert p.authors[0].name == "Ashish Vaswani"
def test_abstract_reconstruction(self) -> None:
from researchclaw.literature.openalex_client import _reconstruct_abstract
inv_idx = {"Hello": [0], "world": [1], "foo": [3], "bar": [2]}
result = _reconstruct_abstract(inv_idx)
assert result == "Hello world bar foo"
def test_abstract_empty(self) -> None:
from researchclaw.literature.openalex_client import _reconstruct_abstract
assert _reconstruct_abstract(None) == ""
assert _reconstruct_abstract({}) == ""
def test_openalex_network_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Should return empty list on network error."""
from researchclaw.literature.openalex_client import search_openalex
monkeypatch.setattr(
"researchclaw.literature.openalex_client.urllib.request.urlopen",
lambda *a, **kw: (_ for _ in ()).throw(urllib.error.URLError("timeout")),
)
monkeypatch.setattr(
"researchclaw.literature.openalex_client.time.sleep", lambda _: None,
)
papers = search_openalex("test", limit=5)
assert papers == []
# ──────────────────────────────────────────────────────────────────────
# Multi-source fallback tests
# ──────────────────────────────────────────────────────────────────────
class TestMultiSourceFallback:
def test_openalex_failure_falls_back_to_s2_and_arxiv(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When OpenAlex fails, S2 and arXiv should still return results."""
arxiv_paper = _make_paper(
paper_id="arxiv-ok", title="ArXiv Paper", source="arxiv",
doi="10.1/ax", arxiv_id="2401.99999",
)
s2_paper = _make_paper(
paper_id="s2-ok", title="S2 Paper", source="semantic_scholar",
doi="10.1/s2", arxiv_id="2402.99999",
)
monkeypatch.setattr(
"researchclaw.literature.search.search_openalex",
lambda *a, **kw: (_ for _ in ()).throw(RuntimeError("OpenAlex down")),
)
monkeypatch.setattr(
"researchclaw.literature.search.search_semantic_scholar",
lambda *a, **kw: [s2_paper],
)
monkeypatch.setattr(
"researchclaw.literature.search.search_arxiv",
lambda *a, **kw: [arxiv_paper],
)
monkeypatch.setattr("researchclaw.literature.search.time.sleep", lambda _: None)
papers = search_papers("test")
assert len(papers) >= 1
sources = {p.source for p in papers}
assert "semantic_scholar" in sources or "arxiv" in sources
# ──────────────────────────────────────────────────────────────────────
# Cache TTL tests
# ──────────────────────────────────────────────────────────────────────
class TestCacheTTL:
def test_source_specific_ttl(self, tmp_path: Any) -> None:
"""arXiv cache should expire after 24h, not 7 days."""
from researchclaw.literature.cache import get_cached, put_cache, _SOURCE_TTL
assert _SOURCE_TTL["arxiv"] == 86400 # 24h
assert _SOURCE_TTL["semantic_scholar"] == 86400 * 3
# Put and get immediately — should hit
put_cache("test", "arxiv", 10, [{"paper_id": "x", "title": "Y"}], cache_base=tmp_path)
result = get_cached("test", "arxiv", 10, cache_base=tmp_path)
assert result is not None
assert len(result) == 1
def test_citation_verify_ttl_is_permanent(self) -> None:
from researchclaw.literature.cache import _SOURCE_TTL
assert _SOURCE_TTL["citation_verify"] >= 86400 * 365
import urllib.error
================================================
FILE: tests/test_rc_llm.py
================================================
from __future__ import annotations
import json
import urllib.request
from types import SimpleNamespace
from typing import Any, Mapping
import pytest
from researchclaw.llm.client import LLMClient, LLMConfig, LLMResponse, _NEW_PARAM_MODELS
class _DummyHTTPResponse:
def __init__(self, payload: Mapping[str, Any]):
self._payload = payload
def read(self) -> bytes:
return json.dumps(self._payload).encode("utf-8")
def __enter__(self) -> _DummyHTTPResponse:
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
return None
def _make_client(
*,
api_key: str = "test-key",
primary_model: str = "gpt-5.2",
fallback_models: list[str] | None = None,
timeout_sec: int = 120,
) -> LLMClient:
config = LLMConfig(
base_url="https://api.example.com/v1",
api_key=api_key,
primary_model=primary_model,
fallback_models=fallback_models or ["gpt-5.1", "gpt-4.1", "gpt-4o"],
timeout_sec=timeout_sec,
)
return LLMClient(config)
def _capture_raw_call(
monkeypatch: pytest.MonkeyPatch, *, model: str, response_data: Mapping[str, Any]
) -> tuple[dict[str, object], LLMResponse, dict[str, object]]:
captured: dict[str, object] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["request"] = req
captured["timeout"] = timeout
return _DummyHTTPResponse(response_data)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
client = _make_client()
resp = client._raw_call(
model, [{"role": "user", "content": "hello"}], 123, 0.2, False
)
request = captured["request"]
assert isinstance(request, urllib.request.Request)
data = request.data
assert isinstance(data, bytes)
body = json.loads(data.decode("utf-8"))
assert isinstance(body, dict)
return body, resp, captured
def test_llm_config_defaults():
config = LLMConfig(base_url="https://api.example.com/v1", api_key="k")
assert config.primary_model == "gpt-4o"
assert config.max_tokens == 4096
assert config.temperature == 0.7
def test_llm_config_custom_values():
config = LLMConfig(
base_url="https://custom.example/v1",
api_key="custom",
primary_model="o3",
fallback_models=["o3-mini"],
max_tokens=2048,
temperature=0.1,
timeout_sec=30,
)
assert config.primary_model == "o3"
assert config.fallback_models == ["o3-mini"]
assert config.max_tokens == 2048
assert config.temperature == 0.1
assert config.timeout_sec == 30
def test_llm_response_dataclass_fields():
response = LLMResponse(content="ok", model="gpt-5.2", completion_tokens=10)
assert response.content == "ok"
assert response.model == "gpt-5.2"
assert response.completion_tokens == 10
def test_llm_response_defaults():
response = LLMResponse(content="ok", model="gpt-5.2")
assert response.prompt_tokens == 0
assert response.completion_tokens == 0
assert response.total_tokens == 0
assert response.finish_reason == ""
assert response.truncated is False
assert response.raw == {}
def test_llm_client_initialization_stores_config():
config = LLMConfig(base_url="https://api.example.com/v1", api_key="k")
client = LLMClient(config)
assert client.config is config
def test_llm_client_model_chain_is_primary_plus_fallbacks():
client = _make_client(
primary_model="gpt-5.4", fallback_models=["gpt-4.1", "gpt-4o"]
)
assert client._model_chain == ["gpt-5.4", "gpt-4.1", "gpt-4o"]
def test_needs_max_completion_tokens_for_new_models():
model = "gpt-5.2"
assert any(model.startswith(prefix) for prefix in _NEW_PARAM_MODELS)
def test_needs_max_completion_tokens_false_for_old_models():
model = "gpt-4o"
assert not any(model.startswith(prefix) for prefix in _NEW_PARAM_MODELS)
def test_build_request_body_structure_via_raw_call(monkeypatch: pytest.MonkeyPatch):
response = {"choices": [{"message": {"content": "x"}, "finish_reason": "stop"}]}
body, _, _ = _capture_raw_call(monkeypatch, model="gpt-4o", response_data=response)
assert body["model"] == "gpt-4o"
assert body["messages"] == [{"role": "user", "content": "hello"}]
assert body["temperature"] == 0.2
def test_build_request_uses_max_completion_tokens_for_new_models(
monkeypatch: pytest.MonkeyPatch,
):
response = {"choices": [{"message": {"content": "x"}, "finish_reason": "stop"}]}
body, _, _ = _capture_raw_call(monkeypatch, model="gpt-5.2", response_data=response)
# Reasoning models enforce a minimum of 32768 tokens
assert body["max_completion_tokens"] == 32768
assert "max_tokens" not in body
def test_build_request_uses_max_tokens_for_old_models(monkeypatch: pytest.MonkeyPatch):
response = {"choices": [{"message": {"content": "x"}, "finish_reason": "stop"}]}
body, _, _ = _capture_raw_call(monkeypatch, model="gpt-4.1", response_data=response)
assert body["max_tokens"] == 123
assert "max_completion_tokens" not in body
def test_parse_response_with_valid_payload_via_raw_call(
monkeypatch: pytest.MonkeyPatch,
):
response = {
"model": "gpt-5.2",
"choices": [{"message": {"content": "hello"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
_, parsed, _ = _capture_raw_call(
monkeypatch, model="gpt-5.2", response_data=response
)
assert parsed.content == "hello"
assert parsed.model == "gpt-5.2"
assert parsed.prompt_tokens == 1
assert parsed.total_tokens == 3
def test_parse_response_truncated_when_finish_reason_length(
monkeypatch: pytest.MonkeyPatch,
):
response = {
"choices": [{"message": {"content": "partial"}, "finish_reason": "length"}],
"usage": {},
}
_, parsed, _ = _capture_raw_call(
monkeypatch, model="gpt-5.2", response_data=response
)
assert parsed.finish_reason == "length"
assert parsed.truncated is True
def test_parse_response_missing_optional_fields_graceful(
monkeypatch: pytest.MonkeyPatch,
):
response = {"choices": [{"message": {"content": None}}]}
_, parsed, _ = _capture_raw_call(
monkeypatch, model="gpt-5.2", response_data=response
)
assert parsed.content == ""
assert parsed.prompt_tokens == 0
assert parsed.completion_tokens == 0
assert parsed.total_tokens == 0
assert parsed.finish_reason == ""
def test_from_rc_config_builds_expected_llm_config():
rc_config = SimpleNamespace(
llm=SimpleNamespace(
base_url="https://proxy.example/v1",
api_key="inline-key",
api_key_env="OPENAI_API_KEY",
primary_model="o3",
fallback_models=("o3-mini", "gpt-4o"),
)
)
client = LLMClient.from_rc_config(rc_config)
assert client.config.base_url == "https://proxy.example/v1"
assert client.config.api_key == "inline-key"
assert client.config.primary_model == "o3"
assert client.config.fallback_models == ["o3-mini", "gpt-4o"]
def test_from_rc_config_reads_api_key_from_env_when_missing(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("RC_TEST_API_KEY", "env-key")
rc_config = SimpleNamespace(
llm=SimpleNamespace(
base_url="https://proxy.example/v1",
api_key="",
api_key_env="RC_TEST_API_KEY",
primary_model="gpt-5.2",
fallback_models=(),
)
)
client = LLMClient.from_rc_config(rc_config)
assert client.config.api_key == "env-key"
def test_new_param_models_contains_expected_models():
expected = {"gpt-5", "gpt-5.1", "gpt-5.2", "gpt-5.4", "o3", "o3-mini", "o4-mini"}
assert expected.issubset(_NEW_PARAM_MODELS)
def test_raw_call_adds_json_mode_response_format(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, object] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["request"] = req
return _DummyHTTPResponse({"choices": [{"message": {"content": "{}"}}]})
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
client = _make_client()
_ = client._raw_call(
"gpt-5.2", [{"role": "user", "content": "json"}], 50, 0.1, True
)
request = captured["request"]
assert isinstance(request, urllib.request.Request)
data = request.data
assert isinstance(data, bytes)
body = json.loads(data.decode("utf-8"))
assert isinstance(body, dict)
assert body["response_format"] == {"type": "json_object"}
def test_raw_call_sets_auth_and_user_agent_headers(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, object] = {}
def fake_urlopen(req: urllib.request.Request, timeout: int) -> _DummyHTTPResponse:
captured["request"] = req
captured["timeout"] = timeout
return _DummyHTTPResponse({"choices": [{"message": {"content": "ok"}}]})
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
client = _make_client(api_key="secret", timeout_sec=77)
_ = client._raw_call("gpt-5.2", [{"role": "user", "content": "hi"}], 20, 0.6, False)
request = captured["request"]
assert isinstance(request, urllib.request.Request)
headers = {k.lower(): v for k, v in request.headers.items()}
assert headers["authorization"] == "Bearer secret"
assert "user-agent" in headers
timeout = captured["timeout"]
assert timeout == 77
def test_chat_prepends_system_message(monkeypatch: pytest.MonkeyPatch):
captured: dict[str, list[dict[str, str]]] = {}
def fake_raw_call(
self: LLMClient,
model: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
json_mode: bool,
) -> LLMResponse:
captured["messages"] = messages
return LLMResponse(content="ok", model=model)
monkeypatch.setattr(LLMClient, "_raw_call", fake_raw_call)
client = _make_client(primary_model="gpt-5.2", fallback_models=["gpt-4o"])
client.chat([{"role": "user", "content": "q"}], system="sys")
assert captured["messages"][0] == {"role": "system", "content": "sys"}
def test_chat_uses_fallback_after_first_model_error(monkeypatch: pytest.MonkeyPatch):
calls: list[str] = []
def fake_call_with_retry(
self: LLMClient,
model: str,
messages: list[dict[str, str]],
max_tokens: int,
temperature: float,
json_mode: bool,
) -> LLMResponse:
_ = (self, messages, max_tokens, temperature, json_mode)
calls.append(model)
if model == "gpt-5.2":
raise RuntimeError("first failed")
return LLMResponse(content="ok", model=model)
monkeypatch.setattr(LLMClient, "_call_with_retry", fake_call_with_retry)
client = _make_client(primary_model="gpt-5.2", fallback_models=["gpt-5.1"])
response = client.chat([{"role": "user", "content": "x"}])
assert calls == ["gpt-5.2", "gpt-5.1"]
assert response.model == "gpt-5.1"
================================================
FILE: tests/test_rc_novelty.py
================================================
"""Tests for researchclaw.literature.novelty — novelty detection module."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.literature.novelty import (
_assess_novelty,
_build_novelty_queries,
_compute_similarity,
_extract_keywords,
_jaccard_keywords,
check_novelty,
)
# ---------------------------------------------------------------------------
# _extract_keywords
# ---------------------------------------------------------------------------
class TestExtractKeywords:
def test_basic_extraction(self) -> None:
kws = _extract_keywords("Transformer attention mechanisms for NLP")
assert "transformer" in kws
assert "attention" in kws
assert "mechanisms" in kws
assert "nlp" in kws
def test_stop_words_removed(self) -> None:
kws = _extract_keywords("the model is a new approach for data")
# "the", "is", "a", "new", "approach", "for", "data", "model" are stop words
assert "the" not in kws
assert "is" not in kws
def test_short_tokens_removed(self) -> None:
kws = _extract_keywords("AI ML RL deep reinforcement learning")
# "AI", "ML", "RL" are only 2 chars → removed
assert "ai" not in kws
assert "deep" in kws
assert "reinforcement" in kws
def test_deduplication(self) -> None:
kws = _extract_keywords("attention attention attention mechanism")
assert kws.count("attention") == 1
def test_empty_input(self) -> None:
assert _extract_keywords("") == []
def test_preserves_order(self) -> None:
kws = _extract_keywords("alpha beta gamma delta")
assert kws == ["alpha", "beta", "gamma", "delta"]
# ---------------------------------------------------------------------------
# _jaccard_keywords
# ---------------------------------------------------------------------------
class TestJaccardKeywords:
def test_identical_sets(self) -> None:
assert _jaccard_keywords(["a", "b", "c"], ["a", "b", "c"]) == 1.0
def test_disjoint_sets(self) -> None:
assert _jaccard_keywords(["a", "b"], ["c", "d"]) == 0.0
def test_partial_overlap(self) -> None:
# {a, b, c} & {b, c, d} = {b, c} / {a, b, c, d} = 2/4 = 0.5
assert _jaccard_keywords(["a", "b", "c"], ["b", "c", "d"]) == 0.5
def test_empty_first(self) -> None:
assert _jaccard_keywords([], ["a", "b"]) == 0.0
def test_empty_second(self) -> None:
assert _jaccard_keywords(["a", "b"], []) == 0.0
def test_both_empty(self) -> None:
assert _jaccard_keywords([], []) == 0.0
# ---------------------------------------------------------------------------
# _compute_similarity
# ---------------------------------------------------------------------------
class TestComputeSimilarity:
def test_returns_float_0_to_1(self) -> None:
sim = _compute_similarity(
["transformer", "attention"],
"Transformer Attention in NLP",
"We study attention mechanisms in transformer models.",
)
assert 0.0 <= sim <= 1.0
def test_high_similarity_for_matching_content(self) -> None:
kws = ["transformer", "attention", "mechanisms", "self-attention"]
sim = _compute_similarity(
kws,
"Self-Attention Mechanisms in Transformers",
"This paper studies transformer self-attention mechanisms in detail.",
)
assert sim > 0.1 # should have meaningful overlap
def test_low_similarity_for_unrelated_content(self) -> None:
kws = ["quantum", "computing", "entanglement", "qubit"]
sim = _compute_similarity(
kws,
"Deep Learning for Image Classification",
"We propose a convolutional neural network for classifying images.",
)
assert sim < 0.1
def test_empty_keywords(self) -> None:
sim = _compute_similarity([], "Some title", "Some abstract")
assert sim == 0.0
# ---------------------------------------------------------------------------
# _build_novelty_queries
# ---------------------------------------------------------------------------
class TestBuildNoveltyQueries:
def test_includes_topic(self) -> None:
queries = _build_novelty_queries("Reinforcement Learning", "No hypotheses")
assert queries[0] == "Reinforcement Learning"
def test_extracts_hypothesis_titles(self) -> None:
hyp_text = (
"## H1: Adaptive learning rates improve convergence\n"
"Details about H1...\n\n"
"## H2: Curriculum learning reduces sample complexity\n"
"Details about H2...\n"
)
queries = _build_novelty_queries("RL topic", hyp_text)
assert len(queries) >= 3 # topic + H1 + H2
def test_caps_at_5(self) -> None:
hyp_text = "\n".join(
f"## H{i}: Hypothesis number {i} with enough text to pass length filter"
for i in range(1, 10)
)
queries = _build_novelty_queries("Topic", hyp_text)
assert len(queries) <= 5
def test_skips_short_titles(self) -> None:
hyp_text = "## H1: Short\n## H2: This is a longer hypothesis title\n"
queries = _build_novelty_queries("Topic", hyp_text)
# "Short" is < 10 chars → excluded
assert not any("Short" in q for q in queries)
def test_empty_hypotheses(self) -> None:
queries = _build_novelty_queries("Topic", "")
assert len(queries) >= 1
assert queries[0] == "Topic"
# ---------------------------------------------------------------------------
# _assess_novelty
# ---------------------------------------------------------------------------
class TestAssessNovelty:
def test_no_similar_papers_is_high(self) -> None:
score, assessment = _assess_novelty([], 0.25)
assert score == 1.0
assert assessment == "high"
def test_moderate_similarity(self) -> None:
papers = [{"similarity": 0.35, "citation_count": 10}]
score, assessment = _assess_novelty(papers, 0.25)
assert 0.45 <= score <= 0.85
assert assessment in ("high", "moderate")
def test_high_similarity_low_novelty(self) -> None:
papers = [{"similarity": 0.8, "citation_count": 200}]
score, assessment = _assess_novelty(papers, 0.25)
assert score <= 0.3
assert assessment in ("low", "critical")
def test_multiple_high_impact_overlaps_penalize(self) -> None:
papers = [
{"similarity": 0.5, "citation_count": 100},
{"similarity": 0.45, "citation_count": 80},
{"similarity": 0.42, "citation_count": 60},
]
score, _ = _assess_novelty(papers, 0.25)
# Should be penalized for multiple high-citation overlaps
assert score < 0.6
def test_score_bounded_0_to_1(self) -> None:
papers = [{"similarity": 0.99, "citation_count": 5000}]
score, _ = _assess_novelty(papers, 0.25)
assert 0.0 <= score <= 1.0
def test_critical_assessment(self) -> None:
papers = [
{"similarity": 0.9, "citation_count": 200},
{"similarity": 0.85, "citation_count": 150},
]
score, assessment = _assess_novelty(papers, 0.25)
assert assessment == "critical"
assert score < 0.25
# ---------------------------------------------------------------------------
# check_novelty (integration)
# ---------------------------------------------------------------------------
class TestCheckNovelty:
"""Integration tests for check_novelty — mocks the real API calls."""
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_basic_flow(self, mock_search: MagicMock) -> None:
"""Smoke test: no similar papers found → high novelty."""
mock_search.return_value = []
result = check_novelty(
topic="Novel quantum-inspired optimization",
hypotheses_text="## H1: Quantum tunneling improves escape from local minima\n",
)
assert isinstance(result, dict)
assert result["novelty_score"] == 1.0
assert result["assessment"] in ("high", "insufficient_data")
assert result["recommendation"] in ("proceed", "proceed_with_caution")
assert result["topic"] == "Novel quantum-inspired optimization"
assert "generated" in result
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_with_similar_papers(self, mock_search: MagicMock) -> None:
"""Papers with keyword overlap → lower novelty."""
# Create a mock paper with overlapping keywords
mock_paper = MagicMock()
mock_paper.title = "Quantum-Inspired Optimization for Combinatorial Problems"
mock_paper.abstract = (
"We propose quantum-inspired optimization methods "
"using tunneling and superposition analogies to escape local minima."
)
mock_paper.paper_id = "abc123"
mock_paper.year = 2024
mock_paper.venue = "NeurIPS"
mock_paper.citation_count = 45
mock_paper.url = "https://example.com/paper"
mock_paper.cite_key = "abc2024quantum"
mock_search.return_value = [mock_paper]
result = check_novelty(
topic="Quantum-inspired optimization",
hypotheses_text="## H1: Quantum tunneling improves escape from local minima\n",
)
assert result["similar_papers_found"] >= 0
assert 0.0 <= result["novelty_score"] <= 1.0
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_with_pipeline_papers(self, mock_search: MagicMock) -> None:
"""Papers from candidates.jsonl also checked for overlap."""
mock_search.return_value = []
pipeline_papers = [
{
"title": "Adaptive Learning Rate Schedules via Meta-Learning",
"abstract": "We study adaptive learning rate schedules.",
"paper_id": "p1",
"year": 2023,
"venue": "ICML",
"citation_count": 30,
"url": "https://example.com",
"cite_key": "p12023",
},
]
result = check_novelty(
topic="Adaptive learning rate schedules",
hypotheses_text="## H1: Meta-learning adaptive learning rate schedules\n",
papers_already_seen=pipeline_papers,
)
assert isinstance(result, dict)
assert "similar_papers" in result
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_search_failure_graceful(self, mock_search: MagicMock) -> None:
"""API failure should not crash — falls back to pipeline papers."""
mock_search.side_effect = RuntimeError("API down")
result = check_novelty(
topic="Some topic",
hypotheses_text="## H1: Some hypothesis with enough text\n",
)
assert isinstance(result, dict)
assert "novelty_score" in result
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_output_keys_complete(self, mock_search: MagicMock) -> None:
"""All expected keys present in output."""
mock_search.return_value = []
result = check_novelty(
topic="Test topic",
hypotheses_text="Some hypotheses text",
)
expected_keys = {
"topic",
"hypotheses_checked",
"search_queries",
"similar_papers_found",
"novelty_score",
"assessment",
"similar_papers",
"recommendation",
"similarity_threshold",
"search_coverage",
"total_papers_retrieved",
"generated",
}
assert expected_keys == set(result.keys())
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_recommendation_values(self, mock_search: MagicMock) -> None:
"""Recommendation must be one of proceed/differentiate/abort."""
mock_search.return_value = []
result = check_novelty(
topic="Test",
hypotheses_text="## H1: Hypothesis one\n",
)
assert result["recommendation"] in ("proceed", "differentiate", "abort", "proceed_with_caution")
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_json_serializable(self, mock_search: MagicMock) -> None:
"""Output must be JSON-serializable for writing to novelty_report.json."""
mock_search.return_value = []
result = check_novelty(
topic="JSON test",
hypotheses_text="## H1: Test hypothesis title is long enough\n",
)
serialized = json.dumps(result)
assert isinstance(serialized, str)
@patch("researchclaw.literature.search.search_papers_multi_query")
def test_similar_papers_capped_at_20(self, mock_search: MagicMock) -> None:
"""Output similar_papers list capped at 20."""
# Create many mock papers
papers = []
for i in range(40):
p = MagicMock()
p.title = f"Paper about optimization variant {i}"
p.abstract = "Optimization variant study"
p.paper_id = f"id_{i}"
p.year = 2024
p.venue = "Conf"
p.citation_count = 10
p.url = f"https://example.com/{i}"
p.cite_key = f"key{i}"
papers.append(p)
mock_search.return_value = papers
result = check_novelty(
topic="optimization",
hypotheses_text="## H1: Optimization variants improve performance\n",
similarity_threshold=0.0, # low threshold → many matches
)
assert len(result["similar_papers"]) <= 20
# ---------------------------------------------------------------------------
# Executor integration — _execute_hypothesis_gen with novelty check
# ---------------------------------------------------------------------------
class TestHypothesisGenNoveltyIntegration:
"""Test that _execute_hypothesis_gen integrates novelty check correctly."""
def test_novelty_report_written_when_available(self, tmp_path: Path) -> None:
"""Hypothesis gen should write novelty_report.json when check succeeds."""
from researchclaw.pipeline.executor import _execute_hypothesis_gen
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
# Set up minimal run directory
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-08"
stage_dir.mkdir()
# Create synthesis artifact from prior stage
stage_07 = run_dir / "stage-07"
stage_07.mkdir()
(stage_07 / "synthesis.md").write_text("## Synthesis\nSome synthesis content.")
data = {
"project": {"name": "novelty-test", "mode": "docs-first"},
"research": {"topic": "novelty testing"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local"},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline",
},
}
config = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
adapters = AdapterBundle()
with patch(
"researchclaw.literature.search.search_papers_multi_query"
) as mock_search:
mock_search.return_value = []
result = _execute_hypothesis_gen(stage_dir, run_dir, config, adapters)
assert result.stage.name == "HYPOTHESIS_GEN"
assert result.status.name == "DONE"
# hypotheses.md always written
assert (stage_dir / "hypotheses.md").exists()
# novelty_report.json should be written (API mocked as returning empty)
assert (stage_dir / "novelty_report.json").exists()
report = json.loads((stage_dir / "novelty_report.json").read_text())
assert report["novelty_score"] == 1.0 # no similar papers → max novelty
assert "novelty_report.json" in result.artifacts
def test_novelty_failure_does_not_block(self, tmp_path: Path) -> None:
"""If novelty check crashes, hypothesis gen still succeeds."""
from researchclaw.pipeline.executor import _execute_hypothesis_gen
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
run_dir = tmp_path / "run"
run_dir.mkdir()
stage_dir = run_dir / "stage-08"
stage_dir.mkdir()
stage_07 = run_dir / "stage-07"
stage_07.mkdir()
(stage_07 / "synthesis.md").write_text("## Synthesis\nContent.")
data = {
"project": {"name": "novelty-test", "mode": "docs-first"},
"research": {"topic": "novelty testing"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local"},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline",
},
}
config = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
adapters = AdapterBundle()
with patch(
"researchclaw.literature.novelty.check_novelty",
side_effect=RuntimeError("Novelty check exploded"),
):
result = _execute_hypothesis_gen(stage_dir, run_dir, config, adapters)
assert result.status.name == "DONE"
assert (stage_dir / "hypotheses.md").exists()
# novelty_report.json NOT written since check failed
assert not (stage_dir / "novelty_report.json").exists()
assert "novelty_report.json" not in result.artifacts
================================================
FILE: tests/test_rc_preflight.py
================================================
from __future__ import annotations
import urllib.error
from email.message import Message
from unittest.mock import patch
from researchclaw.llm.client import LLMClient, LLMConfig, LLMResponse
def _make_client(
*,
base_url: str = "https://api.example.com/v1",
api_key: str = "test-key",
primary_model: str = "gpt-test",
fallback_models: list[str] | None = None,
max_retries: int = 1,
) -> LLMClient:
return LLMClient(
LLMConfig(
base_url=base_url,
api_key=api_key,
primary_model=primary_model,
fallback_models=fallback_models or [],
max_retries=max_retries,
)
)
class TestPreflight:
def test_preflight_success(self):
client = _make_client()
mock_resp = LLMResponse(content="pong", model="gpt-test")
with patch.object(client, "chat", return_value=mock_resp):
ok, msg = client.preflight()
assert ok is True
assert "OK" in msg
assert "gpt-test" in msg
def test_preflight_401_invalid_key(self):
client = _make_client()
err = urllib.error.HTTPError("url", 401, "Unauthorized", Message(), None)
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "Invalid API key" in msg
def test_preflight_403_model_forbidden(self):
client = _make_client()
err = urllib.error.HTTPError("url", 403, "Forbidden", Message(), None)
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "not allowed" in msg
def test_preflight_404_bad_endpoint(self):
client = _make_client()
err = urllib.error.HTTPError("url", 404, "Not Found", Message(), None)
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "Endpoint not found" in msg
def test_preflight_429_rate_limited(self):
client = _make_client()
err = urllib.error.HTTPError("url", 429, "Too Many Requests", Message(), None)
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "Rate limited" in msg
def test_preflight_timeout(self):
client = _make_client()
err = urllib.error.URLError("timeout")
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "Connection failed" in msg
def test_preflight_all_models_failed(self):
client = _make_client()
err = RuntimeError("All models failed. Last error: ...")
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "All models failed" in msg
def test_preflight_unknown_http_error(self):
client = _make_client()
err = urllib.error.HTTPError("url", 500, "Server Error", Message(), None)
with patch.object(client, "chat", side_effect=err):
ok, msg = client.preflight()
assert ok is False
assert "HTTP 500" in msg
================================================
FILE: tests/test_rc_prompts.py
================================================
"""Tests for researchclaw.prompts — PromptManager and template rendering."""
from __future__ import annotations
import textwrap
from pathlib import Path
import pytest
import yaml
from researchclaw.prompts import (
PromptManager,
RenderedPrompt,
_render,
)
# ---------------------------------------------------------------------------
# _render() — template variable substitution
# ---------------------------------------------------------------------------
class TestRender:
def test_simple_substitution(self) -> None:
assert _render("Hello {name}!", {"name": "World"}) == "Hello World!"
def test_multiple_variables(self) -> None:
result = _render(
"Topic: {topic}, Domain: {domain}", {"topic": "RL", "domain": "ML"}
)
assert result == "Topic: RL, Domain: ML"
def test_missing_variable_left_untouched(self) -> None:
assert _render("Value: {unknown}", {}) == "Value: {unknown}"
def test_json_schema_not_substituted(self) -> None:
template = "Return JSON: {candidates:[...]} with >=8 rows."
assert _render(template, {"candidates": "SHOULD_NOT_APPEAR"}) == template
def test_json_schema_complex_not_substituted(self) -> None:
template = "Schema: {score_1_to_10:number, verdict:string}"
assert _render(template, {}) == template
def test_curly_braces_in_code_not_substituted(self) -> None:
template = "def foo(): { return 1; }"
assert _render(template, {}) == template
def test_underscore_variable(self) -> None:
assert _render("{my_var}", {"my_var": "ok"}) == "ok"
def test_numeric_suffix(self) -> None:
assert _render("{score_1}", {"score_1": "9"}) == "9"
def test_empty_template(self) -> None:
assert _render("", {"x": "y"}) == ""
def test_no_placeholders(self) -> None:
assert _render("No variables here", {"x": "y"}) == "No variables here"
# ---------------------------------------------------------------------------
# PromptManager — defaults
# ---------------------------------------------------------------------------
class TestPromptManagerDefaults:
def test_all_stages_present(self) -> None:
"""20 stages have for_stage() prompts; iterative_refine uses sub_prompts only."""
pm = PromptManager()
names = pm.stage_names()
assert len(names) >= 20
for required in [
"topic_init",
"problem_decompose",
"search_strategy",
"literature_collect",
"literature_screen",
"knowledge_extract",
"synthesis",
"hypothesis_gen",
"experiment_design",
"code_generation",
"resource_planning",
"result_analysis",
"research_decision",
"paper_outline",
"paper_draft",
"peer_review",
"paper_revision",
"quality_gate",
"knowledge_archive",
"export_publish",
]:
assert pm.has_stage(required), f"Missing stage: {required}"
def test_system_prompt_nonempty(self) -> None:
pm = PromptManager()
for name in pm.stage_names():
assert pm.system(name), f"Empty system prompt for {name}"
def test_for_stage_returns_rendered_prompt(self) -> None:
pm = PromptManager()
sp = pm.for_stage(
"topic_init",
topic="RL",
domains="ml",
project_name="test",
quality_threshold="4.0",
)
assert isinstance(sp, RenderedPrompt)
assert "RL" in sp.user
assert "ml" in sp.user
assert sp.system
def test_json_mode_stages(self) -> None:
pm = PromptManager()
json_stages = [
"search_strategy",
"literature_collect",
"literature_screen",
"knowledge_extract",
"resource_planning",
"quality_gate",
]
for stage in json_stages:
assert pm.json_mode(stage), f"{stage} should have json_mode=True"
def test_non_json_stages(self) -> None:
pm = PromptManager()
assert not pm.json_mode("topic_init")
assert not pm.json_mode("synthesis")
def test_max_tokens(self) -> None:
pm = PromptManager()
assert pm.max_tokens("code_generation") == 8192
assert pm.max_tokens("paper_draft") == 16384
assert pm.max_tokens("topic_init") is None
def test_block_topic_constraint(self) -> None:
pm = PromptManager()
block = pm.block("topic_constraint", topic="Neural Architecture Search")
assert "Neural Architecture Search" in block
assert "HARD TOPIC CONSTRAINT" in block
def test_block_pkg_hint(self) -> None:
pm = PromptManager()
block = pm.block("pkg_hint_sandbox")
assert "numpy" in block
assert "torch" in block # mentioned as prohibited
def test_sub_prompt_code_repair(self) -> None:
pm = PromptManager()
rp = pm.sub_prompt(
"code_repair",
fname="model.py",
issues_text="SyntaxError",
all_files_ctx="...",
)
assert "model.py" in rp.user
assert "SyntaxError" in rp.user
assert rp.system
def test_sub_prompt_iterative_improve(self) -> None:
pm = PromptManager()
ip = pm.sub_prompt(
"iterative_improve",
metric_key="val_loss",
metric_direction="minimize",
files_context="...",
run_summaries="...",
)
assert "val_loss" in ip.user
assert "minimize" in ip.user
def test_sub_prompt_iterative_repair(self) -> None:
pm = PromptManager()
irp = pm.sub_prompt(
"iterative_repair", issue_text="import error", all_files_ctx="..."
)
assert "import error" in irp.user
# ---------------------------------------------------------------------------
# PromptManager — YAML override
# ---------------------------------------------------------------------------
class TestPromptManagerOverrides:
def test_override_system_prompt(self, tmp_path: Path) -> None:
yaml_content = textwrap.dedent("""\
stages:
topic_init:
system: "You are a custom planner."
""")
override_file = tmp_path / "custom.yaml"
override_file.write_text(yaml_content, encoding="utf-8")
pm = PromptManager(override_file)
assert pm.system("topic_init") == "You are a custom planner."
# Other stages should keep defaults
assert pm.system("problem_decompose") == "You are a senior research strategist."
def test_override_user_template(self, tmp_path: Path) -> None:
yaml_content = textwrap.dedent("""\
stages:
topic_init:
user: "Custom prompt for {topic}."
""")
override_file = tmp_path / "custom.yaml"
override_file.write_text(yaml_content, encoding="utf-8")
pm = PromptManager(override_file)
result = pm.user("topic_init", topic="GAN")
assert result == "Custom prompt for GAN."
def test_override_block(self, tmp_path: Path) -> None:
yaml_content = textwrap.dedent("""\
blocks:
topic_constraint: "Stay focused on {topic}."
""")
override_file = tmp_path / "custom.yaml"
override_file.write_text(yaml_content, encoding="utf-8")
pm = PromptManager(override_file)
assert pm.block("topic_constraint", topic="NAS") == "Stay focused on NAS."
def test_override_json_mode(self, tmp_path: Path) -> None:
yaml_content = textwrap.dedent("""\
stages:
topic_init:
json_mode: true
""")
override_file = tmp_path / "custom.yaml"
override_file.write_text(yaml_content, encoding="utf-8")
pm = PromptManager(override_file)
assert pm.json_mode("topic_init") is True
def test_missing_file_uses_defaults(self, tmp_path: Path) -> None:
pm = PromptManager(tmp_path / "nonexistent.yaml")
assert pm.has_stage("topic_init")
assert pm.system("topic_init")
def test_invalid_yaml_uses_defaults(self, tmp_path: Path) -> None:
bad_file = tmp_path / "bad.yaml"
bad_file.write_text(": invalid: yaml: [", encoding="utf-8")
pm = PromptManager(bad_file)
assert pm.has_stage("topic_init")
def test_unknown_stage_in_override_ignored(self, tmp_path: Path) -> None:
yaml_content = textwrap.dedent("""\
stages:
nonexistent_stage:
system: "Should be ignored."
""")
override_file = tmp_path / "custom.yaml"
override_file.write_text(yaml_content, encoding="utf-8")
# Should not raise
pm = PromptManager(override_file)
assert not pm.has_stage("nonexistent_stage")
# ---------------------------------------------------------------------------
# PromptManager — export_yaml
# ---------------------------------------------------------------------------
class TestExportYaml:
def test_export_roundtrip(self, tmp_path: Path) -> None:
pm1 = PromptManager()
export_path = tmp_path / "exported.yaml"
pm1.export_yaml(export_path)
assert export_path.exists()
# Load it back — should parse cleanly
data = yaml.safe_load(export_path.read_text(encoding="utf-8"))
assert "stages" in data
assert "blocks" in data
assert "version" in data
def test_export_contains_all_stages(self, tmp_path: Path) -> None:
pm = PromptManager()
export_path = tmp_path / "exported.yaml"
pm.export_yaml(export_path)
data = yaml.safe_load(export_path.read_text(encoding="utf-8"))
for stage in pm.stage_names():
assert stage in data["stages"], f"Missing {stage} in export"
def test_export_with_overrides(self, tmp_path: Path) -> None:
override_file = tmp_path / "custom.yaml"
override_file.write_text(
"stages:\n topic_init:\n system: CUSTOM\n",
encoding="utf-8",
)
pm = PromptManager(override_file)
export_path = tmp_path / "exported.yaml"
pm.export_yaml(export_path)
data = yaml.safe_load(export_path.read_text(encoding="utf-8"))
assert data["stages"]["topic_init"]["system"] == "CUSTOM"
# ---------------------------------------------------------------------------
# RenderedPrompt dataclass
# ---------------------------------------------------------------------------
class TestRenderedPrompt:
def test_defaults(self) -> None:
rp = RenderedPrompt(system="sys", user="usr")
assert rp.json_mode is False
assert rp.max_tokens is None
def test_with_options(self) -> None:
rp = RenderedPrompt(system="s", user="u", json_mode=True, max_tokens=4096)
assert rp.json_mode is True
assert rp.max_tokens == 4096
def test_frozen(self) -> None:
rp = RenderedPrompt(system="s", user="u")
with pytest.raises(AttributeError):
rp.system = "modified" # type: ignore[misc]
================================================
FILE: tests/test_rc_quality.py
================================================
"""Tests for content quality assessment."""
from __future__ import annotations
# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false
from researchclaw.quality import (
assess_quality,
check_strict_quality,
compute_template_ratio,
detect_template_content,
)
REAL_ABSTRACT = (
"We propose a novel method for protein structure prediction using "
"graph neural networks. Our approach achieves state-of-the-art results "
"on the CASP14 benchmark with 3.2 GDT-TS improvement over AlphaFold2. "
"We demonstrate that incorporating side-chain interactions as graph "
"edges significantly improves local structure accuracy."
)
TEMPLATE_ABSTRACT = (
"Template abstract: This section will describe the main contributions "
"of our work. [INSERT your abstract here]. We will discuss the results "
"in the following sections. Replace this text with your actual content."
)
MIXED_CONTENT = (
"We propose a novel method for protein structure prediction.\n"
"[TODO: Add more details about the method]\n"
"Our experiments show significant improvements over baselines.\n"
"Template introduction: This section will describe the background."
)
REAL_PAPER_SECTION = (
"## Introduction\n\n"
"Recent advances in large language models have demonstrated remarkable "
"capabilities in natural language understanding and generation. However, "
"these models often struggle with factual consistency and hallucinate "
"information. In this work, we address this limitation by introducing "
"a retrieval-augmented generation framework that grounds model outputs "
"in verified knowledge sources.\n\n"
"Our key contributions are:\n"
"1. A novel attention mechanism for integrating retrieved passages\n"
"2. A training procedure that incentivizes factual consistency\n"
"3. Comprehensive evaluation on three benchmark datasets"
)
class TestDetectTemplateContent:
def test_real_text_no_matches(self):
matches = detect_template_content(REAL_ABSTRACT)
assert len(matches) == 0
def test_template_text_has_matches(self):
matches = detect_template_content(TEMPLATE_ABSTRACT)
assert len(matches) >= 3
def test_detects_insert_placeholder(self):
text = "The results show [INSERT your results here] improvement."
matches = detect_template_content(text)
assert any("Insert placeholder" in m.pattern_desc for m in matches)
def test_detects_todo_placeholder(self):
text = "Method description [TODO: complete this section]."
matches = detect_template_content(text)
assert any("TODO" in m.pattern_desc for m in matches)
def test_detects_template_section(self):
text = "Template introduction: This paper presents our work."
matches = detect_template_content(text)
assert any("Template section" in m.pattern_desc for m in matches)
def test_detects_future_tense_placeholder(self):
text = "This section will describe the methodology in detail."
matches = detect_template_content(text)
assert any("Future-tense" in m.pattern_desc for m in matches)
def test_detects_lorem_ipsum(self):
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
matches = detect_template_content(text)
assert any("Lorem ipsum" in m.pattern_desc for m in matches)
def test_match_has_line_number(self):
text = "Good line\n[TODO: fix this]\nAnother good line"
matches = detect_template_content(text)
assert len(matches) == 1
assert matches[0].line_number == 2
def test_real_paper_section_clean(self):
matches = detect_template_content(REAL_PAPER_SECTION)
assert len(matches) == 0
def test_empty_text(self):
matches = detect_template_content("")
assert len(matches) == 0
class TestComputeTemplateRatio:
def test_real_text_low_ratio(self):
ratio = compute_template_ratio(REAL_ABSTRACT)
assert ratio < 0.05
def test_template_text_high_ratio(self):
ratio = compute_template_ratio(TEMPLATE_ABSTRACT)
assert ratio > 0.5
def test_mixed_content_moderate_ratio(self):
ratio = compute_template_ratio(MIXED_CONTENT)
assert 0.1 < ratio < 0.9
def test_empty_text_zero_ratio(self):
ratio = compute_template_ratio("")
assert ratio == 0.0
def test_ratio_bounded_0_1(self):
ratio = compute_template_ratio(TEMPLATE_ABSTRACT)
assert 0.0 <= ratio <= 1.0
def test_real_paper_section_low_ratio(self):
ratio = compute_template_ratio(REAL_PAPER_SECTION)
assert ratio < 0.05
class TestAssessQuality:
def test_report_has_all_fields(self):
report = assess_quality(REAL_ABSTRACT)
assert report.total_lines > 0
assert report.total_chars > 0
assert isinstance(report.template_ratio, float)
assert isinstance(report.template_matches, tuple)
def test_report_to_dict(self):
report = assess_quality(MIXED_CONTENT)
d = report.to_dict()
assert "template_ratio" in d
assert "template_matches" in d
assert "has_template_content" in d
assert "match_count" in d
def test_report_has_template_flag(self):
report = assess_quality(TEMPLATE_ABSTRACT)
assert report.has_template_content is True
report2 = assess_quality(REAL_ABSTRACT)
assert report2.has_template_content is False
class TestCheckStrictQuality:
def test_real_text_passes(self):
passed, _msg = check_strict_quality(REAL_ABSTRACT)
assert passed is True
def test_template_text_fails(self):
passed, msg = check_strict_quality(TEMPLATE_ABSTRACT)
assert passed is False
assert "Template content detected" in msg
def test_custom_threshold(self):
passed, _msg = check_strict_quality(TEMPLATE_ABSTRACT, threshold=1.0)
assert passed is True
def test_failure_message_includes_examples(self):
_passed, msg = check_strict_quality(TEMPLATE_ABSTRACT)
assert "L" in msg
================================================
FILE: tests/test_rc_report.py
================================================
# pyright: basic, reportMissingImports=false, reportUnusedCallResult=false
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.report import generate_report
class TestReport:
def test_report_missing_run_dir(self, tmp_path: Path):
with pytest.raises(FileNotFoundError):
generate_report(tmp_path / "nonexistent")
def test_report_no_summary(self, tmp_path: Path):
with pytest.raises(ValueError, match="pipeline_summary"):
generate_report(tmp_path)
def test_report_minimal_run(self, tmp_path: Path):
(tmp_path / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "rc-test-123",
"stages_executed": 23,
"stages_done": 23,
"stages_blocked": 0,
"stages_failed": 0,
"final_status": "done",
"generated": "2026-03-10T12:00:00Z",
}
)
)
report = generate_report(tmp_path)
assert "# ResearchClaw Run Report" in report
assert "rc-test-123" in report
assert "✅" in report
def test_report_with_paper(self, tmp_path: Path):
(tmp_path / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "test",
"stages_executed": 1,
"stages_done": 1,
"stages_failed": 0,
"final_status": "done",
"generated": "now",
}
)
)
draft_dir = tmp_path / "stage-17"
draft_dir.mkdir()
(draft_dir / "paper_draft.md").write_text(
"This is a paper with some words in it."
)
report = generate_report(tmp_path)
assert "Paper" in report
assert "words" in report
def test_report_with_citations(self, tmp_path: Path):
(tmp_path / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "test",
"stages_executed": 1,
"stages_done": 1,
"stages_failed": 0,
"final_status": "done",
"generated": "now",
}
)
)
verify_dir = tmp_path / "stage-23"
verify_dir.mkdir()
(verify_dir / "verification_report.json").write_text(
json.dumps(
{
"total_references": 10,
"verified_count": 8,
"suspicious_count": 1,
"hallucinated_count": 1,
}
)
)
report = generate_report(tmp_path)
assert "Citations" in report
assert "8/10" in report
assert "Suspicious: 1" in report
def test_report_with_failures(self, tmp_path: Path):
(tmp_path / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "test",
"stages_executed": 5,
"stages_done": 3,
"stages_failed": 2,
"final_status": "failed",
"generated": "now",
}
)
)
report = generate_report(tmp_path)
assert "❌" in report
assert "Warnings" in report
assert "2 stage(s) failed" in report
def test_report_with_experiment_results(self, tmp_path: Path):
(tmp_path / "pipeline_summary.json").write_text(
json.dumps(
{
"run_id": "test",
"stages_executed": 1,
"stages_done": 1,
"stages_failed": 0,
"final_status": "done",
"generated": "now",
}
)
)
exp_dir = tmp_path / "stage-12"
exp_dir.mkdir()
(exp_dir / "experiment_results.json").write_text(
json.dumps(
{
"iterations": [{"loss": 0.5}, {"loss": 0.3}],
"best_metric": 0.3,
}
)
)
report = generate_report(tmp_path)
assert "Experiments" in report
assert "2 iterations" in report
================================================
FILE: tests/test_rc_runner.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnusedCallResult=false, reportAttributeAccessIssue=false, reportUnknownLambdaType=false
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, cast
import pytest
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.pipeline import runner as rc_runner
from researchclaw.pipeline.executor import StageResult
from researchclaw.pipeline.stages import STAGE_SEQUENCE, Stage, StageStatus
@pytest.fixture()
def rc_config(tmp_path: Path) -> RCConfig:
data = {
"project": {"name": "rc-runner-test", "mode": "docs-first"},
"research": {"topic": "pipeline testing"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local"},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline",
},
}
return RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
@pytest.fixture()
def adapters() -> AdapterBundle:
return AdapterBundle()
@pytest.fixture()
def run_dir(tmp_path: Path) -> Path:
path = tmp_path / "run"
path.mkdir()
return path
def _done(stage: Stage, artifacts: tuple[str, ...] = ("out.md",)) -> StageResult:
return StageResult(stage=stage, status=StageStatus.DONE, artifacts=artifacts)
def _failed(stage: Stage, msg: str = "boom") -> StageResult:
return StageResult(stage=stage, status=StageStatus.FAILED, artifacts=(), error=msg)
def _blocked(stage: Stage) -> StageResult:
return StageResult(
stage=stage,
status=StageStatus.BLOCKED_APPROVAL,
artifacts=("gate.md",),
decision="block",
)
def test_execute_pipeline_runs_stages_in_sequence(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
seen: list[Stage] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-seq",
config=rc_config,
adapters=adapters,
)
assert seen == list(STAGE_SEQUENCE)
assert len(results) == 23
assert all(r.status == StageStatus.DONE for r in results)
def test_execute_pipeline_stops_on_failed_stage(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
fail_stage = Stage.SEARCH_STRATEGY
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
if stage == fail_stage:
return _failed(stage, "forced failure")
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-fail",
config=rc_config,
adapters=adapters,
)
assert results[-1].stage == fail_stage
assert results[-1].status == StageStatus.FAILED
assert len(results) == int(fail_stage)
def test_execute_pipeline_stops_on_gate_when_stop_on_gate_enabled(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
gate_stage = Stage.LITERATURE_SCREEN
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
if stage == gate_stage:
return _blocked(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-gate-stop",
config=rc_config,
adapters=adapters,
stop_on_gate=True,
)
assert results[-1].stage == gate_stage
assert results[-1].status == StageStatus.BLOCKED_APPROVAL
assert len(results) == int(gate_stage)
def test_execute_pipeline_continues_after_gate_when_stop_on_gate_disabled(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
gate_stage = Stage.LITERATURE_SCREEN
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
if stage == gate_stage:
return _blocked(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-gate-continue",
config=rc_config,
adapters=adapters,
stop_on_gate=False,
)
assert len(results) == 23
assert any(item.status == StageStatus.BLOCKED_APPROVAL for item in results)
def test_execute_pipeline_writes_pipeline_summary_json(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-summary",
config=rc_config,
adapters=adapters,
)
summary_path = run_dir / "pipeline_summary.json"
assert summary_path.exists()
def test_pipeline_summary_has_expected_fields_and_values(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
if stage == Stage.LITERATURE_SCREEN:
return _blocked(stage)
if stage == Stage.HYPOTHESIS_GEN:
return _failed(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-summary-fields",
config=rc_config,
adapters=adapters,
)
summary = cast(
dict[str, Any],
json.loads((run_dir / "pipeline_summary.json").read_text(encoding="utf-8")),
)
assert summary["run_id"] == "run-summary-fields"
assert summary["stages_executed"] == len(results)
assert summary["stages_done"] == sum(
1 for r in results if r.status == StageStatus.DONE
)
assert summary["stages_blocked"] == 1
assert summary["stages_failed"] == 1
assert summary["from_stage"] == 1
assert summary["final_stage"] == int(Stage.HYPOTHESIS_GEN)
assert summary["final_status"] == "failed"
assert "generated" in summary
def test_execute_pipeline_from_stage_skips_earlier_stages(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
seen: list[Stage] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-from-stage",
config=rc_config,
adapters=adapters,
from_stage=Stage.PAPER_OUTLINE,
)
assert seen[0] == Stage.PAPER_OUTLINE
assert len(seen) == len(STAGE_SEQUENCE) - (int(Stage.PAPER_OUTLINE) - 1)
assert len(results) == len(seen)
def test_execute_pipeline_writes_kb_entries_when_kb_root_provided(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
tmp_path: Path,
) -> None:
calls: list[tuple[int, str, str]] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
stage_dir = run_dir / f"stage-{int(stage):02d}"
stage_dir.mkdir(parents=True, exist_ok=True)
(stage_dir / "out.md").write_text(f"stage {int(stage)}", encoding="utf-8")
return _done(stage)
def mock_write_stage_to_kb(
kb_root: Path,
stage_id: int,
stage_name: str,
run_id: str,
artifacts: list[str],
stage_dir: Path,
**kwargs,
):
_ = kb_root, artifacts, stage_dir, kwargs
calls.append((stage_id, stage_name, run_id))
return []
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
monkeypatch.setattr(rc_runner, "write_stage_to_kb", mock_write_stage_to_kb)
kb_root = tmp_path / "kb-out"
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-kb",
config=rc_config,
adapters=adapters,
kb_root=kb_root,
)
assert len(results) == 23
assert len(calls) == 23
assert calls[0] == (1, "topic_init", "run-kb")
def test_execute_pipeline_passes_auto_approve_flag_to_execute_stage(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
received: list[bool] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
received.append(kwargs["auto_approve_gates"])
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-auto-approve",
config=rc_config,
adapters=adapters,
auto_approve_gates=True,
)
assert received
assert all(received)
@pytest.mark.parametrize(
("stage", "started", "expected"),
[
(Stage.TOPIC_INIT, False, True),
(Stage.PROBLEM_DECOMPOSE, False, False),
(Stage.PAPER_DRAFT, True, True),
],
)
def test_should_start_logic(stage: Stage, started: bool, expected: bool) -> None:
assert rc_runner._should_start(stage, Stage.TOPIC_INIT, started) is expected
@pytest.mark.parametrize(
("results", "expected_status", "expected_final_stage"),
[
([], "no_stages", int(Stage.TOPIC_INIT)),
([_done(Stage.TOPIC_INIT)], "done", int(Stage.TOPIC_INIT)),
(
[_done(Stage.TOPIC_INIT), _failed(Stage.PROBLEM_DECOMPOSE)],
"failed",
int(Stage.PROBLEM_DECOMPOSE),
),
],
)
def test_build_pipeline_summary_core_fields(
results, expected_status: str, expected_final_stage: int
) -> None:
summary = rc_runner._build_pipeline_summary(
run_id="run-core",
results=results,
from_stage=Stage.TOPIC_INIT,
)
assert summary["run_id"] == "run-core"
assert summary["final_status"] == expected_status
assert summary["final_stage"] == expected_final_stage
def test_pipeline_prints_stage_progress(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
capsys: pytest.CaptureFixture[str],
) -> None:
mock_results = [
StageResult(
stage=Stage.TOPIC_INIT, status=StageStatus.DONE, artifacts=("topic.json",)
),
StageResult(
stage=Stage.PROBLEM_DECOMPOSE,
status=StageStatus.DONE,
artifacts=("tree.json",),
),
StageResult(
stage=Stage.SEARCH_STRATEGY,
status=StageStatus.FAILED,
artifacts=(),
error="LLM timeout",
),
]
call_idx = 0
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = stage, kwargs
nonlocal call_idx
idx = call_idx
call_idx += 1
return mock_results[min(idx, len(mock_results) - 1)]
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
monkeypatch.setattr(rc_runner, "write_stage_to_kb", lambda *args, **kwargs: [])
_ = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="rc-test-001",
config=rc_config,
adapters=adapters,
)
captured = capsys.readouterr()
assert "TOPIC_INIT — running..." in captured.out
assert "TOPIC_INIT — done" in captured.out
assert "SEARCH_STRATEGY — FAILED" in captured.out
assert "LLM timeout" in captured.out
def test_pipeline_prints_elapsed_time(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
capsys: pytest.CaptureFixture[str],
) -> None:
mock_result = StageResult(
stage=Stage.TOPIC_INIT,
status=StageStatus.DONE,
artifacts=("topic.json",),
)
mock_fail = StageResult(
stage=Stage.PROBLEM_DECOMPOSE,
status=StageStatus.FAILED,
artifacts=(),
error="test",
)
results_iter = iter([mock_result, mock_fail])
monkeypatch.setattr(
rc_runner, "execute_stage", lambda *args, **kwargs: next(results_iter)
)
monkeypatch.setattr(rc_runner, "write_stage_to_kb", lambda *args, **kwargs: [])
_ = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="rc-test-002",
config=rc_config,
adapters=adapters,
)
captured = capsys.readouterr()
import re
assert re.search(r"\d+\.\d+s\)", captured.out), (
f"No elapsed time found in: {captured.out}"
)
# ── PIVOT/PROCEED/REFINE decision loop tests ──
def _pivot_result(stage: Stage) -> StageResult:
return StageResult(
stage=stage, status=StageStatus.DONE, artifacts=("decision.md",), decision="pivot"
)
def _refine_result(stage: Stage) -> StageResult:
return StageResult(
stage=stage, status=StageStatus.DONE, artifacts=("decision.md",), decision="refine"
)
def test_pivot_decision_triggers_rollback_to_hypothesis_gen(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
seen: list[Stage] = []
pivot_count = 0
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
nonlocal pivot_count
if stage == Stage.RESEARCH_DECISION and pivot_count == 0:
pivot_count += 1
return _pivot_result(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-pivot",
config=rc_config,
adapters=adapters,
)
# Should have seen HYPOTHESIS_GEN at least twice (original + rollback)
hyp_gen_count = sum(1 for s in seen if s == Stage.HYPOTHESIS_GEN)
assert hyp_gen_count >= 2
# Decision history should be recorded
history_path = run_dir / "decision_history.json"
assert history_path.exists()
history = json.loads(history_path.read_text())
assert len(history) == 1
assert history[0]["decision"] == "pivot"
def test_refine_decision_triggers_rollback_to_iterative_refine(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
seen: list[Stage] = []
refine_count = 0
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
nonlocal refine_count
if stage == Stage.RESEARCH_DECISION and refine_count == 0:
refine_count += 1
return _refine_result(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-refine",
config=rc_config,
adapters=adapters,
)
# Should have seen ITERATIVE_REFINE at least twice
refine_stage_count = sum(1 for s in seen if s == Stage.ITERATIVE_REFINE)
assert refine_stage_count >= 2
def test_max_pivot_count_prevents_infinite_loop(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
seen: list[Stage] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
# Always PIVOT — should be limited by MAX_DECISION_PIVOTS
if stage == Stage.RESEARCH_DECISION:
return _pivot_result(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-max-pivot",
config=rc_config,
adapters=adapters,
)
# RESEARCH_DECISION should appear at most MAX_DECISION_PIVOTS + 1 times
from researchclaw.pipeline.stages import MAX_DECISION_PIVOTS
decision_count = sum(1 for s in seen if s == Stage.RESEARCH_DECISION)
assert decision_count <= MAX_DECISION_PIVOTS + 1
def test_proceed_decision_does_not_trigger_rollback(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
) -> None:
seen: list[Stage] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-proceed",
config=rc_config,
adapters=adapters,
)
# Should be exactly 23 stages, no rollback
assert len(seen) == 23
assert not (run_dir / "decision_history.json").exists()
def test_read_pivot_count_returns_zero_for_no_history(run_dir: Path) -> None:
assert rc_runner._read_pivot_count(run_dir) == 0
def test_record_decision_history_appends(run_dir: Path) -> None:
rc_runner._record_decision_history(run_dir, "pivot", Stage.HYPOTHESIS_GEN, 1)
rc_runner._record_decision_history(run_dir, "refine", Stage.ITERATIVE_REFINE, 2)
history = json.loads((run_dir / "decision_history.json").read_text())
assert len(history) == 2
assert history[0]["decision"] == "pivot"
assert history[1]["decision"] == "refine"
# ── Deliverables packaging tests ──
def _setup_stage_artifacts(run_dir: Path) -> None:
"""Create typical stage-22 and stage-23 output files for testing."""
s22 = run_dir / "stage-22"
s22.mkdir(parents=True, exist_ok=True)
(s22 / "paper_final.md").write_text("# My Paper\nContent here.", encoding="utf-8")
(s22 / "paper.tex").write_text("\\documentclass{article}\n\\begin{document}\nHello\n\\end{document}", encoding="utf-8")
(s22 / "references.bib").write_text("@article{smith2024,\n title={Test}\n}", encoding="utf-8")
code_dir = s22 / "code"
code_dir.mkdir()
(code_dir / "main.py").write_text("print('hello')", encoding="utf-8")
(code_dir / "requirements.txt").write_text("numpy\n", encoding="utf-8")
(code_dir / "README.md").write_text("# Code\n", encoding="utf-8")
s23 = run_dir / "stage-23"
s23.mkdir(parents=True, exist_ok=True)
(s23 / "paper_final_verified.md").write_text("# My Paper (verified)\nContent.", encoding="utf-8")
(s23 / "references_verified.bib").write_text("@article{smith2024,\n title={Test}\n}", encoding="utf-8")
(s23 / "verification_report.json").write_text(
json.dumps({"summary": {"total": 5, "verified": 4}}), encoding="utf-8"
)
def test_package_deliverables_collects_all_artifacts(
run_dir: Path, rc_config: RCConfig
) -> None:
_setup_stage_artifacts(run_dir)
dest = rc_runner._package_deliverables(run_dir, "run-pkg-test", rc_config)
assert dest is not None
assert dest == run_dir / "deliverables"
assert (dest / "paper_final.md").exists()
assert (dest / "paper.tex").exists()
assert (dest / "references.bib").exists()
assert (dest / "code" / "main.py").exists()
assert (dest / "verification_report.json").exists()
assert (dest / "manifest.json").exists()
manifest = json.loads((dest / "manifest.json").read_text())
assert manifest["run_id"] == "run-pkg-test"
assert "paper_final.md" in manifest["files"]
def test_package_deliverables_prefers_verified_versions(
run_dir: Path, rc_config: RCConfig
) -> None:
_setup_stage_artifacts(run_dir)
rc_runner._package_deliverables(run_dir, "run-verified", rc_config)
dest = run_dir / "deliverables"
# Should contain verified content (from stage 23), not base (from stage 22)
paper = (dest / "paper_final.md").read_text(encoding="utf-8")
assert "verified" in paper
bib = (dest / "references.bib").read_text(encoding="utf-8")
assert "smith2024" in bib
def test_package_deliverables_falls_back_to_stage22(
run_dir: Path, rc_config: RCConfig
) -> None:
"""When stage 23 outputs are missing, falls back to stage 22 versions."""
s22 = run_dir / "stage-22"
s22.mkdir(parents=True, exist_ok=True)
(s22 / "paper_final.md").write_text("# Base Paper", encoding="utf-8")
(s22 / "references.bib").write_text("@article{a,title={A}}", encoding="utf-8")
dest = rc_runner._package_deliverables(run_dir, "run-fallback", rc_config)
assert dest is not None
paper = (dest / "paper_final.md").read_text(encoding="utf-8")
assert "Base Paper" in paper
def test_package_deliverables_returns_none_when_no_stage_artifacts(
run_dir: Path, tmp_path: Path,
) -> None:
"""Returns None when no stage artifacts exist and no style files found."""
# Use a config with an unknown conference so style files aren't bundled
data = {
"project": {"name": "empty-test", "mode": "docs-first"},
"research": {"topic": "empty"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local"},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost:1234/v1",
"api_key_env": "RC_TEST_KEY",
"api_key": "inline",
},
"export": {"target_conference": "unknown_conf_9999"},
}
cfg = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
result = rc_runner._package_deliverables(run_dir, "run-empty", cfg)
assert result is None
assert not (run_dir / "deliverables").exists()
def test_package_deliverables_includes_style_files(
run_dir: Path, rc_config: RCConfig
) -> None:
"""Style files (.sty, .bst) for the target conference are bundled."""
_setup_stage_artifacts(run_dir)
dest = rc_runner._package_deliverables(run_dir, "run-styles", rc_config)
assert dest is not None
# Default config uses neurips_2025 → should have neurips_2025.sty
assert (dest / "neurips_2025.sty").exists()
manifest = json.loads((dest / "manifest.json").read_text())
assert "neurips_2025.sty" in manifest["files"]
# ── Atomic checkpoint write tests ──
def test_write_checkpoint_uses_atomic_rename(run_dir: Path) -> None:
"""Checkpoint must be written via temp file + rename, not direct write"""
rc_runner._write_checkpoint(run_dir, Stage.TOPIC_INIT, "run-atomic")
cp = run_dir / "checkpoint.json"
assert cp.exists()
data = json.loads(cp.read_text(encoding="utf-8"))
assert data["last_completed_stage"] == int(Stage.TOPIC_INIT)
assert data["run_id"] == "run-atomic"
def test_write_checkpoint_leaves_no_temp_files(run_dir: Path) -> None:
"""Atomic write must clean up temp files on success"""
rc_runner._write_checkpoint(run_dir, Stage.TOPIC_INIT, "run-clean")
temps = list(run_dir.glob("*.tmp"))
assert temps == [], f"Leftover temp files: {temps}"
def test_write_checkpoint_preserves_old_on_write_failure(
run_dir: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""If the temp-file write fails, the existing checkpoint must survive"""
import builtins
rc_runner._write_checkpoint(run_dir, Stage.TOPIC_INIT, "run-ok")
original_open = builtins.open
def _exploding_open(path, *args, **kwargs):
# After os.close(fd), _write_checkpoint opens via path string —
# intercept temp-file opens (checkpoint_*.tmp)
if isinstance(path, (str, Path)) and "checkpoint_" in str(path):
raise OSError("disk full")
if isinstance(path, int):
raise OSError("disk full")
return original_open(path, *args, **kwargs)
monkeypatch.setattr(builtins, "open", _exploding_open)
with pytest.raises(OSError):
rc_runner._write_checkpoint(run_dir, Stage.PROBLEM_DECOMPOSE, "run-ok")
# Original checkpoint must be intact
data = json.loads((run_dir / "checkpoint.json").read_text(encoding="utf-8"))
assert data["last_completed_stage"] == int(Stage.TOPIC_INIT)
# Temp file must be cleaned up
assert list(run_dir.glob("checkpoint_*.tmp")) == []
def test_write_checkpoint_overwrites_previous(run_dir: Path) -> None:
"""A second checkpoint call must fully replace the first"""
rc_runner._write_checkpoint(run_dir, Stage.TOPIC_INIT, "run-1")
rc_runner._write_checkpoint(run_dir, Stage.PROBLEM_DECOMPOSE, "run-1")
data = json.loads((run_dir / "checkpoint.json").read_text(encoding="utf-8"))
assert data["last_completed_stage"] == int(Stage.PROBLEM_DECOMPOSE)
assert data["last_completed_name"] == Stage.PROBLEM_DECOMPOSE.name
def _degraded(stage: Stage) -> StageResult:
return StageResult(
stage=stage,
status=StageStatus.DONE,
artifacts=("quality_report.json",),
decision="degraded",
)
def test_degraded_quality_gate_continues_pipeline(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
capsys: pytest.CaptureFixture[str],
) -> None:
"""When quality gate returns decision='degraded', pipeline continues to completion."""
seen: list[Stage] = []
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
_ = kwargs
seen.append(stage)
if stage == Stage.QUALITY_GATE:
return _degraded(stage)
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
results = rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-degraded",
config=rc_config,
adapters=adapters,
)
# All 23 stages should execute (not stopped at quality gate)
assert len(results) == 23
assert seen == list(STAGE_SEQUENCE)
# Quality gate result should have decision="degraded"
qg_result = [r for r in results if r.stage == Stage.QUALITY_GATE][0]
assert qg_result.decision == "degraded"
assert qg_result.status == StageStatus.DONE
# Pipeline summary should have degraded=True
summary = json.loads((run_dir / "pipeline_summary.json").read_text())
assert summary["degraded"] is True
# Output should show DEGRADED message
captured = capsys.readouterr()
assert "DEGRADED" in captured.out
def test_package_deliverables_called_after_pipeline(
monkeypatch: pytest.MonkeyPatch,
run_dir: Path,
rc_config: RCConfig,
adapters: AdapterBundle,
capsys: pytest.CaptureFixture[str],
) -> None:
"""Deliverables packaging is called at end of execute_pipeline."""
_setup_stage_artifacts(run_dir)
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
return _done(stage)
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="run-with-deliverables",
config=rc_config,
adapters=adapters,
)
captured = capsys.readouterr()
assert "Deliverables packaged" in captured.out
assert (run_dir / "deliverables" / "manifest.json").exists()
# ---------------------------------------------------------------------------
# BUG-223: _promote_best_stage14 must always write experiment_summary_best.json
# ---------------------------------------------------------------------------
def _make_stage14_summary(run_dir: Path, suffix: str, pm_value: float) -> None:
"""Helper: create a stage-14{suffix}/experiment_summary.json."""
d = run_dir / f"stage-14{suffix}"
d.mkdir(parents=True, exist_ok=True)
data = {
"metrics_summary": {
"primary_metric": {"min": pm_value, "max": pm_value, "mean": pm_value, "count": 1}
},
"condition_summaries": {"cond_a": {"metrics": {"primary_metric": pm_value}}},
}
(d / "experiment_summary.json").write_text(json.dumps(data), encoding="utf-8")
class TestPromoteBestStage14BestJson:
"""BUG-223: experiment_summary_best.json must be written even when
stage-14/ already has the best result (early-return path)."""
@pytest.fixture()
def max_config(self, rc_config: RCConfig) -> RCConfig:
"""Config with metric_direction=maximize (accuracy-like metrics)."""
object.__setattr__(rc_config.experiment, "metric_direction", "maximize")
return rc_config
def test_best_json_written_when_current_is_best(
self, run_dir: Path, max_config: RCConfig
) -> None:
"""stage-14/ already best → should still write best.json."""
_make_stage14_summary(run_dir, "", 90.0)
_make_stage14_summary(run_dir, "_v1", 80.0)
_make_stage14_summary(run_dir, "_v2", 70.0)
rc_runner._promote_best_stage14(run_dir, max_config) # type: ignore[attr-defined]
best_path = run_dir / "experiment_summary_best.json"
assert best_path.exists(), "experiment_summary_best.json must always be written"
data = json.loads(best_path.read_text(encoding="utf-8"))
pm = data["metrics_summary"]["primary_metric"]
assert pm["mean"] == 90.0
def test_best_json_written_when_promotion_needed(
self, run_dir: Path, max_config: RCConfig
) -> None:
"""stage-14/ is NOT best → promote + write best.json."""
_make_stage14_summary(run_dir, "", 70.0)
_make_stage14_summary(run_dir, "_v1", 95.0)
rc_runner._promote_best_stage14(run_dir, max_config) # type: ignore[attr-defined]
best_path = run_dir / "experiment_summary_best.json"
assert best_path.exists()
data = json.loads(best_path.read_text(encoding="utf-8"))
pm = data["metrics_summary"]["primary_metric"]
assert pm["mean"] == 95.0
def test_best_json_written_with_equal_values(
self, run_dir: Path, max_config: RCConfig
) -> None:
"""BUG-223 exact scenario: stage-14 and stage-14_v1 have equal
metrics, stage-14_v2 is regressed."""
_make_stage14_summary(run_dir, "", 64.46)
_make_stage14_summary(run_dir, "_v1", 64.46)
_make_stage14_summary(run_dir, "_v2", 26.80)
rc_runner._promote_best_stage14(run_dir, max_config) # type: ignore[attr-defined]
best_path = run_dir / "experiment_summary_best.json"
assert best_path.exists(), "BUG-223: best.json missing when current is tied-best"
data = json.loads(best_path.read_text(encoding="utf-8"))
pm = data["metrics_summary"]["primary_metric"]
assert pm["mean"] == 64.46
class TestPromoteBestStage14AnalysisBest:
"""BUG-225: analysis_best.md must be written from best stage-14 iteration."""
@pytest.fixture()
def max_config(self, rc_config: RCConfig) -> RCConfig:
object.__setattr__(rc_config.experiment, "metric_direction", "maximize")
return rc_config
def test_analysis_best_written_from_best_iteration(
self, run_dir: Path, max_config: RCConfig
) -> None:
"""analysis_best.md should come from the best stage-14 iteration."""
_make_stage14_summary(run_dir, "", 70.0)
_make_stage14_summary(run_dir, "_v1", 95.0)
# Write analysis.md in each
(run_dir / "stage-14" / "analysis.md").write_text("Degenerate analysis", encoding="utf-8")
(run_dir / "stage-14_v1" / "analysis.md").write_text("Best analysis v1", encoding="utf-8")
rc_runner._promote_best_stage14(run_dir, max_config) # type: ignore[attr-defined]
best_analysis = run_dir / "analysis_best.md"
assert best_analysis.exists(), "BUG-225: analysis_best.md must be written"
assert best_analysis.read_text(encoding="utf-8") == "Best analysis v1"
def test_analysis_best_written_when_current_is_best(
self, run_dir: Path, max_config: RCConfig
) -> None:
"""Even when stage-14 is already best, analysis_best.md should be written."""
_make_stage14_summary(run_dir, "", 90.0)
_make_stage14_summary(run_dir, "_v1", 80.0)
(run_dir / "stage-14" / "analysis.md").write_text("Best analysis current", encoding="utf-8")
(run_dir / "stage-14_v1" / "analysis.md").write_text("Worse analysis", encoding="utf-8")
rc_runner._promote_best_stage14(run_dir, max_config) # type: ignore[attr-defined]
best_analysis = run_dir / "analysis_best.md"
assert best_analysis.exists()
assert best_analysis.read_text(encoding="utf-8") == "Best analysis current"
def test_no_analysis_best_when_no_analysis_md(
self, run_dir: Path, max_config: RCConfig
) -> None:
"""If best stage-14 has no analysis.md, no analysis_best.md is written."""
_make_stage14_summary(run_dir, "", 90.0)
rc_runner._promote_best_stage14(run_dir, max_config) # type: ignore[attr-defined]
assert not (run_dir / "analysis_best.md").exists()
class TestPromoteBestStage14DegenerateDetection:
"""BUG-226: Degenerate near-zero metrics must not be promoted as best."""
def test_degenerate_minimize_skipped(self, run_dir: Path, rc_config: RCConfig) -> None:
"""When minimize, a value 1000x smaller than second-best is degenerate."""
# metric_direction defaults to "minimize"
_make_stage14_summary(run_dir, "", 7.26e-8) # degenerate (broken normalization)
_make_stage14_summary(run_dir, "_v2", 0.37) # valid
rc_runner._promote_best_stage14(run_dir, rc_config) # type: ignore[attr-defined]
best_path = run_dir / "experiment_summary_best.json"
assert best_path.exists()
data = json.loads(best_path.read_text(encoding="utf-8"))
pm = data["metrics_summary"]["primary_metric"]
assert pm["mean"] == 0.37, "Degenerate value should be skipped, valid v2 promoted"
def test_legitimate_minimize_not_skipped(self, run_dir: Path, rc_config: RCConfig) -> None:
"""When values are within normal range, smaller is legitimately best."""
_make_stage14_summary(run_dir, "", 0.15)
_make_stage14_summary(run_dir, "_v1", 0.37)
rc_runner._promote_best_stage14(run_dir, rc_config) # type: ignore[attr-defined]
best_path = run_dir / "experiment_summary_best.json"
data = json.loads(best_path.read_text(encoding="utf-8"))
pm = data["metrics_summary"]["primary_metric"]
assert pm["mean"] == 0.15, "Legitimate lower value should be promoted"
def test_single_candidate_not_affected(self, run_dir: Path, rc_config: RCConfig) -> None:
"""Single candidate is never skipped regardless of value."""
_make_stage14_summary(run_dir, "", 1e-10)
rc_runner._promote_best_stage14(run_dir, rc_config) # type: ignore[attr-defined]
best_path = run_dir / "experiment_summary_best.json"
data = json.loads(best_path.read_text(encoding="utf-8"))
pm = data["metrics_summary"]["primary_metric"]
assert pm["mean"] == 1e-10
================================================
FILE: tests/test_rc_sanitization.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.pipeline.executor import _sanitize_fabricated_data
from researchclaw.pipeline.stage_impls._code_generation import _check_rl_compatibility
@pytest.fixture()
def run_dir(tmp_path: Path) -> Path:
path = tmp_path / "run"
path.mkdir()
return path
def _write_experiment_summary(run_dir: Path, data: dict) -> None:
stage14 = run_dir / "stage-14"
stage14.mkdir(parents=True, exist_ok=True)
(stage14 / "experiment_summary.json").write_text(
json.dumps(data, indent=2), encoding="utf-8"
)
def test_sanitize_replaces_unverified_numbers(run_dir: Path) -> None:
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85, "f1": 0.82},
"best_run": {"metrics": {"accuracy": 0.85}},
})
paper = (
"## Results\n\n"
"| Method | Accuracy | F1 | Precision |\n"
"| --- | --- | --- | --- |\n"
"| Ours | 0.85 | 0.82 | 0.91 |\n"
"| Baseline | 0.73 | 0.65 | 0.78 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# 0.85 and 0.82 should be kept (verified), 0.91, 0.73, 0.65, 0.78 replaced
assert "0.85" in sanitized
assert "0.82" in sanitized
assert "0.91" not in sanitized
assert "0.73" not in sanitized
assert "---" in sanitized
assert report["sanitized"] is True
assert report["numbers_replaced"] == 4
assert report["numbers_kept"] == 2
def test_sanitize_preserves_table_structure(run_dir: Path) -> None:
_write_experiment_summary(run_dir, {
"metrics_summary": {"loss": 0.12},
})
paper = (
"| Model | Loss |\n"
"| --- | --- |\n"
"| A | 0.12 |\n"
"| B | 0.8765 |\n"
)
sanitized, _ = _sanitize_fabricated_data(paper, run_dir)
# Table pipes should still be intact
assert sanitized.count("|") == paper.count("|")
assert "0.12" in sanitized
assert "0.8765" not in sanitized
def test_sanitize_no_experiment_summary(run_dir: Path) -> None:
paper = "| A | 0.5 |\n| --- | --- |\n| B | 0.6 |\n"
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
assert report["sanitized"] is False
assert sanitized == paper # unchanged
def test_sanitize_tolerance_within_1_percent(run_dir: Path) -> None:
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 100.0},
})
paper = (
"| Method | Acc |\n"
"| --- | --- |\n"
"| Ours | 100.5 |\n" # within 1% of 100.0
"| Other | 110.0 |\n" # outside 1%
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
assert "100.5" in sanitized # kept (within tolerance)
assert "110.0" not in sanitized # replaced
def test_sanitize_header_row_preserved(run_dir: Path) -> None:
_write_experiment_summary(run_dir, {
"metrics_summary": {"val": 5.0},
})
paper = (
"| Col1 | Col2 |\n"
"| --- | --- |\n"
"| data | 99.9 |\n"
)
sanitized, _ = _sanitize_fabricated_data(paper, run_dir)
# Header row should be untouched
assert "| Col1 | Col2 |" in sanitized
def test_sanitize_hp_columns_preserved_in_mixed_table(run_dir: Path) -> None:
"""BUG-184: HP columns in mixed tables should not be sanitized."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
"best_run": {"metrics": {"accuracy": 0.85}},
})
paper = (
"## Results\n\n"
"| Method | LR | Batch Size | Accuracy | F1 |\n"
"| --- | --- | --- | --- | --- |\n"
"| Ours | 0.0007 | 48 | 0.85 | 0.91 |\n"
"| Baseline | 0.0001 | 24 | 0.73 | 0.78 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# HP columns (LR, Batch Size) should be preserved regardless of verification
assert "0.0007" in sanitized, "HP column 'LR' value should not be sanitized"
assert "0.0001" in sanitized, "HP column 'LR' value should not be sanitized"
# Result columns: 0.85 verified → kept; 0.91, 0.73, 0.78 → replaced
assert "0.85" in sanitized
assert "0.91" not in sanitized
assert "0.73" not in sanitized
def test_sanitize_pure_hp_table_skipped(run_dir: Path) -> None:
"""BUG-192: Pure HP tables (header keywords) should be fully skipped."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
})
paper = (
"| Hyperparameter | Value |\n"
"| --- | --- |\n"
"| Learning Rate | 0.0007 |\n"
"| Batch Size | 48 |\n"
"| Weight Decay | 0.0005 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# Entire table should be skipped — no sanitization at all
assert "0.0007" in sanitized
assert "0.0005" in sanitized
assert report["tables_processed"] == 0
def test_prose_sanitization_replaces_unverified(run_dir: Path) -> None:
"""Prose numbers in Results section should be sanitized."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
"best_run": {"metrics": {"accuracy": 0.85}},
})
paper = (
"# Introduction\n"
"Prior work achieved 92.3% accuracy on this task.\n\n"
"# Results\n"
"Our method achieved 85.0% accuracy, which is significantly better.\n"
"The baseline obtained 72.4% accuracy on the same benchmark.\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# 85.0 is verified (matches 0.85 × 100), should be kept
assert "85.0" in sanitized
# 72.4 is unverified in Results → replaced
assert "72.4" not in sanitized
assert "[value removed]" in sanitized
# 92.3 is in Introduction (not Results) → should be preserved
assert "92.3" in sanitized
assert report["prose_numbers_replaced"] >= 1
def test_sanitize_model_name_numbers_preserved(run_dir: Path) -> None:
"""BUG-206: Numbers in model names (ResNet-34) must not be replaced."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
"best_run": {"metrics": {"accuracy": 0.85}},
})
# Table with model variant numbers in the first column (ci=1, skipped)
paper = (
"## Results\n\n"
"| Method | Accuracy |\n"
"| --- | --- |\n"
"| ResNet-34 (baseline) | 0.85 |\n"
"| ResNet-50 (teacher) | 0.91 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# First column is method names — must be preserved (includes "34", "50")
assert "ResNet-34" in sanitized, "Model name 'ResNet-34' should not be sanitized"
assert "ResNet-50" in sanitized, "Model name 'ResNet-50' should not be sanitized"
def test_sanitize_unicode_hyphen_model_names_preserved(run_dir: Path) -> None:
"""BUG-206: Unicode non-breaking hyphen in model names must not be replaced."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
"best_run": {"metrics": {"accuracy": 0.85}},
})
# U+2011 non-breaking hyphen (common LLM output)
paper = (
"## Results\n\n"
"| Method | Accuracy |\n"
"| --- | --- |\n"
"| ResNet\u201134 (baseline) | 0.85 |\n"
"| ResNet\u201150 (teacher) | 0.91 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
assert "ResNet\u201134" in sanitized, "Model name with U+2011 hyphen should not be sanitized"
assert "ResNet\u201150" in sanitized, "Model name with U+2011 hyphen should not be sanitized"
def test_prose_sanitization_preserves_introduction(run_dir: Path) -> None:
"""Numbers outside Results/Experiments should NOT be touched."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"val": 0.50},
})
paper = (
"# Introduction\n"
"Previous methods achieved 94.2% accuracy.\n\n"
"# Related Work\n"
"Smith et al. reported 88.7% on the benchmark.\n\n"
"# Conclusion\n"
"We demonstrated 50.0% accuracy.\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# None of these sections are Results/Experiments → all preserved
assert "94.2" in sanitized
assert "88.7" in sanitized
assert report["prose_numbers_replaced"] == 0
# ---------------------------------------------------------------------------
# RL compatibility check (Improvement G)
# ---------------------------------------------------------------------------
def test_rl_compatibility_dqn_continuous_detected() -> None:
"""DQN + continuous env should produce errors."""
code = """
import gymnasium as gym
from stable_baselines3 import DQN
env = gym.make('Pendulum-v1')
model = DQN('MlpPolicy', env)
model.learn(total_timesteps=10000)
"""
errors = _check_rl_compatibility(code)
assert len(errors) >= 1
assert "DQN" in errors[0]
assert "pendulum" in errors[0].lower()
def test_rl_compatibility_ppo_continuous_ok() -> None:
"""PPO + continuous env should be fine."""
code = """
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make('HalfCheetah-v5')
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000)
"""
errors = _check_rl_compatibility(code)
assert len(errors) == 0
def test_sanitize_reads_promoted_best_data(run_dir: Path) -> None:
"""BUG-222: Sanitizer uses experiment_summary_best.json (promoted best).
After REFINE, the pipeline promotes the best iteration's data to
experiment_summary_best.json. The sanitizer should validate against
that file, not scan all refinement logs.
"""
# Stale stage-14 data (from a regressed iteration)
_write_experiment_summary(run_dir, {
"metrics_summary": {"primary_metric": {"min": 8.42, "max": 8.91, "mean": 8.6467, "count": 3}},
"best_run": {"metrics": {"primary_metric": 8.65}},
})
# Promoted best data (from the winning iteration)
(run_dir / "experiment_summary_best.json").write_text(
json.dumps({
"metrics_summary": {"primary_metric": {"min": 73.07, "max": 78.93, "mean": 75.56, "count": 3}},
"best_run": {"metrics": {"primary_metric": 78.93}},
"condition_summaries": {
"Ours": {"metrics": {"primary_metric": 78.93}},
"SGD": {"metrics": {"primary_metric": 73.07}},
"AdamW": {"metrics": {"primary_metric": 68.67}},
},
}, indent=2), encoding="utf-8"
)
# Paper uses values from promoted best
paper = (
"## Results\n\n"
"| Method | Accuracy |\n"
"| --- | --- |\n"
"| Ours | 78.93 |\n"
"| SGD | 73.07 |\n"
"| AdamW | 68.67 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
assert "78.93" in sanitized
assert "73.07" in sanitized
assert "68.67" in sanitized
assert report["numbers_kept"] == 3
assert report["numbers_replaced"] == 0
def test_sanitize_rejects_regressed_refine_data(run_dir: Path) -> None:
"""BUG-222: Regressed REFINE iteration data must NOT pass sanitizer.
Reproduces the Run 75 fabrication bypass: v1 had 74.52%, v3 regressed
to 69.30%. Paper cited v3 numbers. The sanitizer should reject them.
"""
# v1 (best) promoted to experiment_summary_best.json
(run_dir / "experiment_summary_best.json").write_text(
json.dumps({
"best_run": {"metrics": {"FeatureKD/0/metric": 0.7452}},
"condition_summaries": {
"FeatureKD": {"metrics": {"metric": 0.7452}},
"Teacher": {"metrics": {"metric": 0.7431}},
},
"metrics_summary": {"metric": {"mean": 0.7442, "min": 0.7431, "max": 0.7452}},
}, indent=2), encoding="utf-8"
)
# v3 (regressed) in stage-14 (stale)
_write_experiment_summary(run_dir, {
"best_run": {"metrics": {"FeatureKD/0/metric": 0.6930}},
"condition_summaries": {
"FeatureKD": {"metrics": {"metric": 0.6930}},
"Teacher": {"metrics": {"metric": 0.7292}},
},
"metrics_summary": {"metric": {"mean": 0.7111, "min": 0.6930, "max": 0.7292}},
})
# v3 sandbox data in refinement_log
stage13 = run_dir / "stage-13_v2"
stage13.mkdir(parents=True, exist_ok=True)
(stage13 / "refinement_log.json").write_text(json.dumps({
"iterations": [{"sandbox": {"metrics": {"primary_metric": 0.6930}}}]
}), encoding="utf-8")
# Paper fabricates v3 numbers
paper = (
"## Results\n\n"
"| Method | Accuracy |\n"
"| --- | --- |\n"
"| FeatureKD | 69.30 |\n"
"| Teacher | 72.92 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# 69.30 should be REPLACED — it's from regressed v3, not promoted v1
assert "69.30" not in sanitized
assert report["numbers_replaced"] >= 1
# But 74.52 or 74.31 (v1 best) would pass if cited
paper_v1 = (
"## Results\n\n"
"| Method | Accuracy |\n"
"| --- | --- |\n"
"| FeatureKD | 74.52 |\n"
"| Teacher | 74.31 |\n"
)
sanitized_v1, report_v1 = _sanitize_fabricated_data(paper_v1, run_dir)
assert "74.52" in sanitized_v1
assert "74.31" in sanitized_v1
assert report_v1["numbers_replaced"] == 0
def test_sanitize_condition_names_with_decimals_preserved(run_dir: Path) -> None:
"""BUG-210: Condition names with decimal params (ema_decay_0.9) must not be damaged."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 73.07},
"best_run": {"metrics": {"accuracy": 73.07}},
})
paper = (
"## Results\n\n"
"| Condition | Accuracy |\n"
"| --- | --- |\n"
"| ema_decay_0.9 | 73.07 |\n"
"| ema_decay_0.99 | 69.33 |\n"
"| swa_start_0.75 | 68.67 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# First column (condition names) must be completely preserved
assert "ema_decay_0.9 " in sanitized, "Condition name 'ema_decay_0.9' damaged"
assert "ema_decay_0.99" in sanitized, "Condition name 'ema_decay_0.99' damaged"
assert "swa_start_0.75" in sanitized, "Condition name 'swa_start_0.75' damaged"
# 73.07 is verified → kept
assert "73.07" in sanitized
def test_rl_compatibility_dqn_discrete_ok() -> None:
"""DQN + discrete env (CartPole) should be fine."""
code = """
import gymnasium as gym
from stable_baselines3 import DQN
env = gym.make('CartPole-v1')
model = DQN('MlpPolicy', env)
"""
errors = _check_rl_compatibility(code)
assert len(errors) == 0
# ---------------------------------------------------------------------------
# BUG-211: LaTeX tabular sanitization
# ---------------------------------------------------------------------------
def test_sanitize_latex_tabular_replaces_unverified(run_dir: Path) -> None:
"""BUG-211: Numbers inside \\begin{tabular} must be sanitized."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.4816},
"best_run": {"metrics": {"accuracy": 0.4816}},
})
paper = (
"## Results\n\n"
"```latex\n"
"\\begin{table}[htbp]\n"
"\\centering\n"
"\\caption{Test accuracy for all configurations.}\n"
"\\begin{tabular}{l c}\n"
"\\toprule\n"
"Method & Accuracy \\\\\n"
"\\midrule\n"
"baseline\\_resnet18 & \\textbf{0.4816} \\\\\n"
"baseline\\_resnet50 & 0.4451 \\\\\n"
"dropout\\_standard & 0.3243 \\\\\n"
"\\bottomrule\n"
"\\end{tabular}\n"
"\\end{table}\n"
"```\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# 0.4816 is verified → kept
assert "0.4816" in sanitized
# 0.4451 and 0.3243 are unverified → replaced with ---
assert "0.4451" not in sanitized
assert "0.3243" not in sanitized
assert "---" in sanitized
assert report["tables_processed"] >= 1
assert report["numbers_replaced"] >= 2
def test_sanitize_latex_tabular_hp_table_skipped(run_dir: Path) -> None:
"""BUG-211: LaTeX HP tables should be skipped just like markdown ones."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
})
paper = (
"\\begin{table}[htbp]\n"
"\\centering\n"
"\\caption{Training hyperparameters.}\n"
"\\begin{tabular}{l c}\n"
"\\toprule\n"
"Hyperparameter & Value \\\\\n"
"\\midrule\n"
"Learning Rate & 0.001 \\\\\n"
"Batch Size & 128 \\\\\n"
"Weight Decay & 0.0005 \\\\\n"
"\\bottomrule\n"
"\\end{tabular}\n"
"\\end{table}\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# HP table — all values preserved, table NOT processed
assert "0.001" in sanitized
assert "0.0005" in sanitized
def test_sanitize_latex_tabular_with_pm(run_dir: Path) -> None:
"""BUG-211: Numbers with ± in LaTeX cells must be individually checked."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 48.16, "accuracy_std": 0.35},
"best_run": {"metrics": {"accuracy": 48.16}},
"condition_summaries": {
"method_a": {"primary_metric_mean": 48.16, "primary_metric_std": 0.35},
},
})
paper = (
"\\begin{tabular}{l c}\n"
"\\toprule\n"
"Method & Accuracy (mean $\\pm$ std) \\\\\n"
"\\midrule\n"
"method\\_a & 48.16 $\\pm$ 0.35 \\\\\n"
"method\\_b & 32.43 $\\pm$ 0.45 \\\\\n"
"\\bottomrule\n"
"\\end{tabular}\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# 48.16 and 0.35 are verified → kept
assert "48.16" in sanitized
assert "0.35" in sanitized
# 32.43 and 0.45 are unverified → replaced
assert "32.43" not in sanitized
assert "0.45" not in sanitized
def test_sanitize_latex_tabular_preserves_first_column(run_dir: Path) -> None:
"""BUG-211: First column (method names) must be preserved."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": 0.85},
"best_run": {"metrics": {"accuracy": 0.85}},
})
paper = (
"\\begin{tabular}{l r r r r}\n"
"\\toprule\n"
"Method & Seed 0 & Seed 1 & Seed 2 & Mean \\\\\n"
"\\midrule\n"
"resnet\\_18 & 0.4861 & 0.4809 & 0.4777 & 0.4816 \\\\\n"
"resnet\\_50 & 0.4455 & 0.4459 & 0.4438 & 0.4451 \\\\\n"
"\\bottomrule\n"
"\\end{tabular}\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# Method names in first column must be preserved
assert "resnet\\_18" in sanitized
assert "resnet\\_50" in sanitized
# ---------------------------------------------------------------------------
# BUG-224: Statistical analysis tables should NOT be sanitized
# ---------------------------------------------------------------------------
def test_sanitize_skips_statistical_analysis_table(run_dir: Path) -> None:
"""BUG-224: Tables with t-statistics, p-values, and effect sizes are
derived from experiment data and should not be sanitized."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": {"mean": 64.26}},
"condition_summaries": {"ce": {"metrics": {"accuracy": 64.26}}},
})
paper = (
"## Results\n\n"
"| Method | Accuracy |\n"
"|--------|----------|\n"
"| CE | 64.26 |\n"
"| SCE | 56.93 |\n\n"
"## Statistical Analysis\n\n"
"| Comparison | t-statistic | p-value |\n"
"|-----------|------------|--------|\n"
"| CE vs SCE | 7.3267 | 0.0123 |\n"
"| CE vs GCE | 1.7100 | 0.0569 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# Results table: 64.26 is verified, 56.93 is NOT → gets replaced
assert "56.93" not in sanitized or "---" in sanitized
# Statistical table: 7.3267 and 0.0123 are derived → MUST be preserved
assert "7.3267" in sanitized, "BUG-224: t-statistic was sanitized"
assert "0.0123" in sanitized, "BUG-224: p-value was sanitized"
assert "1.7100" in sanitized, "BUG-224: t-statistic was sanitized"
assert "0.0569" in sanitized, "BUG-224: p-value was sanitized"
def test_sanitize_preserves_common_hp_values(run_dir: Path) -> None:
"""BUG-224: Common HP values like 0.7 should be in the always-allowed set."""
_write_experiment_summary(run_dir, {
"metrics_summary": {"accuracy": {"mean": 64.26}},
"condition_summaries": {"ce": {"metrics": {"accuracy": 64.26}}},
})
paper = (
"| Method | q | Accuracy |\n"
"|--------|---|----------|\n"
"| GCE | 0.7 | 64.26 |\n"
"| GCE-05 | 0.5 | 66.77 |\n"
)
sanitized, report = _sanitize_fabricated_data(paper, run_dir)
# 0.7 should be preserved (always-allowed HP value)
assert "0.7" in sanitized, "BUG-224: q=0.7 was incorrectly sanitized"
# 0.5 should also be preserved
assert "0.5" in sanitized
================================================
FILE: tests/test_rc_sentinel.py
================================================
# pyright: reportPrivateUsage=false
"""Tests for the sentinel watchdog and heartbeat system."""
from __future__ import annotations
import json
import os
import subprocess
from pathlib import Path
import pytest
from researchclaw.pipeline import runner as rc_runner
from researchclaw.pipeline.stages import Stage
# ── Heartbeat writing tests ──
class TestHeartbeatWriting:
def test_write_heartbeat_creates_file(self, tmp_path: Path) -> None:
rc_runner._write_heartbeat(tmp_path, Stage.TOPIC_INIT, "run-hb-1")
hb_path = tmp_path / "heartbeat.json"
assert hb_path.exists()
def test_heartbeat_contains_required_fields(self, tmp_path: Path) -> None:
rc_runner._write_heartbeat(tmp_path, Stage.HYPOTHESIS_GEN, "run-hb-2")
data = json.loads((tmp_path / "heartbeat.json").read_text())
assert data["pid"] == os.getpid()
assert data["last_stage"] == 8
assert data["last_stage_name"] == "HYPOTHESIS_GEN"
assert data["run_id"] == "run-hb-2"
assert "timestamp" in data
def test_heartbeat_updates_on_each_stage(self, tmp_path: Path) -> None:
rc_runner._write_heartbeat(tmp_path, Stage.TOPIC_INIT, "run-1")
data1 = json.loads((tmp_path / "heartbeat.json").read_text())
rc_runner._write_heartbeat(tmp_path, Stage.PAPER_DRAFT, "run-1")
data2 = json.loads((tmp_path / "heartbeat.json").read_text())
assert data2["last_stage"] == 17
assert data1["last_stage"] == 1
class TestHeartbeatInPipeline:
def test_pipeline_writes_heartbeat_after_each_stage(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
from researchclaw.adapters import AdapterBundle
from researchclaw.config import RCConfig
from researchclaw.pipeline.executor import StageResult
from researchclaw.pipeline.stages import StageStatus
data = {
"project": {"name": "hb-test", "mode": "docs-first"},
"research": {"topic": "heartbeat testing"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "local"},
"knowledge_base": {"backend": "markdown", "root": str(tmp_path / "kb")},
"openclaw_bridge": {},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost/v1",
"api_key_env": "K",
"api_key": "k",
},
}
config = RCConfig.from_dict(data, project_root=tmp_path, check_paths=False)
run_dir = tmp_path / "run"
run_dir.mkdir()
call_count = 0
def mock_execute_stage(stage: Stage, **kwargs) -> StageResult:
nonlocal call_count
call_count += 1
if call_count >= 3:
return StageResult(
stage=stage, status=StageStatus.FAILED, artifacts=(), error="stop"
)
return StageResult(stage=stage, status=StageStatus.DONE, artifacts=("x.md",))
monkeypatch.setattr(rc_runner, "execute_stage", mock_execute_stage)
rc_runner.execute_pipeline(
run_dir=run_dir,
run_id="hb-test",
config=config,
adapters=AdapterBundle(),
)
hb_path = run_dir / "heartbeat.json"
assert hb_path.exists()
data_out = json.loads(hb_path.read_text())
assert data_out["run_id"] == "hb-test"
# ── Sentinel script syntax check ──
class TestSentinelScript:
def test_sentinel_script_exists(self) -> None:
script = Path(__file__).parent.parent / "sentinel.sh"
assert script.exists()
def test_sentinel_script_is_valid_bash(self) -> None:
script = Path(__file__).parent.parent / "sentinel.sh"
result = subprocess.run(
["bash", "-n", str(script)],
capture_output=True,
text=True,
)
assert result.returncode == 0, f"Bash syntax error: {result.stderr}"
def test_sentinel_script_is_executable(self) -> None:
script = Path(__file__).parent.parent / "sentinel.sh"
assert os.access(script, os.X_OK)
def test_sentinel_script_has_shebang(self) -> None:
script = Path(__file__).parent.parent / "sentinel.sh"
first_line = script.read_text().splitlines()[0]
assert first_line.startswith("#!/")
def test_sentinel_prints_usage_on_no_args(self) -> None:
script = Path(__file__).parent.parent / "sentinel.sh"
result = subprocess.run(
["bash", str(script)],
capture_output=True,
text=True,
)
# Should fail because no run_dir argument provided
assert result.returncode != 0
================================================
FILE: tests/test_rc_stages.py
================================================
import pytest
from researchclaw.pipeline.stages import (
DECISION_ROLLBACK,
GATE_ROLLBACK,
GATE_STAGES,
MAX_DECISION_PIVOTS,
NEXT_STAGE,
PHASE_MAP,
PREVIOUS_STAGE,
STAGE_SEQUENCE,
TRANSITION_MAP,
Stage,
StageStatus,
TransitionEvent,
TransitionOutcome,
advance,
default_rollback_stage,
gate_required,
)
def test_stage_enum_has_exactly_23_members():
assert len(Stage) == 23
@pytest.mark.parametrize(
"index,stage", [(idx, stage) for idx, stage in enumerate(STAGE_SEQUENCE, start=1)]
)
def test_stage_values_follow_sequence_order(index: int, stage: Stage):
assert int(stage) == index
def test_stage_sequence_contains_all_23_stages_in_order():
assert len(STAGE_SEQUENCE) == 23
assert STAGE_SEQUENCE[0] is Stage.TOPIC_INIT
assert STAGE_SEQUENCE[-1] is Stage.CITATION_VERIFY
assert tuple(Stage) == STAGE_SEQUENCE
def test_next_stage_boundary_values():
assert NEXT_STAGE[Stage.TOPIC_INIT] is Stage.PROBLEM_DECOMPOSE
assert NEXT_STAGE[Stage.EXPORT_PUBLISH] is Stage.CITATION_VERIFY
def test_previous_stage_boundary_values():
assert PREVIOUS_STAGE[Stage.TOPIC_INIT] is None
assert PREVIOUS_STAGE[Stage.PROBLEM_DECOMPOSE] is Stage.TOPIC_INIT
def test_gate_stages_matches_expected_set():
assert GATE_STAGES == frozenset(
{Stage.LITERATURE_SCREEN, Stage.EXPERIMENT_DESIGN, Stage.QUALITY_GATE}
)
def test_gate_rollback_map_matches_expected_targets():
assert GATE_ROLLBACK == {
Stage.LITERATURE_SCREEN: Stage.LITERATURE_COLLECT,
Stage.EXPERIMENT_DESIGN: Stage.HYPOTHESIS_GEN,
Stage.QUALITY_GATE: Stage.PAPER_OUTLINE,
}
def test_phase_map_has_8_phases_with_expected_membership():
assert len(PHASE_MAP) == 8
assert PHASE_MAP["A: Research Scoping"] == (
Stage.TOPIC_INIT,
Stage.PROBLEM_DECOMPOSE,
)
assert PHASE_MAP["B: Literature Discovery"] == (
Stage.SEARCH_STRATEGY,
Stage.LITERATURE_COLLECT,
Stage.LITERATURE_SCREEN,
Stage.KNOWLEDGE_EXTRACT,
)
assert PHASE_MAP["C: Knowledge Synthesis"] == (
Stage.SYNTHESIS,
Stage.HYPOTHESIS_GEN,
)
assert PHASE_MAP["D: Experiment Design"] == (
Stage.EXPERIMENT_DESIGN,
Stage.CODE_GENERATION,
Stage.RESOURCE_PLANNING,
)
assert PHASE_MAP["E: Experiment Execution"] == (
Stage.EXPERIMENT_RUN,
Stage.ITERATIVE_REFINE,
)
assert PHASE_MAP["F: Analysis & Decision"] == (
Stage.RESULT_ANALYSIS,
Stage.RESEARCH_DECISION,
)
assert PHASE_MAP["G: Paper Writing"] == (
Stage.PAPER_OUTLINE,
Stage.PAPER_DRAFT,
Stage.PEER_REVIEW,
Stage.PAPER_REVISION,
)
assert PHASE_MAP["H: Finalization"] == (
Stage.QUALITY_GATE,
Stage.KNOWLEDGE_ARCHIVE,
Stage.EXPORT_PUBLISH,
Stage.CITATION_VERIFY,
)
def test_phase_map_covers_all_stages_exactly_once():
flattened = tuple(stage for stages in PHASE_MAP.values() for stage in stages)
assert len(flattened) == 23
assert set(flattened) == set(Stage)
@pytest.mark.parametrize(
"status",
[StageStatus.PENDING, StageStatus.RETRYING, StageStatus.PAUSED],
)
def test_start_event_transitions_to_running_from_allowed_states(status: StageStatus):
outcome = advance(Stage.EXPERIMENT_RUN, status, TransitionEvent.START)
assert outcome.status is StageStatus.RUNNING
assert outcome.next_stage is Stage.EXPERIMENT_RUN
def test_succeed_event_on_non_gate_stage_transitions_to_done():
outcome = advance(
Stage.SEARCH_STRATEGY,
StageStatus.RUNNING,
TransitionEvent.SUCCEED,
hitl_required_stages=(5, 9, 20),
)
assert outcome.status is StageStatus.DONE
assert outcome.next_stage is Stage.LITERATURE_COLLECT
assert outcome.checkpoint_required is True
assert outcome.decision == "proceed"
def test_succeed_event_on_gate_stage_transitions_to_blocked_approval():
outcome = advance(
Stage.LITERATURE_SCREEN,
StageStatus.RUNNING,
TransitionEvent.SUCCEED,
hitl_required_stages=(5, 20),
)
assert outcome.status is StageStatus.BLOCKED_APPROVAL
assert outcome.next_stage is Stage.LITERATURE_SCREEN
assert outcome.checkpoint_required is False
assert outcome.decision == "block"
def test_approve_event_transitions_blocked_stage_to_done():
outcome = advance(
Stage.EXPERIMENT_DESIGN,
StageStatus.BLOCKED_APPROVAL,
TransitionEvent.APPROVE,
hitl_required_stages=(5, 9, 20),
)
assert outcome.status is StageStatus.DONE
assert outcome.next_stage is Stage.CODE_GENERATION
assert outcome.checkpoint_required is True
def test_reject_event_rolls_back_to_default_gate_mapping():
outcome = advance(
Stage.QUALITY_GATE,
StageStatus.BLOCKED_APPROVAL,
TransitionEvent.REJECT,
hitl_required_stages=(5, 9, 20),
)
assert outcome.status is StageStatus.PENDING
assert outcome.stage is Stage.PAPER_OUTLINE
assert outcome.next_stage is Stage.PAPER_OUTLINE
assert outcome.rollback_stage is Stage.PAPER_OUTLINE
assert outcome.checkpoint_required is True
assert outcome.decision == "pivot"
def test_reject_event_uses_explicit_rollback_stage_when_provided():
outcome = advance(
Stage.PAPER_REVISION,
StageStatus.BLOCKED_APPROVAL,
TransitionEvent.REJECT,
rollback_stage=Stage.PAPER_OUTLINE,
)
assert outcome.status is StageStatus.PENDING
assert outcome.stage is Stage.PAPER_OUTLINE
assert outcome.next_stage is Stage.PAPER_OUTLINE
assert outcome.rollback_stage is Stage.PAPER_OUTLINE
def test_timeout_event_transitions_to_paused_with_block_decision():
outcome = advance(
Stage.LITERATURE_SCREEN,
StageStatus.BLOCKED_APPROVAL,
TransitionEvent.TIMEOUT,
)
assert outcome.status is StageStatus.PAUSED
assert outcome.next_stage is Stage.LITERATURE_SCREEN
assert outcome.checkpoint_required is True
assert outcome.decision == "block"
def test_fail_event_transitions_running_to_failed_with_retry_decision():
outcome = advance(Stage.EXPERIMENT_RUN, StageStatus.RUNNING, TransitionEvent.FAIL)
assert outcome.status is StageStatus.FAILED
assert outcome.next_stage is Stage.EXPERIMENT_RUN
assert outcome.checkpoint_required is True
assert outcome.decision == "retry"
def test_retry_event_transitions_failed_to_retrying():
outcome = advance(Stage.EXPERIMENT_RUN, StageStatus.FAILED, TransitionEvent.RETRY)
assert outcome.status is StageStatus.RETRYING
assert outcome.next_stage is Stage.EXPERIMENT_RUN
assert outcome.decision == "retry"
def test_resume_event_transitions_paused_to_running():
outcome = advance(Stage.EXPERIMENT_RUN, StageStatus.PAUSED, TransitionEvent.RESUME)
assert outcome.status is StageStatus.RUNNING
assert outcome.next_stage is Stage.EXPERIMENT_RUN
def test_pause_event_transitions_failed_to_paused():
outcome = advance(Stage.EXPERIMENT_RUN, StageStatus.FAILED, TransitionEvent.PAUSE)
assert outcome.status is StageStatus.PAUSED
assert outcome.next_stage is Stage.EXPERIMENT_RUN
assert outcome.checkpoint_required is True
assert outcome.decision == "block"
def test_invalid_transition_raises_value_error():
with pytest.raises(ValueError, match="Unsupported transition"):
_ = advance(Stage.TOPIC_INIT, StageStatus.DONE, TransitionEvent.START)
def test_advance_rejects_unknown_transition_event_string():
with pytest.raises(ValueError, match="not a valid TransitionEvent"):
_ = advance(Stage.TOPIC_INIT, StageStatus.PENDING, "unknown")
@pytest.mark.parametrize("stage", tuple(GATE_STAGES))
def test_gate_required_for_gate_stages_with_default_config(stage: Stage):
assert gate_required(stage, None) is True
@pytest.mark.parametrize("stage", tuple(GATE_STAGES))
def test_gate_required_respects_hitl_stage_subset(stage: Stage):
required = (5, 20)
assert gate_required(stage, required) is (int(stage) in required)
@pytest.mark.parametrize("stage", tuple(s for s in Stage if s not in GATE_STAGES))
def test_gate_required_is_false_for_non_gate_stages(stage: Stage):
assert gate_required(stage, (5, 9, 20)) is False
@pytest.mark.parametrize(
"stage,expected",
[
(Stage.LITERATURE_SCREEN, Stage.LITERATURE_COLLECT),
(Stage.EXPERIMENT_DESIGN, Stage.HYPOTHESIS_GEN),
(Stage.QUALITY_GATE, Stage.PAPER_OUTLINE),
],
)
def test_default_rollback_stage_for_known_gate_mappings(stage: Stage, expected: Stage):
assert default_rollback_stage(stage) is expected
def test_default_rollback_stage_for_unknown_stage_uses_previous_stage():
assert default_rollback_stage(Stage.PAPER_DRAFT) is Stage.PAPER_OUTLINE
def test_default_rollback_stage_for_first_stage_returns_self():
assert default_rollback_stage(Stage.TOPIC_INIT) is Stage.TOPIC_INIT
def test_transition_outcome_field_values_are_exposed():
outcome = TransitionOutcome(
stage=Stage.TOPIC_INIT,
status=StageStatus.RUNNING,
next_stage=Stage.TOPIC_INIT,
rollback_stage=Stage.TOPIC_INIT,
checkpoint_required=True,
decision="block",
)
assert outcome.checkpoint_required is True
assert outcome.decision == "block"
def test_sequence_and_neighbor_maps_are_consistent_for_all_stages():
for idx, stage in enumerate(STAGE_SEQUENCE):
expected_prev = STAGE_SEQUENCE[idx - 1] if idx > 0 else None
expected_next = (
STAGE_SEQUENCE[idx + 1] if idx + 1 < len(STAGE_SEQUENCE) else None
)
assert PREVIOUS_STAGE[stage] is expected_prev
assert NEXT_STAGE[stage] is expected_next
def test_transition_map_covers_all_stage_status_values():
assert set(TRANSITION_MAP.keys()) == set(StageStatus)
for source_status, targets in TRANSITION_MAP.items():
assert isinstance(targets, frozenset)
assert all(target in StageStatus for target in targets)
if source_status is StageStatus.DONE:
assert targets == frozenset()
# ── DECISION_ROLLBACK tests ──
def test_decision_rollback_has_pivot_and_refine():
assert "pivot" in DECISION_ROLLBACK
assert "refine" in DECISION_ROLLBACK
def test_decision_rollback_pivot_targets_hypothesis_gen():
assert DECISION_ROLLBACK["pivot"] is Stage.HYPOTHESIS_GEN
def test_decision_rollback_refine_targets_iterative_refine():
assert DECISION_ROLLBACK["refine"] is Stage.ITERATIVE_REFINE
def test_max_decision_pivots_is_positive():
assert MAX_DECISION_PIVOTS >= 1
================================================
FILE: tests/test_rc_templates.py
================================================
"""Unit tests for researchclaw.templates — conference templates + MD→LaTeX converter."""
from __future__ import annotations
import threading
import pytest
from researchclaw.templates.conference import (
CONFERENCE_REGISTRY,
ConferenceTemplate,
get_template,
list_conferences,
NEURIPS_2024,
NEURIPS_2025,
ICLR_2025,
ICLR_2026,
ICML_2025,
ICML_2026,
)
from researchclaw.templates.converter import (
markdown_to_latex,
_parse_sections,
_extract_title,
_extract_abstract,
_convert_inline,
_escape_latex,
_escape_algo_line,
_render_code_block,
_build_body,
_render_table,
_parse_table_row,
_parse_alignments,
_render_itemize,
_render_enumerate,
_reset_render_counters,
_next_table_num,
_next_figure_num,
check_paper_completeness, # noqa: F401
)
# =====================================================================
# conference.py tests
# =====================================================================
class TestConferenceTemplate:
"""Tests for ConferenceTemplate dataclass."""
def test_neurips_basic_fields(self) -> None:
t = NEURIPS_2024
assert t.name == "neurips_2024"
assert t.display_name == "NeurIPS 2024"
assert t.year == 2024
assert t.document_class == "article"
assert t.style_package == "neurips_2024"
assert t.columns == 1
assert t.author_format == "neurips"
assert t.bib_style == "plainnat"
def test_iclr_basic_fields(self) -> None:
t = ICLR_2025
assert t.name == "iclr_2025"
assert t.year == 2025
assert t.style_package == "iclr2025_conference"
assert t.bib_style == "iclr2025_conference"
assert t.columns == 1
assert t.author_format == "iclr"
def test_icml_basic_fields(self) -> None:
t = ICML_2025
assert t.name == "icml_2025"
assert t.year == 2025
assert t.style_package == "icml2025"
assert t.columns == 2
assert t.author_format == "icml"
assert t.bib_style == "icml2025"
def test_frozen(self) -> None:
with pytest.raises(AttributeError):
NEURIPS_2024.name = "hacked" # type: ignore[misc]
class TestRenderPreamble:
"""Tests for ConferenceTemplate.render_preamble()."""
def test_neurips_preamble_structure(self) -> None:
tex = NEURIPS_2024.render_preamble("My Title", "J. Doe", "An abstract.")
assert r"\documentclass{article}" in tex
assert r"\usepackage[preprint]{neurips_2024}" in tex
assert r"\title{My Title}" in tex
assert r"\author{J. Doe}" in tex
assert r"\begin{abstract}" in tex
assert "An abstract." in tex
assert r"\end{abstract}" in tex
assert r"\begin{document}" in tex
assert r"\maketitle" in tex
def test_iclr_preamble_no_options(self) -> None:
tex = ICLR_2025.render_preamble("Title", "Author", "Abstract")
assert r"\documentclass{article}" in tex # no options
assert r"\usepackage{iclr2025_conference}" in tex
def test_icml_author_block(self) -> None:
tex = ICML_2025.render_preamble("Title", "Alice", "Abstract")
assert r"\begin{icmlauthorlist}" in tex
assert r"\icmlauthor{Alice}{aff1}" in tex
assert r"\end{icmlauthorlist}" in tex
assert r"\icmlaffiliation{aff1}{Affiliation}" in tex
def test_icml_preamble_extra(self) -> None:
tex = ICML_2025.render_preamble("Title", "Author", "Abstract")
assert r"\icmltitlerunning{Title}" in tex
class TestRenderFooter:
"""Tests for ConferenceTemplate.render_footer()."""
def test_neurips_footer(self) -> None:
tex = NEURIPS_2024.render_footer("refs")
assert r"\bibliographystyle{plainnat}" in tex
assert r"\bibliography{refs}" in tex
assert r"\end{document}" in tex
def test_icml_footer(self) -> None:
tex = ICML_2025.render_footer()
assert r"\bibliographystyle{icml2025}" in tex
assert r"\bibliography{references}" in tex
def test_default_bib_file(self) -> None:
tex = NEURIPS_2024.render_footer()
assert r"\bibliography{references}" in tex
class TestGetTemplate:
"""Tests for get_template() lookup."""
def test_full_name(self) -> None:
assert get_template("neurips_2024") is NEURIPS_2024
def test_short_alias(self) -> None:
assert get_template("neurips") is NEURIPS_2025
assert get_template("iclr") is ICLR_2026
assert get_template("icml") is ICML_2026
def test_case_insensitive(self) -> None:
assert get_template("NeurIPS") is NEURIPS_2025
assert get_template("ICML_2026") is ICML_2026
def test_dash_and_space_normalization(self) -> None:
assert get_template("neurips-2025") is NEURIPS_2025
assert get_template("icml 2026") is ICML_2026
def test_unknown_raises(self) -> None:
with pytest.raises(KeyError, match="Unknown conference"):
get_template("aaai_2025")
class TestListConferences:
"""Tests for list_conferences()."""
def test_returns_canonical_names(self) -> None:
names = list_conferences()
assert "neurips_2025" in names
assert "iclr_2026" in names
assert "icml_2026" in names
# Should be deduplicated — no aliases (6 conference + 1 generic)
assert len(names) == 7
def test_sorted(self) -> None:
names = list_conferences()
assert names == sorted(names)
class TestConferenceRegistry:
"""Tests for CONFERENCE_REGISTRY dict."""
def test_all_aliases_resolve(self) -> None:
for key, tpl in CONFERENCE_REGISTRY.items():
assert isinstance(tpl, ConferenceTemplate)
assert tpl.name # not empty
# =====================================================================
# converter.py tests
# =====================================================================
class TestParseSections:
"""Tests for _parse_sections()."""
def test_empty(self) -> None:
sections = _parse_sections("")
assert len(sections) == 1
assert sections[0].level == 1
assert sections[0].body == ""
def test_single_heading(self) -> None:
md = "# Introduction\nHello world"
sections = _parse_sections(md)
assert len(sections) == 1
assert sections[0].level == 1
assert sections[0].heading == "Introduction"
assert "Hello world" in sections[0].body
def test_multiple_headings(self) -> None:
md = "# Title\nfoo\n## Method\nbar\n### Details\nbaz"
sections = _parse_sections(md)
assert len(sections) == 3
assert sections[0].heading == "Title"
assert sections[1].heading == "Method"
assert sections[2].heading == "Details"
def test_preamble_before_heading(self) -> None:
md = "Some text before\n\n# First\nBody"
sections = _parse_sections(md)
assert len(sections) == 2
assert sections[0].level == 0
assert "Some text before" in sections[0].body
def test_heading_lower(self) -> None:
md = "# Abstract\nContent"
sections = _parse_sections(md)
assert sections[0].heading_lower == "abstract"
class TestExtractTitle:
"""Tests for _extract_title()."""
def test_bold_title_after_heading(self) -> None:
md = "# Title\n**My Paper**\n\n# Abstract\nblah"
sections = _parse_sections(md)
assert _extract_title(sections, md) == "My Paper"
def test_first_non_meta_h1(self) -> None:
md = "# Introduction\nSome text"
sections = _parse_sections(md)
assert _extract_title(sections, md) == "Introduction"
def test_fallback(self) -> None:
sections = _parse_sections("")
assert _extract_title(sections, "") == "Untitled Paper"
class TestExtractAbstract:
"""Tests for _extract_abstract()."""
def test_from_h1(self) -> None:
md = "# Abstract\nThis is the abstract.\n\n# Intro\nBody"
sections = _parse_sections(md)
assert "This is the abstract." in _extract_abstract(sections)
def test_from_h2(self) -> None:
md = "# Title\nfoo\n## Abstract\nAbstract text.\n## Intro"
sections = _parse_sections(md)
assert "Abstract text." in _extract_abstract(sections)
def test_missing_abstract(self) -> None:
md = "# Introduction\nNo abstract here"
sections = _parse_sections(md)
assert _extract_abstract(sections) == ""
class TestConvertInline:
"""Tests for _convert_inline()."""
def test_bold(self) -> None:
assert r"\textbf{bold}" in _convert_inline("**bold**")
def test_italic(self) -> None:
assert r"\textit{italic}" in _convert_inline("*italic*")
def test_inline_code(self) -> None:
assert r"\texttt{code}" in _convert_inline("`code`")
def test_link(self) -> None:
result = _convert_inline("[text](http://example.com)")
assert r"\href{http://example.com}{text}" in result
def test_special_chars_escaped(self) -> None:
result = _convert_inline("100% done & 5# items")
assert r"100\% done \& 5\# items" in result
def test_math_preserved(self) -> None:
result = _convert_inline(r"where \(x + y\) is given")
assert r"\(x + y\)" in result
def test_cite_preserved(self) -> None:
result = _convert_inline(r"as shown by \cite{doe2024}")
assert r"\cite{doe2024}" in result
def test_dollar_math_preserved(self) -> None:
result = _convert_inline("the value $x^2$ is")
assert "$x^2$" in result
def test_pre_escaped_underscore_not_doubled(self) -> None:
"""BUG-182: LLM pre-escapes underscores → must NOT double-escape to \\\\_."""
result = _convert_inline(r"RawObservation\_PPO\_WithNorm")
assert r"\\_" not in result, f"Double-escaped: {result}"
assert r"\_" in result
def test_pre_escaped_underscore_near_math(self) -> None:
"""BUG-182: Pre-escaped underscore adjacent to math must not break."""
result = _convert_inline(
r"RawObs\_PPO. Statistics \(\mu_t\) are given"
)
assert r"\\_" not in result
assert r"\_" in result
assert r"\(\mu_t\)" in result
def test_pre_escaped_hash_not_doubled(self) -> None:
"""BUG-182: Pre-escaped hash should not be double-escaped."""
result = _convert_inline(r"Section \#3 details")
assert r"\\#" not in result
assert r"\#" in result
class TestEscapeLatex:
"""Tests for _escape_latex()."""
def test_special_chars(self) -> None:
assert r"\#" in _escape_latex("#")
assert r"\%" in _escape_latex("%")
assert r"\&" in _escape_latex("&")
assert r"\_" in _escape_latex("_")
def test_math_not_escaped(self) -> None:
result = _escape_latex(r"value \(x_1\) here")
assert r"\(x_1\)" in result # underscore inside math preserved
class TestBuildBody:
"""Tests for _build_body()."""
def test_skips_title_and_abstract(self) -> None:
md = "# Title\nfoo\n# Abstract\nbar\n# Introduction\nbaz"
sections = _parse_sections(md)
body = _build_body(sections)
assert r"\section{Introduction}" in body
assert "baz" in body
# Title and abstract should not appear as sections
assert r"\section{Title}" not in body
assert r"\section{Abstract}" not in body
def test_subsection_promoted_when_all_h2(self) -> None:
"""T1.3: When all body sections are H2, they should be promoted to \\section."""
md = "## Method\ntext"
sections = _parse_sections(md)
body = _build_body(sections)
# All-H2 document → auto-promoted to \section
assert r"\section{Method}" in body
def test_h2_promoted_under_h1_title(self) -> None:
"""When title occupies H1, H2 body sections promote to \\section."""
md = "# My Paper\ntitle body\n## Method\ntext"
sections = _parse_sections(md)
body = _build_body(sections, title="My Paper")
assert r"\section{Method}" in body
def test_subsubsection(self) -> None:
md = "## Intro\nintro\n### Details\ntext"
sections = _parse_sections(md)
body = _build_body(sections)
# H2 promoted to \section, H3 promoted to \subsection
assert r"\subsection{Details}" in body
class TestListRendering:
"""Tests for bullet and numbered list rendering."""
def test_bullet_list(self) -> None:
items = ["First item", "Second item"]
result = _render_itemize(items)
assert r"\begin{itemize}" in result
assert r"\item First item" in result
assert r"\item Second item" in result
assert r"\end{itemize}" in result
def test_numbered_list(self) -> None:
items = ["Step one", "Step two"]
result = _render_enumerate(items)
assert r"\begin{enumerate}" in result
assert r"\item Step one" in result
assert r"\end{enumerate}" in result
class TestTableRendering:
"""Tests for Markdown table → LaTeX tabular conversion."""
def test_parse_table_row(self) -> None:
assert _parse_table_row("| a | b | c |") == ["a", "b", "c"]
def test_parse_alignments(self) -> None:
assert _parse_alignments("| --- | :---: | ---: |", 3) == ["l", "c", "r"]
def test_render_simple_table(self) -> None:
lines = [
"| Name | Value |",
"| --- | --- |",
"| A | 1 |",
"| B | 2 |",
]
result = _render_table(lines)
assert r"\begin{table}" in result
assert r"\begin{tabular}{ll}" in result
assert r"\toprule" in result
assert r"\textbf{Name}" in result
assert r"\midrule" in result
assert r"\bottomrule" in result
assert r"\end{tabular}" in result
assert r"\end{table}" in result
def test_render_counters_are_thread_local(self) -> None:
results: list[tuple[int, int, int]] = []
lock = threading.Lock()
def worker() -> None:
_reset_render_counters()
value = (_next_table_num(), _next_table_num(), _next_figure_num())
with lock:
results.append(value)
threads = [threading.Thread(target=worker) for _ in range(4)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert results == [(1, 2, 1)] * 4
# =====================================================================
# markdown_to_latex integration tests
# =====================================================================
class TestMarkdownToLatex:
"""Integration tests for the full conversion pipeline."""
SAMPLE_MD = (
"# Title\n"
"**My Great Paper**\n\n"
"# Abstract\n"
"This is the abstract.\n\n"
"# Introduction\n"
"We study the problem of RL.\n\n"
"## Related Work\n"
"Prior work includes **many** approaches.\n\n"
"# Method\n"
"Our method uses \\(f(x) = x^2\\) as the objective.\n\n"
"# Results\n"
"- Result 1\n"
"- Result 2\n\n"
"# Conclusion\n"
"We conclude.\n\n"
"# References\n"
"1. Doe et al. (2024)\n"
)
def test_neurips_full(self) -> None:
tex = markdown_to_latex(self.SAMPLE_MD, NEURIPS_2024)
assert r"\documentclass{article}" in tex
assert r"\usepackage[preprint]{neurips_2024}" in tex
assert r"\title{My Great Paper}" in tex
assert r"\begin{abstract}" in tex
assert "This is the abstract." in tex
assert r"\section{Introduction}" in tex
assert r"\subsection{Related Work}" in tex
assert r"\section{Method}" in tex
assert r"\begin{itemize}" in tex
assert r"\bibliographystyle{plainnat}" in tex
assert r"\end{document}" in tex
def test_iclr_full(self) -> None:
tex = markdown_to_latex(self.SAMPLE_MD, ICLR_2025)
assert r"\usepackage{iclr2025_conference}" in tex
assert r"\bibliographystyle{iclr2025_conference}" in tex
def test_icml_full(self) -> None:
tex = markdown_to_latex(self.SAMPLE_MD, ICML_2025, authors="Alice")
assert r"\begin{icmlauthorlist}" in tex
assert r"\icmlauthor{Alice}{aff1}" in tex
assert r"\bibliographystyle{icml2025}" in tex
def test_custom_title_override(self) -> None:
tex = markdown_to_latex(
"# Abstract\nblah\n# Intro\nbody",
NEURIPS_2024,
title="Override Title",
)
assert r"\title{Override Title}" in tex
def test_custom_authors(self) -> None:
tex = markdown_to_latex(self.SAMPLE_MD, NEURIPS_2024, authors="Jane Doe")
assert r"\author{Jane Doe}" in tex
def test_custom_bib_file(self) -> None:
tex = markdown_to_latex(self.SAMPLE_MD, NEURIPS_2024, bib_file="my_refs")
assert r"\bibliography{my_refs}" in tex
def test_math_preserved_in_output(self) -> None:
md = "# Abstract\nabs\n# Method\n\\(f(x)\\) and \\[E = mc^2\\]"
tex = markdown_to_latex(md, NEURIPS_2024, title="T")
assert r"\(f(x)\)" in tex
assert r"\[E = mc^2\]" in tex
def test_empty_paper(self) -> None:
tex = markdown_to_latex("", NEURIPS_2024, title="Empty")
assert r"\begin{document}" in tex
assert r"\end{document}" in tex
def test_display_math_block(self) -> None:
md = "# Abstract\nabs\n# Method\n\\[\nx = y + z\n\\]"
tex = markdown_to_latex(md, NEURIPS_2024, title="T")
assert "x = y + z" in tex
def test_code_block(self) -> None:
md = "# Abstract\nabs\n# Method\n```python\nprint('hello')\n```"
tex = markdown_to_latex(md, NEURIPS_2024, title="T")
assert r"\begin{verbatim}" in tex
assert "print('hello')" in tex
assert r"\end{verbatim}" in tex
def test_table_in_paper(self) -> None:
md = (
"# Abstract\nabs\n"
"# Results\n"
"| Model | Score |\n"
"| --- | --- |\n"
"| Ours | 95.0 |\n"
)
tex = markdown_to_latex(md, NEURIPS_2024, title="T")
assert r"\begin{tabular}" in tex
assert r"\textbf{Model}" in tex
# =====================================================================
# ExportConfig tests
# =====================================================================
class TestExportConfig:
"""Tests for ExportConfig in config.py."""
def test_default_values(self) -> None:
from researchclaw.config import ExportConfig
ec = ExportConfig()
assert ec.target_conference == "neurips_2025"
assert ec.authors == "Anonymous"
assert ec.bib_file == "references"
def test_frozen(self) -> None:
from researchclaw.config import ExportConfig
ec = ExportConfig()
with pytest.raises(AttributeError):
ec.target_conference = "icml" # type: ignore[misc]
def test_rcconfig_has_export(self) -> None:
from researchclaw.config import RCConfig
cfg = RCConfig.load("config.researchclaw.example.yaml", check_paths=False)
assert hasattr(cfg, "export")
assert cfg.export.target_conference == "neurips_2025"
def test_rcconfig_export_from_dict(self) -> None:
from researchclaw.config import RCConfig
import yaml
from pathlib import Path
data = yaml.safe_load(Path("config.researchclaw.example.yaml").read_text())
data["export"] = {
"target_conference": "icml_2025",
"authors": "Test Author",
"bib_file": "mybib",
}
cfg = RCConfig.from_dict(data, check_paths=False)
assert cfg.export.target_conference == "icml_2025"
assert cfg.export.authors == "Test Author"
assert cfg.export.bib_file == "mybib"
# =====================================================================
# hitl_required_stages validation update test
# =====================================================================
class TestHitlStageValidation:
"""Test that hitl_required_stages now accepts up to stage 23."""
def test_stage_23_valid(self) -> None:
from researchclaw.config import validate_config
import yaml
from pathlib import Path
data = yaml.safe_load(Path("config.researchclaw.example.yaml").read_text())
data.setdefault("security", {})["hitl_required_stages"] = [1, 22, 23]
result = validate_config(data, check_paths=False)
assert result.ok, f"Errors: {result.errors}"
def test_get_style_files_returns_bundled_sty(self) -> None:
"""Each conference template bundles at least one .sty file."""
for name in ["neurips_2025", "neurips_2024", "iclr_2026", "iclr_2025", "icml_2026", "icml_2025"]:
tpl = get_template(name)
files = tpl.get_style_files()
assert len(files) >= 1, f"No style files for {name}"
sty_names = [f.name for f in files]
assert any(f.endswith(".sty") for f in sty_names), f"No .sty file for {name}"
def test_iclr_icml_have_bst_files(self) -> None:
"""ICLR and ICML templates bundle custom .bst files."""
for name in ["iclr_2026", "iclr_2025", "icml_2026", "icml_2025"]:
tpl = get_template(name)
files = tpl.get_style_files()
bst_names = [f.name for f in files if f.suffix == ".bst"]
assert len(bst_names) >= 1, f"No .bst file for {name}"
def test_stage_24_invalid(self) -> None:
from researchclaw.config import validate_config
import yaml
from pathlib import Path
data = yaml.safe_load(Path("config.researchclaw.example.yaml").read_text())
data.setdefault("security", {})["hitl_required_stages"] = [24]
result = validate_config(data, check_paths=False)
assert not result.ok
assert any("24" in e for e in result.errors)
# =====================================================================
# check_paper_completeness — section word count + bullet density checks
# =====================================================================
class TestCompletenessWordCountAndBullets:
"""Tests for new per-section word count and bullet density checks."""
@staticmethod
def _make_sections(section_specs: list[tuple[str, int, bool]]) -> list:
"""Build _Section objects from (heading, word_count, use_bullets) specs."""
results = []
for heading, wc, bullets in section_specs:
if bullets:
lines = [f"- Point number {i}" for i in range(wc // 3)]
body = "\n".join(lines)
else:
body = " ".join(["word"] * wc)
results.append(
type("_Section", (), {
"level": 1,
"heading": heading,
"heading_lower": heading.lower(),
"body": body,
})()
)
return results
def test_completeness_section_word_count_short(self) -> None:
"""A Method section with only 100 words triggers a warning."""
secs = self._make_sections([
("Title", 5, False),
("Abstract", 200, False),
("Introduction", 900, False),
("Related Work", 700, False),
("Method", 100, False),
("Experiments", 1000, False),
("Results", 700, False),
("Conclusion", 250, False),
])
warns = check_paper_completeness(secs)
method_warns = [w for w in warns if "Method" in w and "words" in w]
assert len(method_warns) >= 1, f"Expected word count warning, got: {warns}"
def test_completeness_bullet_density(self) -> None:
"""A Method section full of bullet points triggers a warning."""
secs = self._make_sections([
("Title", 5, False),
("Abstract", 200, False),
("Introduction", 900, False),
("Related Work", 700, False),
("Method", 300, True),
("Experiments", 1000, False),
("Results", 700, False),
("Conclusion", 250, False),
])
warns = check_paper_completeness(secs)
bullet_warns = [w for w in warns if "bullet" in w.lower() and "Method" in w]
assert len(bullet_warns) >= 1, f"Expected bullet warning, got: {warns}"
# =====================================================================
# BUG-177: Algorithm pseudocode escaping tests
# =====================================================================
class TestAlgorithmEscaping:
"""Tests for _escape_algo_line and algorithm rendering in _render_code_block."""
def test_escape_underscore(self) -> None:
assert r"psi\_1" in _escape_algo_line("psi_1")
def test_escape_hash_comment(self) -> None:
result = _escape_algo_line("x = y # update rule")
assert r"\COMMENT{update rule}" in result
assert "x = y" in result
def test_fullline_hash_comment(self) -> None:
result = _escape_algo_line("# Initialize buffer")
assert result == r"\COMMENT{Initialize buffer}"
def test_escape_percent(self) -> None:
assert r"\%" in _escape_algo_line("accuracy 95%")
def test_escape_ampersand(self) -> None:
assert r"\&" in _escape_algo_line("x & y")
def test_preserve_latex_commands(self) -> None:
result = _escape_algo_line(r"Set $x = \alpha$ and update")
assert r"$x = \alpha$" in result
def test_render_code_block_algo_escapes(self) -> None:
code = (
"Initialize theta_1, theta_2\n"
"for t = 1 to T do\n"
" Sample batch B # prioritized\n"
)
result = _render_code_block("algorithm", code)
assert r"\begin{algorithm}" in result
assert r"\begin{algorithmic}" in result
assert r"theta\_1" in result
assert r"\COMMENT{prioritized}" in result
def test_render_code_block_verbatim_no_escape(self) -> None:
"""Non-algorithm code blocks should use verbatim (no escaping)."""
code = "x_1 = y_2 # comment"
result = _render_code_block("python", code)
assert r"\begin{verbatim}" in result
assert "x_1" in result # NOT escaped in verbatim
================================================
FILE: tests/test_rc_validator.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportUnusedCallResult=false, reportAttributeAccessIssue=false, reportUnknownLambdaType=false
from __future__ import annotations
import pytest
from researchclaw.experiment.validator import (
BANNED_MODULES,
DANGEROUS_BUILTINS,
DANGEROUS_CALLS,
CodeValidation,
ValidationIssue,
check_filename_collisions,
extract_imports,
format_issues_for_llm,
validate_code,
validate_imports,
validate_security,
validate_syntax,
)
def _call_source(name: str) -> str:
top = name.split(".")[0]
lines: list[str] = []
if top in {"os", "subprocess", "shutil"}:
lines.append(f"import {top}")
lines.append(f"{name}()")
return "\n".join(lines)
def test_validate_syntax_accepts_valid_code():
result = validate_syntax("x = 1\nif x > 0:\n x += 1")
assert result.ok is True
assert result.issues == []
def test_validate_syntax_reports_syntax_error_with_location():
result = validate_syntax("def bad(:\n pass")
assert result.ok is False
assert len(result.issues) == 1
issue = result.issues[0]
assert issue.severity == "error"
assert issue.category == "syntax"
assert issue.line == 1
assert issue.col is not None
assert issue.message
@pytest.mark.parametrize("code", ["", " \n\t ", "# comment only\n# still comment"])
def test_validate_syntax_accepts_empty_whitespace_and_comment_only(code: str):
result = validate_syntax(code)
assert result.ok is True
assert result.issues == []
def test_validate_security_accepts_safe_code():
code = 'import os\nvalue = os.path.join("a", "b")\nprint(value)'
result = validate_security(code)
assert result.ok is True
assert result.issues == []
def test_validate_security_skips_when_code_has_syntax_error():
result = validate_security("def broken(:\n pass")
assert result.ok is True
assert result.issues == []
@pytest.mark.parametrize("builtin_name", sorted(DANGEROUS_BUILTINS))
def test_validate_security_flags_every_dangerous_builtin_call(builtin_name: str):
if builtin_name == "__import__":
code = '__import__("os")'
elif builtin_name == "compile":
code = 'compile("x = 1", "", "exec")'
else:
code = f'{builtin_name}("print(1)")'
result = validate_security(code)
assert len(result.issues) == 1
issue = result.issues[0]
assert issue.severity == "error"
assert issue.category == "security"
assert issue.message == f"Dangerous built-in call: {builtin_name}()"
@pytest.mark.parametrize("call_name", sorted(DANGEROUS_CALLS))
def test_validate_security_flags_every_dangerous_call(call_name: str):
result = validate_security(_call_source(call_name))
messages = [issue.message for issue in result.issues]
assert f"Dangerous call: {call_name}()" in messages
assert all(issue.severity == "error" for issue in result.issues)
assert all(issue.category == "security" for issue in result.issues)
@pytest.mark.parametrize("module_name", sorted(BANNED_MODULES))
def test_validate_security_flags_every_banned_import(module_name: str):
result = validate_security(f"import {module_name}")
assert len(result.issues) == 1
issue = result.issues[0]
assert issue.severity == "error"
assert issue.category == "security"
assert issue.message == f"Banned module import: {module_name}"
@pytest.mark.parametrize("module_name", sorted(BANNED_MODULES))
def test_validate_security_flags_every_banned_from_import(module_name: str):
result = validate_security(f"from {module_name} import x")
assert len(result.issues) == 1
issue = result.issues[0]
assert issue.severity == "error"
assert issue.category == "security"
assert issue.message == f"Banned module import: from {module_name}"
def test_validate_imports_recognizes_stdlib_modules_by_default():
result = validate_imports("import json\nfrom math import sqrt")
assert result.ok is True
assert result.warnings == []
def test_validate_imports_warns_for_unavailable_package():
result = validate_imports("import totally_missing_pkg")
assert result.ok is True
assert len(result.warnings) == 1
warning = result.warnings[0]
assert warning.severity == "warning"
assert warning.category == "import"
assert (
warning.message
== "Module 'totally_missing_pkg' may not be available in sandbox"
)
def test_validate_imports_respects_custom_available_set():
result = validate_imports(
"import alpha\nimport beta\nimport gamma",
available={"alpha", "gamma"},
)
assert [w.message for w in result.warnings] == [
"Module 'beta' may not be available in sandbox",
]
def test_validate_imports_returns_no_warnings_for_syntax_error_input():
result = validate_imports("def bad(:\n pass", available=set())
assert result.ok is True
assert result.warnings == []
@pytest.mark.parametrize("code", ["", " \n\t ", "# comment only"])
def test_validate_imports_handles_empty_like_inputs(code: str):
result = validate_imports(code, available=set())
assert result.ok is True
assert result.warnings == []
def test_validate_code_combines_security_and_import_issues_in_order():
code = 'import os\nos.system("echo hi")\nimport unknown_mod'
result = validate_code(code, available_packages={"os"})
assert result.ok is False
assert [i.category for i in result.issues] == ["security", "import"]
assert result.issues[0].message == "Dangerous call: os.system()"
assert (
result.issues[1].message
== "Module 'unknown_mod' may not be available in sandbox"
)
def test_validate_code_short_circuits_after_syntax_error():
result = validate_code("def bad(:\n pass")
assert len(result.issues) == 1
assert result.issues[0].category == "syntax"
def test_validate_code_skip_security_excludes_security_issues():
code = 'import os\nos.system("echo hi")\nimport unknown_mod'
result = validate_code(code, available_packages={"os"}, skip_security=True)
assert [i.category for i in result.issues] == ["import"]
def test_validate_code_skip_imports_excludes_import_warnings():
code = 'import os\nos.system("echo hi")\nimport unknown_mod'
result = validate_code(code, available_packages={"os"}, skip_imports=True)
assert all(issue.category == "security" for issue in result.issues)
assert len(result.issues) == 1
def test_validate_code_skip_both_returns_clean_for_safe_code():
result = validate_code("x = 1", skip_security=True, skip_imports=True)
assert result.ok is True
assert result.issues == []
def test_validate_code_uses_available_packages_for_import_validation():
code = "import alpha\nimport beta"
result = validate_code(code, available_packages={"alpha"})
assert [i.message for i in result.issues] == [
"Module 'beta' may not be available in sandbox",
]
def test_extract_imports_supports_import_and_from_import_styles():
code = (
"import os\nimport numpy as np\nfrom pandas import DataFrame\nfrom x.y import z"
)
assert extract_imports(code) == {"os", "numpy", "pandas", "x"}
def test_extract_imports_supports_multiple_aliases_and_dedupes():
code = "import os.path, os, json as js\nfrom json import loads"
assert extract_imports(code) == {"os", "json"}
def test_extract_imports_ignores_relative_import_without_module_name():
assert extract_imports("from . import local_mod") == set()
def test_extract_imports_includes_relative_import_with_module_name():
assert extract_imports("from ..pkg.sub import thing") == {"pkg"}
def test_extract_imports_returns_empty_set_for_syntax_error():
assert extract_imports("def bad(:\n pass") == set()
@pytest.mark.parametrize("code", ["", " \n\t", "# comment only"])
def test_extract_imports_handles_empty_like_inputs(code: str):
assert extract_imports(code) == set()
def test_format_issues_for_llm_returns_no_issues_message_when_clean():
assert format_issues_for_llm(CodeValidation()) == "No issues found."
def test_format_issues_for_llm_formats_issues_with_and_without_line():
validation = CodeValidation(
issues=[
ValidationIssue(
severity="error",
category="syntax",
message="invalid syntax",
line=3,
),
ValidationIssue(
severity="warning",
category="import",
message="Module 'x' may be missing",
line=None,
),
]
)
formatted = format_issues_for_llm(validation)
assert "- [ERROR] (syntax) invalid syntax @ line 3" in formatted
assert (
"- [WARNING] (import) Module 'x' may be missing @ unknown location" in formatted
)
def test_format_issues_for_llm_preserves_issue_order():
validation = CodeValidation(
issues=[
ValidationIssue(severity="warning", category="import", message="first"),
ValidationIssue(
severity="error", category="security", message="second", line=9
),
]
)
formatted = format_issues_for_llm(validation).splitlines()
assert formatted[0] == "- [WARNING] (import) first @ unknown location"
assert formatted[1] == "- [ERROR] (security) second @ line 9"
def test_code_validation_ok_true_when_no_errors_present():
validation = CodeValidation(
issues=[ValidationIssue(severity="warning", category="import", message="warn")]
)
assert validation.ok is True
def test_code_validation_ok_false_when_error_present():
validation = CodeValidation(
issues=[ValidationIssue(severity="error", category="syntax", message="bad")]
)
assert validation.ok is False
def test_code_validation_errors_and_warnings_filter_correctly():
err = ValidationIssue(severity="error", category="security", message="danger")
warn = ValidationIssue(
severity="warning", category="import", message="maybe missing"
)
validation = CodeValidation(issues=[err, warn])
assert validation.errors == [err]
assert validation.warnings == [warn]
def test_code_validation_summary_for_no_issues():
assert CodeValidation().summary() == "Code validation passed."
def test_code_validation_summary_for_errors_only():
validation = CodeValidation(
issues=[ValidationIssue(severity="error", category="syntax", message="bad")]
)
assert validation.summary() == "Code validation: 1 error(s)"
def test_code_validation_summary_for_warnings_only():
validation = CodeValidation(
issues=[ValidationIssue(severity="warning", category="import", message="warn")]
)
assert validation.summary() == "Code validation: 1 warning(s)"
def test_code_validation_summary_for_errors_and_warnings():
validation = CodeValidation(
issues=[
ValidationIssue(severity="error", category="syntax", message="bad"),
ValidationIssue(severity="warning", category="import", message="warn"),
]
)
assert validation.summary() == "Code validation: 1 error(s), 1 warning(s)"
# ---------------------------------------------------------------------------
# check_filename_collisions (BUG-202)
# ---------------------------------------------------------------------------
def test_filename_collision_detects_config_py():
"""BUG-202: config.py shadows pip 'config' package."""
warnings = check_filename_collisions({"config.py": "x = 1", "main.py": "print(1)"})
assert len(warnings) == 1
assert "shadows stdlib/pip" in warnings[0]
assert "config" in warnings[0]
def test_filename_collision_detects_stdlib_shadows():
"""Filenames shadowing stdlib modules should be flagged."""
warnings = check_filename_collisions({"json.py": "x = 1"})
assert len(warnings) == 1
assert "json" in warnings[0]
def test_filename_collision_allows_safe_names():
"""Normal experiment filenames should not trigger warnings."""
files = {
"main.py": "print(1)",
"models.py": "class M: pass",
"training.py": "def train(): pass",
"data_loader.py": "def load(): pass",
"experiment_config.py": "LR = 0.01",
"requirements.txt": "torch",
}
warnings = check_filename_collisions(files)
assert warnings == []
def test_filename_collision_multiple_shadows():
"""Multiple shadowing files should each produce a warning."""
files = {"config.py": "", "logging.py": "", "main.py": ""}
warnings = check_filename_collisions(files)
assert len(warnings) == 2
================================================
FILE: tests/test_results_table_builder.py
================================================
"""Tests for results_table_builder — pre-built LaTeX tables."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.pipeline.verified_registry import VerifiedRegistry
from researchclaw.templates.results_table_builder import (
LatexTable,
build_condition_whitelist,
build_results_tables,
)
ARTIFACTS = Path(__file__).resolve().parent.parent / "artifacts"
def _make_registry(
conditions: dict[str, dict[int, float]],
primary_metric: float | None = None,
) -> VerifiedRegistry:
"""Create a registry from simple condition → {seed: value} mapping."""
summary = {"best_run": {"metrics": {}}, "condition_summaries": {}, "metrics_summary": {}}
for cond_name, seeds in conditions.items():
for seed_idx, value in seeds.items():
key = f"{cond_name}/{seed_idx}/metric"
summary["best_run"]["metrics"][key] = value
cond_metric = sum(seeds.values()) / len(seeds) if seeds else 0
summary["condition_summaries"][cond_name] = {"metrics": {"metric": cond_metric}}
if primary_metric is not None:
summary["best_run"]["metrics"]["primary_metric"] = primary_metric
return VerifiedRegistry.from_experiment(summary)
class TestBuildResultsTables:
def test_basic_table(self):
reg = _make_registry(
{
"Baseline": {0: 80.0, 1: 82.0, 2: 81.0},
"Proposed": {0: 85.0, 1: 87.0, 2: 86.0},
},
primary_metric=86.0,
)
tables = build_results_tables(reg, metric_name="Accuracy (\\%)")
assert len(tables) == 2 # main + per-seed
main = tables[0]
assert main.label == "tab:main_results"
assert "AUTO-GENERATED" in main.latex_code
assert "\\begin{table}" in main.latex_code
assert "Baseline" in main.latex_code
assert "Proposed" in main.latex_code
assert main.n_conditions == 2
def test_best_is_bolded(self):
reg = _make_registry(
{
"Baseline": {0: 70.0, 1: 72.0},
"Proposed": {0: 85.0, 1: 87.0},
}
)
tables = build_results_tables(reg, metric_direction="maximize")
main = tables[0]
# Proposed should be bold (higher metric)
assert "\\textbf" in main.latex_code
def test_single_seed_marker(self):
reg = _make_registry(
{
"Baseline": {0: 80.0, 1: 82.0},
"Proposed": {0: 90.0}, # Single seed
}
)
tables = build_results_tables(reg)
main = tables[0]
assert "\\ddagger" in main.latex_code # Single-seed footnote
def test_no_conditions(self):
reg = VerifiedRegistry()
tables = build_results_tables(reg)
assert len(tables) == 0
def test_all_single_seed_no_per_seed_table(self):
reg = _make_registry(
{
"A": {0: 80.0},
"B": {0: 70.0},
}
)
tables = build_results_tables(reg)
# Only 1 table (main), no per-seed table (all single seed)
assert len(tables) == 1
def test_per_seed_table_structure(self):
reg = _make_registry(
{
"DQN": {0: 156.1, 1: 105.5, 2: 356.7},
"DQN+Abstraction": {0: 98.1, 1: 456.7, 2: 282.0},
}
)
tables = build_results_tables(reg)
assert len(tables) == 2
seed_table = tables[1]
assert seed_table.label == "tab:per_seed"
assert "156.10" in seed_table.latex_code or "156.1" in seed_table.latex_code
assert "Seed 0" in seed_table.latex_code
def test_two_column_uses_table_star(self):
reg = _make_registry({"A": {0: 80.0, 1: 82.0}})
tables = build_results_tables(reg, two_column=True)
assert "\\begin{table*}" in tables[0].latex_code
def test_verified_values_populated(self):
reg = _make_registry(
{"A": {0: 80.0, 1: 82.0}, "B": {0: 70.0, 1: 72.0}}
)
tables = build_results_tables(reg)
main = tables[0]
assert 81.0 in main.verified_values or any(
abs(v - 81.0) < 0.01 for v in main.verified_values
)
def test_special_chars_escaped(self):
reg = _make_registry({"DQN+Raw_Count": {0: 80.0, 1: 82.0}})
tables = build_results_tables(reg)
assert "DQN+Raw\\_Count" in tables[0].latex_code
def test_minimize_direction(self):
reg = _make_registry(
{
"Baseline": {0: 20.0, 1: 22.0},
"Proposed": {0: 10.0, 1: 12.0},
}
)
tables = build_results_tables(reg, metric_direction="minimize")
# Proposed (lower) should be bold
lines = tables[0].latex_code.split("\n")
proposed_line = [l for l in lines if "Proposed" in l][0]
assert "\\textbf" in proposed_line
class TestConditionWhitelist:
def test_basic(self):
reg = _make_registry(
{
"DQN": {0: 206.1, 1: 105.5, 2: 356.7},
"DQN+Abstraction": {0: 278.93},
}
)
wl = build_condition_whitelist(reg)
assert "DQN" in wl
assert "DQN+Abstraction" in wl
assert "3 seed(s)" in wl
assert "1 seed(s)" in wl
def test_empty_registry(self):
reg = VerifiedRegistry()
wl = build_condition_whitelist(reg)
assert "no conditions completed" in wl
class TestRealArtifacts:
def _load(self, run_id: str) -> VerifiedRegistry:
pattern = f"rc-*-{run_id}"
matches = sorted(ARTIFACTS.glob(pattern))
if not matches:
pytest.skip(f"Artifact {run_id} not found")
summary_path = matches[0] / "stage-14" / "experiment_summary.json"
ref_path = matches[0] / "stage-13" / "refinement_log.json"
if not summary_path.exists():
pytest.skip(f"No experiment_summary for {run_id}")
summary = json.loads(summary_path.read_text())
ref_log = None
if ref_path.exists():
ref_log = json.loads(ref_path.read_text())
return VerifiedRegistry.from_experiment(summary, ref_log)
def test_run_e57360_rl_tables(self):
reg = self._load("e57360")
tables = build_results_tables(reg, metric_name="Return")
assert len(tables) >= 1
main = tables[0]
# Should NOT contain PPO (never ran)
assert "PPO" not in main.latex_code
# Should contain DQN
assert "DQN" in main.latex_code
def test_run_acbdfa_tables(self):
reg = self._load("acbdfa")
tables = build_results_tables(reg, metric_name="Top-1 Accuracy (\\%)")
assert len(tables) >= 1
================================================
FILE: tests/test_robotics_adapter.py
================================================
"""Tests for robotics & control domain adapter.
Covers adapter dispatch, prompt block generation, and integration
with the existing domain detection and profile system.
"""
from __future__ import annotations
import pytest
from researchclaw.domains.detector import (
get_profile,
_keyword_detect,
_profile_cache,
)
from researchclaw.domains.prompt_adapter import (
MLPromptAdapter,
GenericPromptAdapter,
get_adapter,
)
# ---------------------------------------------------------------------------
# Profile sanity
# ---------------------------------------------------------------------------
class TestRoboticsProfile:
def setup_method(self):
_profile_cache.clear()
def test_profile_exists(self):
profile = get_profile("robotics_control")
assert profile is not None
assert profile.domain_id == "robotics_control"
def test_profile_fields(self):
profile = get_profile("robotics_control")
assert profile is not None
assert profile.experiment_paradigm == "comparison"
assert "gymnasium" in profile.core_libraries
assert "stable-baselines3" in profile.core_libraries
assert profile.gpu_required is True
def test_profile_baselines(self):
profile = get_profile("robotics_control")
assert profile is not None
baselines = profile.standard_baselines
assert any("PPO" in b for b in baselines)
assert any("SAC" in b for b in baselines)
# ---------------------------------------------------------------------------
# Keyword detection
# ---------------------------------------------------------------------------
class TestRoboticsKeywordDetection:
def test_robot_keyword(self):
assert _keyword_detect("robot manipulation task") == "robotics_control"
def test_mujoco(self):
assert _keyword_detect("locomotion in MuJoCo") == "robotics_control"
def test_pybullet(self):
assert _keyword_detect("grasping policy with PyBullet") == "robotics_control"
# ---------------------------------------------------------------------------
# Adapter dispatch
# ---------------------------------------------------------------------------
class TestRoboticsAdapter:
def test_gets_robotics_adapter(self):
profile = get_profile("robotics_control")
if profile is None:
pytest.skip("robotics_control profile not found")
adapter = get_adapter(profile)
assert not isinstance(adapter, MLPromptAdapter)
# Before this contribution it would fall back to GenericPromptAdapter
from researchclaw.domains.adapters.robotics import (
RoboticsPromptAdapter,
)
assert isinstance(adapter, RoboticsPromptAdapter)
def test_code_generation_blocks_nonempty(self):
profile = get_profile("robotics_control")
if profile is None:
pytest.skip("robotics_control profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints
assert blocks.dataset_guidance
assert blocks.output_format_guidance
def test_experiment_design_mentions_baselines(self):
profile = get_profile("robotics_control")
if profile is None:
pytest.skip("robotics_control profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_experiment_design_blocks({})
assert "PPO" in blocks.experiment_design_context
assert "SAC" in blocks.experiment_design_context
def test_result_analysis_mentions_return(self):
profile = get_profile("robotics_control")
if profile is None:
pytest.skip("robotics_control profile not found")
adapter = get_adapter(profile)
blocks = adapter.get_result_analysis_blocks({})
assert "return" in blocks.result_analysis_hints.lower()
def test_blueprint_context(self):
profile = get_profile("robotics_control")
if profile is None:
pytest.skip("robotics_control profile not found")
adapter = get_adapter(profile)
ctx = adapter.get_blueprint_context()
if profile.typical_file_structure:
assert "agent.py" in ctx or "train.py" in ctx
================================================
FILE: tests/test_servers.py
================================================
"""Tests for multi-server resource scheduling (C2): Registry, Monitor, Dispatcher, Executors."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from researchclaw.servers.registry import ServerEntry, ServerRegistry
from researchclaw.servers.monitor import ServerMonitor, _parse_status_output
from researchclaw.servers.dispatcher import TaskDispatcher
from researchclaw.servers.ssh_executor import SSHExecutor
from researchclaw.servers.slurm_executor import SlurmExecutor
from researchclaw.servers.cloud_executor import CloudExecutor
# ── fixtures ──────────────────────────────────────────────────────
def _make_server(
name: str = "s1",
host: str = "gpu1.local",
server_type: str = "ssh",
vram_gb: int = 24,
priority: int = 1,
cost: float = 0.0,
scheduler: str = "",
cloud_provider: str = "",
) -> ServerEntry:
return ServerEntry(
name=name,
host=host,
server_type=server_type,
gpu="RTX 4090",
vram_gb=vram_gb,
priority=priority,
cost_per_hour=cost,
scheduler=scheduler,
cloud_provider=cloud_provider,
)
@pytest.fixture
def registry() -> ServerRegistry:
return ServerRegistry([
_make_server("local", "localhost", vram_gb=48, priority=1),
_make_server("cloud1", "cloud.host", server_type="cloud", vram_gb=80, priority=3, cost=2.0, cloud_provider="aws"),
_make_server("hpc", "hpc.host", server_type="slurm", vram_gb=40, priority=2, scheduler="slurm"),
])
# ══════════════════════════════════════════════════════════════════
# ServerEntry tests
# ══════════════════════════════════════════════════════════════════
class TestServerEntry:
def test_to_dict_roundtrip(self) -> None:
s = _make_server()
d = s.to_dict()
s2 = ServerEntry.from_dict(d)
assert s2.name == s.name
assert s2.vram_gb == s.vram_gb
def test_defaults(self) -> None:
s = ServerEntry.from_dict({"name": "x"})
assert s.server_type == "ssh"
assert s.priority == 1
# ══════════════════════════════════════════════════════════════════
# ServerRegistry tests
# ══════════════════════════════════════════════════════════════════
class TestServerRegistry:
def test_list_all_sorted_by_priority(self, registry: ServerRegistry) -> None:
servers = registry.list_all()
priorities = [s.priority for s in servers]
assert priorities == sorted(priorities)
def test_count(self, registry: ServerRegistry) -> None:
assert registry.count == 3
def test_add_server(self) -> None:
reg = ServerRegistry()
reg.add(_make_server("new"))
assert reg.count == 1
assert reg.get("new").name == "new"
def test_remove_server(self, registry: ServerRegistry) -> None:
registry.remove("local")
assert registry.count == 2
def test_remove_unknown_raises(self, registry: ServerRegistry) -> None:
with pytest.raises(KeyError):
registry.remove("ghost")
def test_get_unknown_raises(self, registry: ServerRegistry) -> None:
with pytest.raises(KeyError):
registry.get("ghost")
def test_get_available_excludes(self, registry: ServerRegistry) -> None:
avail = registry.get_available(exclude={"local"})
names = [s.name for s in avail]
assert "local" not in names
assert len(names) == 2
def test_get_best_match_by_vram(self, registry: ServerRegistry) -> None:
best = registry.get_best_match({"min_vram_gb": 40})
assert best is not None
assert best.vram_gb >= 40
def test_get_best_match_by_type(self, registry: ServerRegistry) -> None:
best = registry.get_best_match({"server_type": "slurm"})
assert best is not None
assert best.server_type == "slurm"
def test_get_best_match_prefers_free(self, registry: ServerRegistry) -> None:
best = registry.get_best_match(prefer_free=True)
assert best is not None
assert best.cost_per_hour == 0.0
def test_get_best_match_none_when_impossible(self, registry: ServerRegistry) -> None:
best = registry.get_best_match({"min_vram_gb": 999})
assert best is None
def test_get_best_match_by_gpu(self, registry: ServerRegistry) -> None:
best = registry.get_best_match({"gpu": "RTX"})
assert best is not None
def test_get_best_match_no_requirements(self, registry: ServerRegistry) -> None:
best = registry.get_best_match()
assert best is not None
assert best.name == "local"
# ══════════════════════════════════════════════════════════════════
# ServerMonitor tests
# ══════════════════════════════════════════════════════════════════
class TestServerMonitor:
def test_parse_status_output(self) -> None:
raw = "75, 8000, 24576\n---\n total used free\nMem: 64000 32000 32000\n---\n 10:00:00 up 5 days"
server = _make_server()
status = _parse_status_output(raw, server)
assert status["gpu"]["count"] == 1
assert status["gpu"]["devices"][0]["utilization_pct"] == 75
assert status["memory"]["total_mb"] == 64000
assert "uptime" in status
def test_parse_status_no_gpu(self) -> None:
raw = "\n---\n total used free\nMem: 64000 32000 32000\n---\nup 1 day"
server = _make_server()
status = _parse_status_output(raw, server)
assert status["gpu"]["count"] == 0
def test_get_cached_none(self, registry: ServerRegistry) -> None:
monitor = ServerMonitor(registry)
assert monitor.get_cached("local") is None
def test_get_gpu_usage_empty(self, registry: ServerRegistry) -> None:
monitor = ServerMonitor(registry)
assert monitor.get_gpu_usage(_make_server()) == {}
def test_check_status_unreachable(self, registry: ServerRegistry) -> None:
monitor = ServerMonitor(registry)
with patch("researchclaw.servers.monitor._ssh_command", side_effect=RuntimeError("unreachable")):
status = asyncio.run(monitor.check_status(_make_server()))
assert status["reachable"] is False
def test_check_all(self, registry: ServerRegistry) -> None:
monitor = ServerMonitor(registry)
with patch("researchclaw.servers.monitor._ssh_command", side_effect=RuntimeError("unreachable")):
results = asyncio.run(monitor.check_all())
assert len(results) == 3
for name, status in results.items():
assert status["reachable"] is False
# ══════════════════════════════════════════════════════════════════
# SSHExecutor tests
# ══════════════════════════════════════════════════════════════════
class TestSSHExecutor:
def test_init(self) -> None:
server = _make_server()
exe = SSHExecutor(server)
assert exe.host == "gpu1.local"
def test_run_experiment_timeout(self) -> None:
server = _make_server()
exe = SSHExecutor(server)
async def _run() -> dict:
with patch("asyncio.create_subprocess_exec") as mock_exec:
proc = AsyncMock()
proc.communicate = AsyncMock(side_effect=asyncio.TimeoutError)
proc.kill = AsyncMock()
proc.wait = AsyncMock()
mock_exec.return_value = proc
return await exe.run_experiment("/tmp/test", "echo hello", timeout=1)
result = asyncio.run(_run())
assert result["success"] is False
assert "Timeout" in result["error"]
# ══════════════════════════════════════════════════════════════════
# SlurmExecutor tests
# ══════════════════════════════════════════════════════════════════
class TestSlurmExecutor:
def test_init_wrong_type_raises(self) -> None:
server = _make_server(server_type="ssh")
with pytest.raises(ValueError, match="not a slurm"):
SlurmExecutor(server)
def test_generate_sbatch_script(self) -> None:
server = _make_server(server_type="slurm", scheduler="slurm")
exe = SlurmExecutor(server)
script = exe._generate_sbatch_script("python main.py", resources={"gpus": 2, "mem_gb": 32})
assert "#SBATCH --gres=gpu:2" in script
assert "#SBATCH --mem=32G" in script
assert "python main.py" in script
def test_sbatch_script_default_resources(self) -> None:
server = _make_server(server_type="slurm", scheduler="slurm")
exe = SlurmExecutor(server)
script = exe._generate_sbatch_script("echo hi")
assert "#SBATCH --gres=gpu:1" in script
assert "#SBATCH --time=01:00:00" in script
def test_submit_job_parses_output(self) -> None:
server = _make_server(server_type="slurm", scheduler="slurm")
exe = SlurmExecutor(server)
async def _run() -> str:
with patch("asyncio.create_subprocess_exec") as mock_exec:
proc = AsyncMock()
proc.communicate = AsyncMock(return_value=(b"Submitted batch job 12345\n", b""))
proc.returncode = 0
mock_exec.return_value = proc
return await exe.submit_job("echo hi", "/tmp/test")
job_id = asyncio.run(_run())
assert job_id == "12345"
# ══════════════════════════════════════════════════════════════════
# CloudExecutor tests
# ══════════════════════════════════════════════════════════════════
class TestCloudExecutor:
def test_init_wrong_type_raises(self) -> None:
server = _make_server(server_type="ssh")
with pytest.raises(ValueError, match="not a cloud"):
CloudExecutor(server)
def test_launch_instance_stub(self) -> None:
server = _make_server(server_type="cloud", cloud_provider="aws")
exe = CloudExecutor(server)
result = asyncio.run(exe.launch_instance())
assert result["status"] == "stub_launched"
assert result["provider"] == "aws"
# ══════════════════════════════════════════════════════════════════
# TaskDispatcher tests
# ══════════════════════════════════════════════════════════════════
class TestTaskDispatcher:
def test_dispatch_returns_task_id(self, registry: ServerRegistry) -> None:
monitor = ServerMonitor(registry)
disp = TaskDispatcher(registry, monitor)
task_id = asyncio.run(disp.dispatch({"command": "echo hi", "local_dir": "/tmp"}))
assert len(task_id) == 12
def test_dispatch_no_server_queues(self) -> None:
reg = ServerRegistry()
monitor = ServerMonitor(reg)
disp = TaskDispatcher(reg, monitor)
task_id = asyncio.run(disp.dispatch({"command": "echo hi"}))
status = disp.get_task_status(task_id)
assert status["status"] == "queued"
def test_get_task_status_unknown(self, registry: ServerRegistry) -> None:
monitor = ServerMonitor(registry)
disp = TaskDispatcher(registry, monitor)
status = disp.get_task_status("nonexistent")
assert status["status"] == "unknown"
================================================
FILE: tests/test_skills_library.py
================================================
"""Tests for the dynamic skills library.
Covers:
- Skill schema (agentskills.io data model)
- YAML skill loading (legacy)
- SKILL.md loading (agentskills.io)
- Skill registry (register, query, external dirs)
- Keyword matching + description fallback
- Stage filtering (int + string)
- Prompt formatting
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from researchclaw.skills.schema import Skill, STAGE_NAME_TO_NUMBER
from researchclaw.skills.loader import (
load_skill_file,
load_skill_from_skillmd,
load_skillmd_from_directory,
load_skills_from_directory,
)
from researchclaw.skills.registry import SkillRegistry
from researchclaw.skills.matcher import (
match_skills,
format_skills_for_prompt,
_tokenize,
_resolve_stage,
)
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture
def sample_skill() -> Skill:
return Skill(
name="test-skill-1",
description="A test skill for unit testing",
body="## Test Skill\nDo the thing.",
metadata={
"category": "tooling",
"trigger-keywords": "training,pytorch,gpu",
"applicable-stages": "10,12",
"priority": "5",
"version": "1.0",
"code-template": "print('hello')",
"references": "Test Paper 2024",
},
)
@pytest.fixture
def skill_yaml_dir(tmp_path: Path) -> Path:
d = tmp_path / "skills"
d.mkdir()
skill_data = {
"id": "yaml-skill-1",
"name": "YAML Test Skill",
"category": "experiment",
"description": "Loaded from YAML",
"trigger_keywords": ["review", "literature"],
"applicable_stages": [3, 4, 5],
"prompt_template": "Do literature review",
"version": "1.0",
"priority": 3,
}
import yaml
(d / "test_skill.yaml").write_text(yaml.dump(skill_data), encoding="utf-8")
return d
@pytest.fixture
def skill_json_dir(tmp_path: Path) -> Path:
d = tmp_path / "json_skills"
d.mkdir()
skill_data = {
"id": "json-skill-1",
"name": "JSON Test Skill",
"category": "writing",
"description": "Loaded from JSON",
"trigger_keywords": ["paper", "writing"],
"applicable_stages": [17],
"prompt_template": "Write well",
"version": "1.0",
"priority": 4,
}
(d / "test_skill.json").write_text(
json.dumps(skill_data), encoding="utf-8"
)
return d
@pytest.fixture
def skillmd_dir(tmp_path: Path) -> Path:
"""Create a directory with SKILL.md files for testing."""
d = tmp_path / "skillmd_skills"
d.mkdir()
# Skill with full metadata
s1 = d / "test-skill-md"
s1.mkdir()
(s1 / "SKILL.md").write_text(
"---\n"
"name: test-skill-md\n"
"description: A test skill from SKILL.md\n"
"metadata:\n"
" category: domain\n"
" trigger-keywords: \"nlp,transformer,bert\"\n"
" applicable-stages: \"9,10\"\n"
" priority: \"2\"\n"
"---\n\n"
"## NLP Skill\nDo NLP things.\n",
encoding="utf-8",
)
# Skill with minimal metadata (no trigger-keywords)
s2 = d / "minimal-skill"
s2.mkdir()
(s2 / "SKILL.md").write_text(
"---\n"
"name: minimal-skill\n"
"description: A minimal skill for testing description-based matching\n"
"---\n\n"
"## Minimal\nJust a body.\n",
encoding="utf-8",
)
return d
@pytest.fixture
def external_skillmd_dir(tmp_path: Path) -> Path:
"""Simulates an external skill directory (like Collider-Agent)."""
d = tmp_path / "external"
d.mkdir()
s = d / "hep-feynrules"
s.mkdir()
(s / "SKILL.md").write_text(
"---\n"
"name: hep-feynrules\n"
"description: Generate FeynRules model files for BSM physics\n"
"metadata:\n"
" category: domain\n"
" applicable-stages: \"10\"\n"
"---\n\n"
"## FeynRules Model Generation\n"
"Build BSM model files for MadGraph.\n",
encoding="utf-8",
)
return d
# ── Skill Schema ─────────────────────────────────────────────────────
class TestSkillSchema:
def test_create_skill(self, sample_skill: Skill) -> None:
assert sample_skill.name == "test-skill-1"
assert sample_skill.id == "test-skill-1" # backward compat
assert sample_skill.category == "tooling"
assert len(sample_skill.trigger_keywords) == 3
def test_to_dict(self, sample_skill: Skill) -> None:
d = sample_skill.to_dict()
assert d["id"] == "test-skill-1"
assert d["applicable_stages"] == [10, 12]
assert d["code_template"] == "print('hello')"
def test_from_dict(self) -> None:
data = {
"id": "from-dict",
"name": "From Dict",
"category": "domain",
"description": "Created from dict",
"trigger_keywords": ["test"],
"applicable_stages": [1],
"prompt_template": "test prompt",
}
skill = Skill.from_dict(data)
assert skill.name == "from-dict"
assert skill.priority == 5 # default
def test_from_dict_defaults(self) -> None:
skill = Skill.from_dict({})
assert skill.name == ""
assert skill.version == "1.0"
assert skill.code_template is None
def test_roundtrip(self, sample_skill: Skill) -> None:
d = sample_skill.to_dict()
restored = Skill.from_dict(d)
assert restored.name == sample_skill.name
assert restored.applicable_stages == sample_skill.applicable_stages
def test_stage_name_to_number(self) -> None:
assert STAGE_NAME_TO_NUMBER["code_generation"] == 10
assert STAGE_NAME_TO_NUMBER["paper_draft"] == 17
assert len(STAGE_NAME_TO_NUMBER) == 23
def test_prompt_template_alias(self, sample_skill: Skill) -> None:
assert sample_skill.prompt_template == sample_skill.body
# ── Skill Loader ─────────────────────────────────────────────────────
class TestSkillLoader:
def test_load_yaml(self, skill_yaml_dir: Path) -> None:
skill = load_skill_file(skill_yaml_dir / "test_skill.yaml")
assert skill is not None
assert skill.name == "yaml-skill-1"
assert skill.category == "experiment"
def test_load_json(self, skill_json_dir: Path) -> None:
skill = load_skill_file(skill_json_dir / "test_skill.json")
assert skill is not None
assert skill.name == "json-skill-1"
def test_load_nonexistent(self, tmp_path: Path) -> None:
skill = load_skill_file(tmp_path / "nope.yaml")
assert skill is None
def test_load_invalid_yaml(self, tmp_path: Path) -> None:
bad = tmp_path / "bad.yaml"
bad.write_text("not: [valid: yaml: {", encoding="utf-8")
skill = load_skill_file(bad)
assert skill is None
def test_load_unsupported_format(self, tmp_path: Path) -> None:
txt = tmp_path / "skill.txt"
txt.write_text("id: test", encoding="utf-8")
skill = load_skill_file(txt)
assert skill is None
def test_load_directory(self, skill_yaml_dir: Path) -> None:
skills = load_skills_from_directory(skill_yaml_dir)
assert len(skills) == 1
def test_load_empty_directory(self, tmp_path: Path) -> None:
empty = tmp_path / "empty"
empty.mkdir()
skills = load_skills_from_directory(empty)
assert skills == []
def test_load_missing_directory(self, tmp_path: Path) -> None:
skills = load_skills_from_directory(tmp_path / "nonexistent")
assert skills == []
class TestSkillMdLoader:
def test_load_skillmd(self, skillmd_dir: Path) -> None:
skill = load_skill_from_skillmd(skillmd_dir / "test-skill-md" / "SKILL.md")
assert skill is not None
assert skill.name == "test-skill-md"
assert skill.category == "domain"
assert "nlp" in skill.trigger_keywords
assert skill.applicable_stages == [9, 10]
assert skill.priority == 2
assert "NLP Skill" in skill.body
assert skill.source_format == "skillmd"
def test_load_skillmd_minimal(self, skillmd_dir: Path) -> None:
skill = load_skill_from_skillmd(skillmd_dir / "minimal-skill" / "SKILL.md")
assert skill is not None
assert skill.name == "minimal-skill"
assert skill.trigger_keywords == []
assert skill.applicable_stages == []
assert skill.priority == 5 # default
def test_load_skillmd_missing(self, tmp_path: Path) -> None:
skill = load_skill_from_skillmd(tmp_path / "nope" / "SKILL.md")
assert skill is None
def test_load_skillmd_no_frontmatter(self, tmp_path: Path) -> None:
d = tmp_path / "bad-skill"
d.mkdir()
(d / "SKILL.md").write_text("No frontmatter here", encoding="utf-8")
skill = load_skill_from_skillmd(d / "SKILL.md")
assert skill is None
def test_load_skillmd_directory(self, skillmd_dir: Path) -> None:
skills = load_skillmd_from_directory(skillmd_dir)
assert len(skills) == 2
names = {s.name for s in skills}
assert "test-skill-md" in names
assert "minimal-skill" in names
def test_skillmd_wins_over_yaml(self, tmp_path: Path) -> None:
"""When both SKILL.md and YAML exist for the same name, SKILL.md wins."""
d = tmp_path / "mixed"
d.mkdir()
# YAML file
import yaml
(d / "test-skill-md.yaml").write_text(
yaml.dump({
"id": "test-skill-md",
"name": "test-skill-md",
"category": "tooling",
"description": "From YAML",
"trigger_keywords": ["x"],
"applicable_stages": [1],
"prompt_template": "yaml body",
}),
encoding="utf-8",
)
# SKILL.md file
sd = d / "test-skill-md"
sd.mkdir()
(sd / "SKILL.md").write_text(
"---\nname: test-skill-md\ndescription: From SKILL.md\n---\n\nskillmd body\n",
encoding="utf-8",
)
skills = load_skills_from_directory(d)
matched = [s for s in skills if s.name == "test-skill-md"]
assert len(matched) == 1
assert matched[0].source_format == "skillmd"
assert "From SKILL.md" in matched[0].description
# ── Matcher ──────────────────────────────────────────────────────────
class TestMatcher:
def test_tokenize(self) -> None:
tokens = _tokenize("PyTorch Training GPU")
assert "pytorch" in tokens
assert "training" in tokens
assert "gpu" in tokens
def test_match_by_keyword(self, sample_skill: Skill) -> None:
matched = match_skills(
[sample_skill],
context="training a pytorch model on gpu",
stage=10,
)
assert len(matched) == 1
assert matched[0].name == "test-skill-1"
def test_match_filters_by_stage(self, sample_skill: Skill) -> None:
matched = match_skills(
[sample_skill],
context="training pytorch gpu",
stage=1, # not in applicable_stages
)
assert len(matched) == 0
def test_match_empty_context(self, sample_skill: Skill) -> None:
matched = match_skills([sample_skill], context="", stage=10)
assert len(matched) == 0
def test_match_no_keyword_overlap(self, sample_skill: Skill) -> None:
matched = match_skills(
[sample_skill],
context="linguistics morphology",
stage=10,
)
assert len(matched) == 0
def test_match_respects_top_k(self) -> None:
skills = [
Skill(
name=f"skill-{i}",
description="test",
body="test",
metadata={
"category": "tooling",
"trigger-keywords": "training",
"applicable-stages": "10",
"priority": str(i),
},
)
for i in range(10)
]
matched = match_skills(skills, context="training", stage=10, top_k=3)
assert len(matched) == 3
def test_match_priority_ordering(self) -> None:
high = Skill(
name="high", description="t", body="t",
metadata={
"trigger-keywords": "training",
"applicable-stages": "10",
"priority": "1",
},
)
low = Skill(
name="low", description="t", body="t",
metadata={
"trigger-keywords": "training",
"applicable-stages": "10",
"priority": "9",
},
)
matched = match_skills([low, high], context="training", stage=10)
assert matched[0].name == "high"
def test_match_string_stage(self, sample_skill: Skill) -> None:
"""String stage names should be resolved via STAGE_NAME_TO_NUMBER."""
matched = match_skills(
[sample_skill],
context="training pytorch gpu",
stage="code_generation", # resolves to 10
)
assert len(matched) == 1
assert matched[0].name == "test-skill-1"
def test_match_string_stage_mismatch(self, sample_skill: Skill) -> None:
matched = match_skills(
[sample_skill],
context="training pytorch gpu",
stage="paper_draft", # resolves to 17, not in [10, 12]
)
assert len(matched) == 0
def test_resolve_stage(self) -> None:
assert _resolve_stage(10) == 10
assert _resolve_stage("code_generation") == 10
assert _resolve_stage("unknown_stage") == -1
def test_match_description_fallback(self) -> None:
"""Skills without trigger_keywords should match via description."""
external_skill = Skill(
name="ext-skill",
description="Generate FeynRules model files for BSM physics",
body="Do feynrules things.",
metadata={"applicable-stages": "10"},
)
matched = match_skills(
[external_skill],
context="feynrules model generation",
stage=10,
fallback_matching=True,
)
assert len(matched) == 1
assert matched[0].name == "ext-skill"
def test_match_description_fallback_disabled(self) -> None:
external_skill = Skill(
name="ext-skill",
description="Generate FeynRules model files for BSM physics",
body="Do feynrules things.",
metadata={"applicable-stages": "10"},
)
matched = match_skills(
[external_skill],
context="feynrules model generation",
stage=10,
fallback_matching=False,
)
assert len(matched) == 0
class TestFormatSkills:
def test_format_single_skill(self, sample_skill: Skill) -> None:
text = format_skills_for_prompt([sample_skill])
assert "test-skill-1" in text
assert "tooling" in text
def test_format_empty(self) -> None:
assert format_skills_for_prompt([]) == ""
def test_format_includes_code_template(self, sample_skill: Skill) -> None:
text = format_skills_for_prompt([sample_skill])
assert "print('hello')" in text
def test_format_includes_references(self, sample_skill: Skill) -> None:
text = format_skills_for_prompt([sample_skill])
assert "Test Paper 2024" in text
def test_format_respects_max_chars(self) -> None:
skills = [
Skill(
name=f"s{i}", description="t", body="x" * 500,
metadata={
"category": "tooling",
"trigger-keywords": "t",
},
)
for i in range(10)
]
text = format_skills_for_prompt(skills, max_chars=1000)
assert len(text) <= 1500 # some slack for headers
# ── Registry ─────────────────────────────────────────────────────────
class TestSkillRegistry:
def test_registry_loads_builtins(self) -> None:
registry = SkillRegistry()
assert registry.count() >= 12 # builtin skills (SKILL.md format)
def test_builtin_skillmd_count(self) -> None:
"""All builtin skills should load from SKILL.md."""
registry = SkillRegistry()
assert registry.count() == 16
def test_register_custom(self, sample_skill: Skill) -> None:
registry = SkillRegistry()
initial = registry.count()
registry.register(sample_skill)
assert registry.count() == initial + 1
def test_get_skill(self, sample_skill: Skill) -> None:
registry = SkillRegistry()
registry.register(sample_skill)
got = registry.get("test-skill-1")
assert got is not None
assert got.name == "test-skill-1"
def test_get_nonexistent(self) -> None:
registry = SkillRegistry()
assert registry.get("nonexistent") is None
def test_unregister(self, sample_skill: Skill) -> None:
registry = SkillRegistry()
registry.register(sample_skill)
assert registry.unregister("test-skill-1")
assert registry.get("test-skill-1") is None
def test_unregister_nonexistent(self) -> None:
registry = SkillRegistry()
assert not registry.unregister("nope")
def test_list_by_category(self) -> None:
registry = SkillRegistry()
tooling = registry.list_by_category("tooling")
assert len(tooling) > 0
assert all(s.category == "tooling" for s in tooling)
def test_list_by_stage(self) -> None:
registry = SkillRegistry()
stage_10 = registry.list_by_stage(10)
assert len(stage_10) > 0
def test_match(self) -> None:
registry = SkillRegistry()
matched = registry.match("pytorch training classification cifar", stage=10)
assert len(matched) > 0
def test_match_string_stage(self) -> None:
registry = SkillRegistry()
matched = registry.match(
"pytorch training classification",
stage="code_generation",
)
assert len(matched) > 0
def test_export_for_prompt(self) -> None:
registry = SkillRegistry()
matched = registry.match("pytorch training", stage=10, top_k=2)
text = registry.export_for_prompt(matched)
assert len(text) > 0
def test_custom_dir_loading(self, skill_yaml_dir: Path) -> None:
registry = SkillRegistry(custom_dirs=[str(skill_yaml_dir)])
skill = registry.get("yaml-skill-1")
assert skill is not None
def test_registry_external_dirs(self, external_skillmd_dir: Path) -> None:
registry = SkillRegistry(external_dirs=[str(external_skillmd_dir)])
assert registry.count() == 17 # 16 builtin + 1 external
skill = registry.get("hep-feynrules")
assert skill is not None
assert skill.category == "domain"
def test_registry_external_match_fallback(
self, external_skillmd_dir: Path
) -> None:
"""External skills without trigger_keywords should match via description."""
registry = SkillRegistry(
external_dirs=[str(external_skillmd_dir)],
fallback_matching=True,
)
matched = registry.match("feynrules model generation", stage=10, top_k=10)
names = [s.name for s in matched]
assert "hep-feynrules" in names
================================================
FILE: tests/test_ssh_and_colab_sandbox.py
================================================
# pyright: reportPrivateUsage=false, reportUnknownParameterType=false
"""Tests for ssh_remote and colab_drive experiment backends."""
from __future__ import annotations
import json
import textwrap
import time
from pathlib import Path
from unittest import mock
import pytest
from researchclaw.config import (
ColabDriveConfig,
ExperimentConfig,
SandboxConfig,
SshRemoteConfig,
DockerSandboxConfig,
CodeAgentConfig,
BenchmarkAgentConfig,
FigureAgentConfig,
)
from researchclaw.experiment.ssh_sandbox import (
SshRemoteSandbox,
_build_ssh_base,
_ssh_target,
)
from researchclaw.experiment.colab_sandbox import (
ColabDriveSandbox,
COLAB_WORKER_TEMPLATE,
)
from researchclaw.experiment.factory import create_sandbox
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_experiment_config(**overrides) -> ExperimentConfig:
defaults = dict(
sandbox=SandboxConfig(),
docker=DockerSandboxConfig(),
ssh_remote=SshRemoteConfig(),
colab_drive=ColabDriveConfig(),
code_agent=CodeAgentConfig(),
benchmark_agent=BenchmarkAgentConfig(),
figure_agent=FigureAgentConfig(),
)
defaults.update(overrides)
return ExperimentConfig(**defaults)
# ===========================================================================
# SSH Remote: unit tests
# ===========================================================================
class TestSshTarget:
def test_with_user(self):
cfg = SshRemoteConfig(host="gpu.lab.edu", user="alice")
assert _ssh_target(cfg) == "alice@gpu.lab.edu"
def test_without_user(self):
cfg = SshRemoteConfig(host="gpu.lab.edu")
assert _ssh_target(cfg) == "gpu.lab.edu"
class TestBuildSshBase:
def test_default_port(self):
cfg = SshRemoteConfig(host="server", user="bob")
cmd = _build_ssh_base(cfg)
assert "ssh" in cmd
assert "bob@server" in cmd
assert "-p" not in cmd
def test_custom_port(self):
cfg = SshRemoteConfig(host="server", user="bob", port=2222)
cmd = _build_ssh_base(cfg)
idx = cmd.index("-p")
assert cmd[idx + 1] == "2222"
def test_key_path(self):
cfg = SshRemoteConfig(host="server", key_path="~/.ssh/my_key")
cmd = _build_ssh_base(cfg)
assert "-i" in cmd
class TestSshRemoteSandboxCommands:
def test_bare_exec_cmd(self, tmp_path: Path):
cfg = SshRemoteConfig(
host="server", user="test", gpu_ids=(0, 1),
remote_python="python3",
)
sb = SshRemoteSandbox(cfg, tmp_path)
cmd = sb._build_bare_exec_cmd("/tmp/rc-test", entry_point="main.py")
assert "CUDA_VISIBLE_DEVICES=0,1" in cmd
assert "HOME=/tmp/rc-test" in cmd
assert "python3 -u main.py" in cmd
assert "unshare --net" in cmd
def test_bare_exec_no_gpu(self, tmp_path: Path):
cfg = SshRemoteConfig(host="server", user="test")
sb = SshRemoteSandbox(cfg, tmp_path)
cmd = sb._build_bare_exec_cmd("/tmp/rc-test", entry_point="main.py")
assert "CUDA_VISIBLE_DEVICES" not in cmd
def test_docker_exec_cmd(self, tmp_path: Path):
cfg = SshRemoteConfig(
host="server", user="test",
use_docker=True,
docker_image="myimage:latest",
docker_network_policy="none",
docker_memory_limit_mb=4096,
docker_shm_size_mb=1024,
gpu_ids=(0,),
)
sb = SshRemoteSandbox(cfg, tmp_path)
cmd = sb._build_docker_exec_cmd("/tmp/rc-test", entry_point="main.py")
assert "docker run --rm" in cmd
assert "-v /tmp/rc-test:/workspace" in cmd
assert "--network none" in cmd
assert "--memory=4096m" in cmd
assert "--shm-size=1024m" in cmd
assert "device=0" in cmd
assert "myimage:latest" in cmd
assert cmd.endswith("main.py")
def test_docker_exec_full_network(self, tmp_path: Path):
cfg = SshRemoteConfig(
host="server", use_docker=True,
docker_network_policy="full",
)
sb = SshRemoteSandbox(cfg, tmp_path)
cmd = sb._build_docker_exec_cmd("/tmp/rc-test", entry_point="main.py")
assert "--network" not in cmd
# ── Entry point path traversal validation ─────────────────────────────
class TestSshEntryPointValidation:
def test_run_project_rejects_path_traversal(self, tmp_path: Path):
"""run_project() must reject entry_point with '..' components."""
project = tmp_path / "proj"
project.mkdir()
(project / "main.py").write_text("print('hi')")
cfg = SshRemoteConfig(host="server", user="test")
work = tmp_path / "work"
sandbox = SshRemoteSandbox(cfg, work)
# Create escape target so .exists() alone wouldn't catch it
work.mkdir(parents=True, exist_ok=True)
(work / "escape.py").write_text("print('escaped!')")
# Mock _execute to ensure it's never reached
sandbox._execute = mock.MagicMock() # type: ignore[assignment]
result = sandbox.run_project(project, entry_point="../escape.py")
assert result.returncode == -1
assert ".." in result.stderr
sandbox._execute.assert_not_called()
def test_run_project_rejects_absolute_path(self, tmp_path: Path):
"""run_project() must reject absolute entry_point paths."""
project = tmp_path / "proj"
project.mkdir()
(project / "main.py").write_text("print('hi')")
cfg = SshRemoteConfig(host="server", user="test")
sandbox = SshRemoteSandbox(cfg, tmp_path / "work")
sandbox._execute = mock.MagicMock() # type: ignore[assignment]
result = sandbox.run_project(project, entry_point="/etc/passwd")
assert result.returncode == -1
assert "relative" in result.stderr.lower() or "absolute" in result.stderr.lower()
sandbox._execute.assert_not_called()
class TestSshConnectivityCheck:
def test_empty_host(self):
cfg = SshRemoteConfig(host="")
ok, msg = SshRemoteSandbox.check_ssh_available(cfg)
assert not ok
assert "empty" in msg
def test_unreachable_host(self):
cfg = SshRemoteConfig(host="nonexistent-host-12345.invalid")
ok, msg = SshRemoteSandbox.check_ssh_available(cfg)
assert not ok
class TestSshSandboxRun:
"""Test run() with mocked SSH commands."""
def test_run_success(self, tmp_path: Path):
cfg = SshRemoteConfig(host="fake", user="test")
sb = SshRemoteSandbox(cfg, tmp_path)
fake_results = [
mock.Mock(returncode=0, stdout="", stderr=""), # mkdir
mock.Mock(returncode=0, stdout="accuracy: 0.95\nloss: 0.05", stderr=""), # exec
mock.Mock(returncode=0, stdout="", stderr=""), # cleanup
]
call_count = [0]
def fake_ssh_run(command, *, timeout_sec=60):
from researchclaw.experiment.ssh_sandbox import _SshResult
idx = min(call_count[0], len(fake_results) - 1)
r = fake_results[idx]
call_count[0] += 1
return _SshResult(
returncode=r.returncode,
stdout=r.stdout,
stderr=r.stderr,
)
def fake_scp(local_dir, remote_dir):
return True
with mock.patch.object(sb, '_ssh_run', side_effect=fake_ssh_run):
with mock.patch.object(sb, '_scp_upload', side_effect=fake_scp):
result = sb.run("print('hello')", timeout_sec=60)
assert result.returncode == 0
assert result.metrics.get("accuracy") == 0.95
assert result.metrics.get("loss") == 0.05
def test_run_upload_failure(self, tmp_path: Path):
cfg = SshRemoteConfig(host="fake", user="test")
sb = SshRemoteSandbox(cfg, tmp_path)
from researchclaw.experiment.ssh_sandbox import _SshResult
with mock.patch.object(sb, '_ssh_run', return_value=_SshResult(0, "", "")):
with mock.patch.object(sb, '_scp_upload', return_value=False):
result = sb.run("print('hello')")
assert result.returncode == -1
assert "Failed to upload" in result.stderr
# ===========================================================================
# Colab Drive: unit tests
# ===========================================================================
class TestColabDriveCheck:
def test_empty_root(self):
cfg = ColabDriveConfig(drive_root="")
ok, msg = ColabDriveSandbox.check_drive_available(cfg)
assert not ok
assert "empty" in msg
def test_nonexistent_root(self):
cfg = ColabDriveConfig(drive_root="/nonexistent/path/12345")
ok, msg = ColabDriveSandbox.check_drive_available(cfg)
assert not ok
assert "not found" in msg
def test_existing_root(self, tmp_path: Path):
cfg = ColabDriveConfig(drive_root=str(tmp_path))
ok, msg = ColabDriveSandbox.check_drive_available(cfg)
assert ok
class TestColabDriveSandbox:
def test_submit_and_collect(self, tmp_path: Path):
"""Simulate the full flow: submit task → worker picks up → collect result."""
drive_root = tmp_path / "drive"
drive_root.mkdir()
cfg = ColabDriveConfig(
drive_root=str(drive_root),
poll_interval_sec=1,
timeout_sec=10,
)
sb = ColabDriveSandbox(cfg, tmp_path / "workdir")
# Simulate worker in a thread: move pending → done with result
import threading
def fake_worker():
pending = drive_root / "pending"
done = drive_root / "done"
for _ in range(20): # poll for up to 20 seconds
if pending.exists():
for task_dir in pending.iterdir():
if task_dir.is_dir():
done.mkdir(parents=True, exist_ok=True)
done_dir = done / task_dir.name
task_dir.rename(done_dir)
(done_dir / "result.json").write_text(json.dumps({
"returncode": 0,
"stdout": "primary_metric: 42.0\naccuracy: 0.99",
"stderr": "",
}))
return
time.sleep(0.5)
worker = threading.Thread(target=fake_worker, daemon=True)
worker.start()
result = sb.run("print('experiment')", timeout_sec=15)
worker.join(timeout=5)
assert result.returncode == 0
assert result.metrics.get("primary_metric") == 42.0
assert result.metrics.get("accuracy") == 0.99
def test_timeout(self, tmp_path: Path):
"""If worker never picks up, should timeout."""
drive_root = tmp_path / "drive"
drive_root.mkdir()
cfg = ColabDriveConfig(
drive_root=str(drive_root),
poll_interval_sec=1,
timeout_sec=3,
)
sb = ColabDriveSandbox(cfg, tmp_path / "workdir")
result = sb.run("print('hello')", timeout_sec=3)
assert result.timed_out
assert result.returncode == -1
assert "did not complete" in result.stderr
def test_setup_script_written(self, tmp_path: Path):
drive_root = tmp_path / "drive"
drive_root.mkdir()
cfg = ColabDriveConfig(
drive_root=str(drive_root),
poll_interval_sec=1,
timeout_sec=3,
setup_script="pip install torch -q",
)
sb = ColabDriveSandbox(cfg, tmp_path / "workdir")
# Just submit, don't wait for result
staging = tmp_path / "workdir" / "_colab_1"
staging.mkdir(parents=True, exist_ok=True)
(staging / "main.py").write_text("print('hi')")
sb._write_setup_script(staging)
setup_sh = staging / "setup.sh"
assert setup_sh.exists()
content = setup_sh.read_text()
assert "pip install torch -q" in content
class TestColabWorkerTemplate:
def test_template_not_empty(self):
assert len(COLAB_WORKER_TEMPLATE) > 100
def test_template_has_key_elements(self):
assert "pending" in COLAB_WORKER_TEMPLATE
assert "done" in COLAB_WORKER_TEMPLATE
assert "result.json" in COLAB_WORKER_TEMPLATE
assert "drive.mount" in COLAB_WORKER_TEMPLATE
# ===========================================================================
# Factory integration tests
# ===========================================================================
class TestFactoryIntegration:
def test_ssh_remote_requires_host(self, tmp_path: Path):
cfg = _make_experiment_config(
mode="ssh_remote",
ssh_remote=SshRemoteConfig(host=""),
)
with pytest.raises(RuntimeError, match="host"):
create_sandbox(cfg, tmp_path)
def test_ssh_remote_checks_connectivity(self, tmp_path: Path):
cfg = _make_experiment_config(
mode="ssh_remote",
ssh_remote=SshRemoteConfig(host="nonexistent.invalid"),
)
with pytest.raises(RuntimeError, match="SSH connectivity"):
create_sandbox(cfg, tmp_path)
def test_colab_drive_requires_root(self, tmp_path: Path):
cfg = _make_experiment_config(
mode="colab_drive",
colab_drive=ColabDriveConfig(drive_root=""),
)
with pytest.raises(RuntimeError, match="empty"):
create_sandbox(cfg, tmp_path)
def test_colab_drive_checks_path(self, tmp_path: Path):
cfg = _make_experiment_config(
mode="colab_drive",
colab_drive=ColabDriveConfig(drive_root="/nonexistent/12345"),
)
with pytest.raises(RuntimeError, match="not found"):
create_sandbox(cfg, tmp_path)
def test_colab_drive_creates_sandbox(self, tmp_path: Path):
drive_root = tmp_path / "drive"
drive_root.mkdir()
cfg = _make_experiment_config(
mode="colab_drive",
colab_drive=ColabDriveConfig(drive_root=str(drive_root)),
)
sb = create_sandbox(cfg, tmp_path / "workdir")
assert isinstance(sb, ColabDriveSandbox)
# ===========================================================================
# ACP timeout fix test
# ===========================================================================
class TestAcpTimeoutFix:
def test_timeout_passed_from_config(self):
from researchclaw.config import RCConfig, AcpConfig, LlmConfig
from researchclaw.llm.acp_client import ACPClient, ACPConfig
acp_cfg = AcpConfig(agent="codex", timeout_sec=1500)
llm_cfg = LlmConfig(provider="acp", acp=acp_cfg)
# Simulate RCConfig with just the fields ACPClient.from_rc_config uses
fake_rc = mock.Mock()
fake_rc.llm = llm_cfg
client = ACPClient.from_rc_config(fake_rc)
assert client.config.timeout_sec == 1500
def test_timeout_default(self):
from researchclaw.llm.acp_client import ACPClient
fake_rc = mock.Mock()
fake_rc.llm.acp.agent = "claude"
fake_rc.llm.acp.cwd = "."
fake_rc.llm.acp.acpx_command = ""
fake_rc.llm.acp.session_name = "test"
fake_rc.llm.acp.timeout_sec = 600
client = ACPClient.from_rc_config(fake_rc)
assert client.config.timeout_sec == 600
# ===========================================================================
# ACP session reconnect tests (Issue #52)
# ===========================================================================
class TestAcpSessionReconnect:
def test_reconnect_on_session_died(self):
"""_send_prompt retries when session dies with 'agent needs reconnect'."""
from researchclaw.llm.acp_client import ACPClient, ACPConfig
client = ACPClient(ACPConfig(agent="claude"))
client._acpx = "/usr/bin/true"
client._session_ready = True
call_count = 0
def fake_cli(acpx: str, prompt: str) -> str:
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("ACP prompt failed (exit 1): agent needs reconnect")
return "success response"
client._send_prompt_cli = fake_cli # type: ignore[assignment]
client._ensure_session = lambda: None # type: ignore[assignment]
client._force_reconnect = lambda: None # type: ignore[assignment]
result = client._send_prompt("test prompt")
assert result == "success response"
assert call_count == 2
def test_reconnect_exhausted_raises(self):
"""_send_prompt raises after exhausting reconnect attempts."""
from researchclaw.llm.acp_client import ACPClient, ACPConfig
client = ACPClient(ACPConfig(agent="claude"))
client._acpx = "/usr/bin/true"
client._session_ready = True
def always_fail(acpx: str, prompt: str) -> str:
raise RuntimeError("ACP prompt failed (exit 1): session not found")
client._send_prompt_cli = always_fail # type: ignore[assignment]
client._ensure_session = lambda: None # type: ignore[assignment]
client._force_reconnect = lambda: None # type: ignore[assignment]
import pytest
with pytest.raises(RuntimeError, match="session not found"):
client._send_prompt("test prompt")
def test_non_reconnectable_error_raises_immediately(self):
"""_send_prompt does not retry on non-session errors."""
from researchclaw.llm.acp_client import ACPClient, ACPConfig
client = ACPClient(ACPConfig(agent="claude"))
client._acpx = "/usr/bin/true"
client._session_ready = True
call_count = 0
def fail_with_other_error(acpx: str, prompt: str) -> str:
nonlocal call_count
call_count += 1
raise RuntimeError("ACP prompt failed (exit 1): permission denied")
client._send_prompt_cli = fail_with_other_error # type: ignore[assignment]
client._ensure_session = lambda: None # type: ignore[assignment]
import pytest
with pytest.raises(RuntimeError, match="permission denied"):
client._send_prompt("test prompt")
assert call_count == 1 # no retry
================================================
FILE: tests/test_trends.py
================================================
"""Tests for researchclaw.trends — Research Trend Tracker (Agent D1).
25+ tests covering feeds, trend_analyzer, opportunity_finder,
daily_digest, auto_topic, and literature/trends.
"""
from __future__ import annotations
import asyncio
from datetime import date
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.trends.feeds import FeedManager
from researchclaw.trends.trend_analyzer import TrendAnalyzer, _STOPWORDS
from researchclaw.trends.opportunity_finder import OpportunityFinder
from researchclaw.trends.daily_digest import DailyDigest
from researchclaw.trends.auto_topic import AutoTopicGenerator
from researchclaw.literature.trends import LiteratureTrendAnalyzer
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_papers(n: int = 10) -> list[dict[str, Any]]:
"""Generate synthetic papers for testing."""
papers = []
for i in range(n):
papers.append({
"title": f"Transformer attention mechanism for graph neural networks part {i}",
"authors": [
{"name": "Alice Smith"},
{"name": "Bob Jones"},
] if i % 2 == 0 else ["Alice Smith", "Charlie Brown"],
"abstract": (
"We propose a transformer-based attention approach for "
"graph neural networks using contrastive learning on ImageNet "
"and CIFAR datasets. Our diffusion model achieves SOTA results."
),
"url": f"https://arxiv.org/abs/2026.{i:05d}",
"source": "arxiv" if i % 2 == 0 else "semantic_scholar",
"published_date": "2026-03-01",
})
return papers
class MockLLM:
async def chat_async(self, prompt: str) -> str:
return (
"TOPIC: Graph transformers for drug discovery | "
"WHY: Rising trend | FEASIBILITY: high\n"
"TOPIC: Diffusion models for 3D generation | "
"WHY: New paradigm | FEASIBILITY: medium\n"
)
class FailingLLM:
async def chat_async(self, prompt: str) -> str:
raise RuntimeError("API error")
# ===================================================================
# FeedManager tests
# ===================================================================
class TestFeedManager:
def test_init_filters_supported_sources(self):
fm = FeedManager(sources=("arxiv", "invalid_source", "semantic_scholar"))
assert fm.sources == ("arxiv", "semantic_scholar")
def test_supported_sources(self):
assert "arxiv" in FeedManager.SUPPORTED_SOURCES
assert "semantic_scholar" in FeedManager.SUPPORTED_SOURCES
assert "openalex" in FeedManager.SUPPORTED_SOURCES
def test_fetch_deduplicates_by_title(self):
fm = FeedManager(sources=("arxiv",))
# Mock _fetch_arxiv to return duplicates
papers = [
{"title": "Same Title", "source": "arxiv"},
{"title": "Same Title", "source": "arxiv"},
{"title": "Different Title", "source": "arxiv"},
]
with patch.object(fm, "_fetch_arxiv", return_value=papers):
result = fm.fetch_recent_papers(["ml"], max_papers=10)
assert len(result) == 2
def test_fetch_respects_max_papers(self):
fm = FeedManager(sources=("arxiv",))
papers = [{"title": f"Paper {i}", "source": "arxiv"} for i in range(20)]
with patch.object(fm, "_fetch_arxiv", return_value=papers):
result = fm.fetch_recent_papers(["ml"], max_papers=5)
assert len(result) == 5
def test_fetch_handles_source_failure(self):
fm = FeedManager(sources=("arxiv",))
with patch.object(fm, "_fetch_arxiv", side_effect=RuntimeError("fail")):
result = fm.fetch_recent_papers(["ml"])
assert result == []
def test_fetch_empty_title_excluded(self):
fm = FeedManager(sources=("arxiv",))
papers = [
{"title": "", "source": "arxiv"},
{"title": " ", "source": "arxiv"},
{"title": "Valid Paper", "source": "arxiv"},
]
with patch.object(fm, "_fetch_arxiv", return_value=papers):
result = fm.fetch_recent_papers(["ml"])
assert len(result) == 1
# ===================================================================
# TrendAnalyzer tests
# ===================================================================
class TestTrendAnalyzer:
def test_analyze_empty(self):
analyzer = TrendAnalyzer()
result = analyzer.analyze([])
assert result["paper_count"] == 0
assert result["rising_keywords"] == []
def test_analyze_extracts_keywords(self):
analyzer = TrendAnalyzer()
papers = _make_papers(10)
result = analyzer.analyze(papers)
assert result["paper_count"] == 10
assert len(result["rising_keywords"]) > 0
def test_keywords_exclude_stopwords(self):
analyzer = TrendAnalyzer()
papers = _make_papers(10)
result = analyzer.analyze(papers)
for kw in result["rising_keywords"]:
for word in kw["keyword"].split():
assert word not in _STOPWORDS
def test_extract_authors_dict_format(self):
analyzer = TrendAnalyzer()
papers = [
{"authors": [{"name": "Alice"}, {"name": "Bob"}]} for _ in range(5)
]
authors = analyzer._extract_authors(papers)
assert any(a["author"] == "Alice" for a in authors)
def test_extract_authors_string_format(self):
analyzer = TrendAnalyzer()
papers = [{"authors": ["Alice", "Bob"]} for _ in range(5)]
authors = analyzer._extract_authors(papers)
assert any(a["author"] == "Alice" for a in authors)
def test_extract_datasets(self):
analyzer = TrendAnalyzer()
papers = [
{"title": "Training on ImageNet and CIFAR", "abstract": ""},
{"title": "MNIST results", "abstract": "evaluated on GLUE benchmark"},
]
datasets = analyzer._extract_datasets(papers)
ds_names = {d["dataset"] for d in datasets}
assert "ImageNet" in ds_names
assert "CIFAR" in ds_names
def test_extract_methods(self):
analyzer = TrendAnalyzer()
papers = [
{"title": "Transformer attention", "abstract": "using diffusion models"},
{"title": "GAN for images", "abstract": "contrastive learning approach"},
]
methods = analyzer._extract_methods(papers)
method_names = {m["method"] for m in methods}
assert "transformer" in method_names or "attention" in method_names
def test_tokenize(self):
tokens = TrendAnalyzer._tokenize("Hello World! It's a test-case.")
assert "hello" in tokens
assert "world" in tokens
assert "it's" in tokens
assert "test-case" in tokens
def test_source_distribution(self):
papers = [
{"source": "arxiv"},
{"source": "arxiv"},
{"source": "semantic_scholar"},
]
dist = TrendAnalyzer._source_distribution(papers)
assert dist["arxiv"] == 2
assert dist["semantic_scholar"] == 1
def test_generate_trend_report(self):
analyzer = TrendAnalyzer()
analysis = analyzer.analyze(_make_papers(10))
report = analyzer.generate_trend_report(analysis)
assert "Research Trend Analysis" in report
assert "10 papers" in report
def test_min_keyword_length(self):
analyzer = TrendAnalyzer(min_keyword_length=5)
papers = [{"title": "AI is a big deal", "abstract": ""}] * 5
keywords = analyzer._extract_keywords(papers)
# Short words like "deal" (4 chars) should be excluded by min_keyword_length=5
# but "big" is only 3 chars so excluded too
for kw in keywords:
for word in kw["keyword"].split():
assert len(word) >= 5 or word in _STOPWORDS
# ===================================================================
# OpportunityFinder tests
# ===================================================================
class TestOpportunityFinder:
def test_heuristic_no_llm(self):
finder = OpportunityFinder()
trend_analysis = {
"rising_keywords": [
{"keyword": "graph neural", "count": 10},
{"keyword": "attention", "count": 8},
{"keyword": "diffusion", "count": 6},
],
"method_trends": [
{"method": "transformer", "mention_count": 12},
{"method": "contrastive learning", "mention_count": 7},
],
}
result = asyncio.run(finder.find_opportunities(trend_analysis, ["ml"]))
assert len(result) > 0
assert all("topic" in opp for opp in result)
assert all(opp["source"] == "heuristic" for opp in result)
def test_heuristic_empty_trends(self):
finder = OpportunityFinder()
result = asyncio.run(finder.find_opportunities(
{"rising_keywords": [], "method_trends": []}, ["ml"]
))
assert result == []
def test_llm_path(self):
finder = OpportunityFinder(llm_client=MockLLM())
trend_analysis = {
"rising_keywords": [{"keyword": "graph", "count": 10}],
"method_trends": [{"method": "transformer", "mention_count": 5}],
}
result = asyncio.run(finder.find_opportunities(trend_analysis, ["ml"]))
assert len(result) >= 1
assert result[0]["source"] == "llm"
def test_llm_fallback_on_failure(self):
finder = OpportunityFinder(llm_client=FailingLLM())
trend_analysis = {
"rising_keywords": [{"keyword": "test", "count": 5}],
"method_trends": [{"method": "GAN", "mention_count": 3}],
}
result = asyncio.run(finder.find_opportunities(trend_analysis, ["ml"]))
assert all(opp["source"] == "heuristic" for opp in result)
def test_parse_opportunities(self):
response = (
"TOPIC: Adaptive transformers | WHY: Trending | FEASIBILITY: high\n"
"TOPIC: Diffusion for audio | WHY: New area | FEASIBILITY: medium\n"
"Some noise line\n"
)
result = OpportunityFinder._parse_opportunities(response)
assert len(result) == 2
assert result[0]["topic"] == "Adaptive transformers"
assert result[0]["feasibility"] == "high"
def test_heuristic_max_five(self):
finder = OpportunityFinder()
trend_analysis = {
"rising_keywords": [
{"keyword": f"kw{i}", "count": 10} for i in range(10)
],
"method_trends": [
{"method": f"method{i}", "mention_count": 5} for i in range(10)
],
}
result = asyncio.run(finder.find_opportunities(trend_analysis, ["ml"]))
assert len(result) <= 5
# ===================================================================
# DailyDigest tests
# ===================================================================
class TestDailyDigest:
def test_generate_basic_no_papers(self):
fm = FeedManager(sources=())
digest = DailyDigest(fm)
result = asyncio.run(digest.generate(["ml"]))
assert "No new papers found" in result
def test_generate_basic_with_papers(self):
fm = FeedManager(sources=("arxiv",))
papers = _make_papers(3)
with patch.object(fm, "fetch_recent_papers", return_value=papers):
digest = DailyDigest(fm)
result = asyncio.run(digest.generate(["ml"]))
assert "Daily Paper Digest" in result
assert "Papers found: 3" in result
def test_generate_basic_truncates_abstract(self):
fm = FeedManager(sources=("arxiv",))
papers = [{"title": "Test", "abstract": "x" * 500, "authors": [], "url": ""}]
with patch.object(fm, "fetch_recent_papers", return_value=papers):
digest = DailyDigest(fm)
result = asyncio.run(digest.generate(["ml"]))
assert "..." in result
def test_parse_summary_valid(self):
response = "SUMMARY: Great paper on attention | RELEVANCE: 4"
summary, relevance = DailyDigest._parse_summary(response)
assert summary == "Great paper on attention"
assert relevance == 4
def test_parse_summary_no_format(self):
response = "Just a plain text response."
summary, relevance = DailyDigest._parse_summary(response)
assert summary == response
assert relevance == 3 # default
def test_parse_summary_clamped(self):
response = "SUMMARY: x | RELEVANCE: 99"
_, relevance = DailyDigest._parse_summary(response)
assert relevance == 5
def test_generate_and_save(self, tmp_path: Path):
fm = FeedManager(sources=("arxiv",))
papers = _make_papers(2)
with patch.object(fm, "fetch_recent_papers", return_value=papers):
digest = DailyDigest(fm)
result_path = asyncio.run(digest.generate_and_save(tmp_path, ["ml"]))
assert result_path.exists()
assert result_path.read_text(encoding="utf-8").startswith("## Daily Paper Digest")
def test_author_formatting_dict(self):
fm = FeedManager(sources=("arxiv",))
papers = [{
"title": "T",
"abstract": "",
"url": "",
"authors": [{"name": "A"}, {"name": "B"}, {"name": "C"}, {"name": "D"}],
}]
with patch.object(fm, "fetch_recent_papers", return_value=papers):
digest = DailyDigest(fm)
result = asyncio.run(digest.generate(["ml"]))
assert "et al." in result
# ===================================================================
# AutoTopicGenerator tests
# ===================================================================
class TestAutoTopicGenerator:
def test_generate_candidates(self):
analyzer = TrendAnalyzer()
finder = OpportunityFinder()
gen = AutoTopicGenerator(analyzer, finder)
papers = _make_papers(10)
candidates = asyncio.run(gen.generate_candidates(["ml"], papers, count=3))
assert len(candidates) <= 3
if candidates:
assert "topic" in candidates[0]
assert "overall_score" in candidates[0]
def test_generate_candidates_empty(self):
analyzer = TrendAnalyzer()
finder = OpportunityFinder()
gen = AutoTopicGenerator(analyzer, finder)
candidates = asyncio.run(gen.generate_candidates(["ml"], [], count=3))
# With empty papers, heuristic has no keywords/methods → no opportunities
assert isinstance(candidates, list)
def test_auto_select_default_fallback(self):
analyzer = TrendAnalyzer()
finder = OpportunityFinder()
gen = AutoTopicGenerator(analyzer, finder)
result = asyncio.run(gen.auto_select(["ml"], []))
assert "topic" in result
assert result["source"] == "default"
def test_score_candidate_feasibility(self):
opp_high = {"topic": "unique topic xyz", "feasibility": "high"}
opp_low = {"topic": "unique topic xyz", "feasibility": "low"}
trend = {"rising_keywords": [], "paper_count": 50}
score_h = AutoTopicGenerator._score_candidate(opp_high, trend)
score_l = AutoTopicGenerator._score_candidate(opp_low, trend)
assert score_h["feasibility"] == 0.9
assert score_l["feasibility"] == 0.3
assert score_h["overall"] > score_l["overall"]
def test_score_candidate_novelty_decay(self):
opp = {"topic": "graph neural", "feasibility": "medium"}
trend = {
"rising_keywords": [
{"keyword": "graph neural", "count": 10},
{"keyword": "neural network", "count": 8},
],
"paper_count": 50,
}
score = AutoTopicGenerator._score_candidate(opp, trend)
assert score["novelty"] < 1.0 # overlap penalizes novelty
def test_score_candidate_weights(self):
"""Overall = 0.4*novelty + 0.3*feasibility + 0.3*impact."""
opp = {"topic": "totally unique xyz", "feasibility": "high"}
trend = {"rising_keywords": [], "paper_count": 50}
score = AutoTopicGenerator._score_candidate(opp, trend)
expected = round(0.4 * score["novelty"] + 0.3 * score["feasibility"] + 0.3 * score["impact"], 3)
assert score["overall"] == expected
def test_format_candidates_empty(self):
analyzer = TrendAnalyzer()
finder = OpportunityFinder()
gen = AutoTopicGenerator(analyzer, finder)
assert "No candidate" in gen.format_candidates([])
def test_format_candidates_with_data(self):
analyzer = TrendAnalyzer()
finder = OpportunityFinder()
gen = AutoTopicGenerator(analyzer, finder)
candidates = [
{
"topic": "Test topic",
"overall_score": 0.75,
"novelty_score": 0.8,
"feasibility_score": 0.7,
"impact_score": 0.6,
"rationale": "Good idea",
}
]
output = gen.format_candidates(candidates)
assert "Test topic" in output
assert "0.75" in output
# ===================================================================
# LiteratureTrendAnalyzer tests
# ===================================================================
class TestLiteratureTrendAnalyzer:
def test_no_client_returns_empty(self):
lta = LiteratureTrendAnalyzer()
assert lta.get_daily_papers(["ml"]) == []
def test_analyze_keyword_trends_no_client(self):
lta = LiteratureTrendAnalyzer()
result = lta.analyze_keyword_trends(["ml"])
assert result["total_papers"] == 0
def test_find_emerging_topics_no_client(self):
lta = LiteratureTrendAnalyzer()
assert lta.find_emerging_topics(["ml"]) == []
def test_find_emerging_topics_filters_bigrams(self):
"""Only bigrams with count >= 3 are considered emerging."""
lta = LiteratureTrendAnalyzer(search_client="fake")
papers = _make_papers(20)
with patch.object(lta, "get_daily_papers", return_value=papers):
topics = lta.find_emerging_topics(["ml"])
for t in topics:
assert t["type"] == "bigram"
assert t["frequency"] >= 3
================================================
FILE: tests/test_universal_codegen_integration.py
================================================
"""Integration tests for universal cross-domain code generation.
Tests the full pipeline from domain detection → adapter selection →
prompt block generation → blueprint context building, across multiple
research domains. These tests do NOT require an LLM or network —
they verify the infrastructure wiring.
"""
from __future__ import annotations
import json
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from researchclaw.domains.detector import (
DomainProfile,
detect_domain,
get_profile,
is_ml_domain,
load_all_profiles,
)
from researchclaw.domains.prompt_adapter import get_adapter, PromptBlocks
from researchclaw.domains.experiment_schema import (
Condition,
ConditionRole,
EvaluationSpec,
MetricSpec,
UniversalExperimentPlan,
from_legacy_exp_plan,
)
from researchclaw.experiment.metrics import UniversalMetricParser
from researchclaw.experiment.evaluators.convergence import analyze_convergence
from researchclaw.agents.code_searcher.agent import CodeSearchAgent, CodeSearchResult
from researchclaw.agents.code_searcher.pattern_extractor import CodePatterns
# ---------------------------------------------------------------------------
# Cross-domain domain detection integration
# ---------------------------------------------------------------------------
class TestCrossDomainDetection:
"""Test domain detection across all supported domains."""
def test_all_profiles_loadable(self):
profiles = load_all_profiles()
assert len(profiles) >= 18 # at least 18 domain profiles
def test_ml_vision_full_pipeline(self):
"""ML Vision: detect → adapter → blocks → legacy compatibility."""
profile = detect_domain("image classification on CIFAR-10 with ResNet")
assert profile.domain_id == "ml_vision"
assert is_ml_domain(profile)
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
# ML adapter returns empty blocks (existing behavior)
assert blocks.compute_budget == ""
def test_physics_pde_full_pipeline(self):
"""Physics PDE: detect → adapter → blocks with convergence guidance."""
profile = detect_domain("finite element method for Poisson equation")
assert profile.domain_id == "physics_pde"
assert not is_ml_domain(profile)
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
# Physics adapter should provide non-empty guidance
assert blocks.code_generation_hints # not empty
# Blueprint context should mention convergence
ctx = adapter.get_blueprint_context()
assert ctx # not empty
def test_economics_full_pipeline(self):
"""Economics: detect → adapter → progressive spec guidance."""
profile = detect_domain("panel data regression with instrumental variables")
assert profile.domain_id == "economics_empirical"
adapter = get_adapter(profile)
blocks = adapter.get_experiment_design_blocks({})
assert "progressive" in blocks.experiment_design_context.lower()
def test_chemistry_full_pipeline(self):
"""Chemistry: detect → adapter → PySCF guidance."""
profile = detect_domain("DFT calculation with PySCF for molecular energies")
assert profile.domain_id == "chemistry_qm"
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints
def test_biology_full_pipeline(self):
"""Biology: detect → adapter → scanpy guidance."""
profile = detect_domain("single-cell RNA-seq clustering with scanpy")
assert profile.domain_id == "biology_singlecell"
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints
def test_math_full_pipeline(self):
"""Math: detect → adapter → convergence guidance."""
profile = detect_domain("Runge-Kutta ODE solver convergence analysis")
assert profile.domain_id == "mathematics_numerical"
adapter = get_adapter(profile)
blocks = adapter.get_code_generation_blocks({})
assert blocks.code_generation_hints
# ---------------------------------------------------------------------------
# Universal Experiment Schema integration
# ---------------------------------------------------------------------------
class TestExperimentSchemaIntegration:
def test_physics_convergence_plan(self):
"""Create a physics convergence study plan."""
plan = UniversalExperimentPlan(
experiment_type="convergence",
domain_id="physics_pde",
problem_description="Solve Poisson equation with FEM and FDM",
conditions=[
Condition(name="FDM_2nd", role="reference",
description="2nd order finite difference"),
Condition(name="FEM_P1", role="proposed",
description="P1 finite element method"),
Condition(name="FEM_P2", role="variant",
varies_from="FEM_P1",
description="P2 finite element method"),
],
evaluation=EvaluationSpec(
primary_metric=MetricSpec(
name="l2_error",
direction="minimize",
unit="relative",
),
protocol="Run at 5 grid sizes, measure L2 error",
statistical_test="convergence_order_fit",
num_seeds=1,
),
main_figure_type="convergence_plot",
)
assert len(plan.references) == 1
assert len(plan.proposed) == 1
assert len(plan.variants) == 1
# Test legacy format conversion
legacy = plan.to_legacy_format()
assert len(legacy["baselines"]) == 1
assert legacy["baselines"][0]["name"] == "FDM_2nd"
assert "l2_error" in legacy["metrics"]
# Test YAML serialization
yaml_str = plan.to_yaml()
assert "convergence" in yaml_str
assert "FDM_2nd" in yaml_str
def test_economics_progressive_plan(self):
"""Create an economics progressive specification plan."""
plan = UniversalExperimentPlan(
experiment_type="progressive_spec",
domain_id="economics_empirical",
conditions=[
Condition(name="OLS", role="reference",
description="Simple OLS"),
Condition(name="OLS_controls", role="proposed",
description="OLS with control variables"),
Condition(name="FE", role="variant",
varies_from="OLS_controls",
description="Fixed effects"),
Condition(name="IV_2SLS", role="variant",
varies_from="OLS_controls",
description="Instrumental variables"),
],
evaluation=EvaluationSpec(
primary_metric=MetricSpec(name="coefficient", direction="maximize"),
statistical_test="hausman_test",
),
main_table_type="regression_table",
)
assert len(plan.conditions) == 4
legacy = plan.to_legacy_format()
assert len(legacy["ablations"]) == 2 # FE and IV are variants
# ---------------------------------------------------------------------------
# Metric Parser + Convergence Evaluator integration
# ---------------------------------------------------------------------------
class TestMetricConvergenceIntegration:
def test_json_convergence_end_to_end(self, tmp_path):
"""Parse JSON convergence results → analyze convergence → report."""
data = {
"experiment_type": "convergence",
"convergence": {
"euler": [
{"h": 0.1, "error": 0.1},
{"h": 0.05, "error": 0.05},
{"h": 0.025, "error": 0.025},
{"h": 0.0125, "error": 0.0125},
],
"rk4": [
{"h": 0.1, "error": 1e-4},
{"h": 0.05, "error": 6.25e-6},
{"h": 0.025, "error": 3.9e-7},
{"h": 0.0125, "error": 2.44e-8},
],
},
"metadata": {"domain": "mathematics_numerical"},
}
(tmp_path / "results.json").write_text(json.dumps(data))
# Parse
parser = UniversalMetricParser()
results = parser.parse(tmp_path)
assert results.source == "json"
assert "euler" in results.convergence
# Analyze convergence
report = analyze_convergence(
results.convergence,
expected_orders={"euler": 1.0, "rk4": 4.0},
)
assert len(report.methods) == 2
euler = next(r for r in report.methods if r.method == "euler")
rk4 = next(r for r in report.methods if r.method == "rk4")
assert abs(euler.convergence_order - 1.0) < 0.2
assert abs(rk4.convergence_order - 4.0) < 0.5
assert rk4.convergence_order > euler.convergence_order
assert report.best_method == "rk4"
def test_flat_metrics_backward_compatible(self, tmp_path):
"""Ensure new metric parser produces backward-compatible output."""
# Write old-style stdout
result = UniversalMetricParser().parse(
tmp_path,
stdout="accuracy: 0.95\nloss: 0.32\ncondition=proposed accuracy: 0.95\n",
)
flat = result.to_flat_metrics()
assert "accuracy" in flat
assert "loss" in flat
assert flat["accuracy"] == 0.95
# ---------------------------------------------------------------------------
# Code Search + Domain Profile integration
# ---------------------------------------------------------------------------
class TestCodeSearchIntegration:
def test_code_search_result_in_blueprint(self):
"""Code search results should be formattable as prompt context."""
result = CodeSearchResult(
patterns=CodePatterns(
api_patterns=[
"from pyscf import gto, scf\nmol = gto.M(atom='H 0 0 0; H 0 0 0.74', basis='sto-3g')",
],
file_structure={"main.py": "Entry point", "molecule.py": "Molecule definitions"},
evaluation_patterns=["mae = np.mean(np.abs(predicted - reference))"],
),
repos_found=[
MagicMock(full_name="user/pyscf-example", stars=200),
],
)
ctx = result.to_prompt_context()
assert "pyscf" in ctx
assert "molecule.py" in ctx
def test_domain_adapter_blueprint_context(self):
"""Domain adapter should produce useful blueprint context."""
profile = get_profile("physics_simulation")
if profile is None:
pytest.skip("physics_simulation profile not found")
adapter = get_adapter(profile)
ctx = adapter.get_blueprint_context()
# Should mention file structure
assert "main.py" in ctx or "integrator" in ctx.lower()
# Should mention libraries
assert "numpy" in ctx.lower() or "scipy" in ctx.lower() or ctx != ""
# ---------------------------------------------------------------------------
# CodeAgent domain injection test
# ---------------------------------------------------------------------------
class TestCodeAgentDomainInjection:
def test_code_agent_accepts_domain_profile(self):
"""CodeAgent should accept domain_profile and code_search_result."""
from researchclaw.pipeline.code_agent import CodeAgent, CodeAgentConfig
config = CodeAgentConfig(enabled=True)
profile = DomainProfile(
domain_id="physics_pde",
display_name="PDE Solvers",
core_libraries=["numpy", "scipy"],
)
search_result = CodeSearchResult(
patterns=CodePatterns(
api_patterns=["import scipy.sparse"],
),
)
agent = CodeAgent(
llm=MagicMock(),
prompts=MagicMock(),
config=config,
stage_dir=Path("/tmp/test"),
domain_profile=profile,
code_search_result=search_result,
)
# Verify the domain context builder works
ctx = agent._build_domain_context()
assert "scipy" in ctx.lower() or ctx != ""
def test_code_agent_ml_domain_no_extra_context(self):
"""ML domain should add minimal extra context (preserve existing behavior)."""
from researchclaw.pipeline.code_agent import CodeAgent, CodeAgentConfig
config = CodeAgentConfig(enabled=True)
profile = get_profile("ml_vision") or DomainProfile(
domain_id="ml_vision",
display_name="Computer Vision",
)
agent = CodeAgent(
llm=MagicMock(),
prompts=MagicMock(),
config=config,
stage_dir=Path("/tmp/test"),
domain_profile=profile,
code_search_result=None, # No code search for ML
)
# ML adapter returns empty blocks → minimal context
ctx = agent._build_domain_context()
# It's acceptable for ML to have some context from file structure,
# but it should NOT have code search results
# (we didn't provide code_search_result)
assert "Reference Code from GitHub" not in ctx
# ---------------------------------------------------------------------------
# Docker profile mapping test
# ---------------------------------------------------------------------------
class TestDockerProfileMapping:
def test_domain_to_docker_mapping(self):
"""All domains should map to a valid docker profile."""
import yaml
profiles_path = Path(__file__).parent.parent / "researchclaw" / "data" / "docker_profiles.yaml"
if not profiles_path.exists():
pytest.skip("docker_profiles.yaml not found")
with profiles_path.open() as f:
docker_config = yaml.safe_load(f)
domain_map = docker_config.get("domain_map", {})
profiles = docker_config.get("profiles", {})
# Every mapped domain should point to a valid profile
for domain_id, profile_name in domain_map.items():
assert profile_name in profiles, (
f"Domain {domain_id} maps to unknown profile: {profile_name}"
)
def test_all_loaded_domains_have_docker_mapping(self):
"""All domain profiles should have a docker mapping."""
import yaml
profiles_path = Path(__file__).parent.parent / "researchclaw" / "data" / "docker_profiles.yaml"
if not profiles_path.exists():
pytest.skip("docker_profiles.yaml not found")
with profiles_path.open() as f:
docker_config = yaml.safe_load(f)
domain_map = docker_config.get("domain_map", {})
domain_profiles = load_all_profiles()
unmapped = []
for domain_id in domain_profiles:
if domain_id not in domain_map and domain_id != "generic":
unmapped.append(domain_id)
# Allow some unmapped (new domains without docker images yet)
# but the core ones should be mapped
core_domains = [
"ml_vision", "ml_nlp", "ml_rl", "physics_simulation",
"physics_pde", "chemistry_qm", "economics_empirical",
"mathematics_numerical",
]
for d in core_domains:
assert d in domain_map, f"Core domain {d} missing from docker mapping"
================================================
FILE: tests/test_v6_improvements.py
================================================
"""Tests for V6 improvements (IMP-13 through IMP-16).
Run with:
.venv/bin/python3 -m pytest tests/test_v6_improvements.py -v
or:
.venv/bin/python3 tests/test_v6_improvements.py
"""
from __future__ import annotations
import re
import sys
import statistics
import random
import textwrap
from pathlib import Path
# ============================================================
# IMP-13: Test _extract_paper_title import & behaviour
# ============================================================
class TestIMP13_ExtractPaperTitle:
"""IMP-13: runner.py imports _extract_paper_title from executor.
Verify the import works and the function produces correct results."""
def test_import_works(self):
"""The import `from researchclaw.pipeline.executor import _extract_paper_title`
must succeed — runner.py line 394 depends on it."""
from researchclaw.pipeline.executor import _extract_paper_title
assert callable(_extract_paper_title), "_extract_paper_title should be callable"
print("[IMP-13] PASS: import _extract_paper_title works")
def test_extracts_h1_title(self):
from researchclaw.pipeline.executor import _extract_paper_title
md = textwrap.dedent("""\
# A Novel Approach to Deep Reinforcement Learning
## Abstract
This paper presents...
""")
title = _extract_paper_title(md)
assert title == "A Novel Approach to Deep Reinforcement Learning", \
f"Expected H1 title, got: {title!r}"
print(f"[IMP-13] PASS: extracted title = {title!r}")
def test_skips_abstract_heading(self):
"""Title before Abstract should be found; Abstract heading itself skipped."""
from researchclaw.pipeline.executor import _extract_paper_title
md = textwrap.dedent("""\
# A Real Title of at Least Four Words
## Abstract
Some text...
""")
title = _extract_paper_title(md)
# "Abstract" should be skipped; the real title (before Abstract) is found
assert title == "A Real Title of at Least Four Words", \
f"Expected real title, got: {title!r}"
print(f"[IMP-13] PASS: skipped Abstract, got title = {title!r}")
def test_title_after_abstract_not_found(self):
"""If the only real title is AFTER Abstract, it should not be found
(function searches only before Abstract heading)."""
from researchclaw.pipeline.executor import _extract_paper_title
md = textwrap.dedent("""\
# Abstract
# A Title That Appears After Abstract
Some text...
""")
title = _extract_paper_title(md)
# Title after Abstract is not in the search region, so fallback
assert title == "Untitled Paper", \
f"Expected 'Untitled Paper' since title is after Abstract, got: {title!r}"
print(f"[IMP-13] PASS: title after Abstract not found, fallback = {title!r}")
def test_fallback_untitled(self):
from researchclaw.pipeline.executor import _extract_paper_title
md = "Just some text without any headings."
title = _extract_paper_title(md)
assert title == "Untitled Paper", f"Expected 'Untitled Paper', got: {title!r}"
print(f"[IMP-13] PASS: fallback = {title!r}")
def test_bold_title(self):
from researchclaw.pipeline.executor import _extract_paper_title
md = textwrap.dedent("""\
**A Bold Title for This Paper**
## Abstract
Text here...
""")
title = _extract_paper_title(md)
assert "Bold Title" in title, f"Expected bold title, got: {title!r}"
print(f"[IMP-13] PASS: bold title = {title!r}")
# ============================================================
# IMP-14: Test orphaned cite-key stripping logic
# ============================================================
class TestIMP14_StripOrphanedCites:
"""IMP-14: After packaging, any \\cite{key} where key is not in
references.bib should be stripped from paper.tex."""
@staticmethod
def _run_cite_stripping(tex_text: str, bib_text: str) -> str:
"""Reproduce the IMP-14 logic from runner.py lines 505-532."""
all_cite_keys: set[str] = set()
for cm in re.finditer(r"\\cite\{([^}]+)\}", tex_text):
all_cite_keys.update(k.strip() for k in cm.group(1).split(","))
bib_keys = set(re.findall(r"@\w+\{([^,]+),", bib_text))
missing = all_cite_keys - bib_keys
if missing:
def _filter_cite(m: re.Match[str]) -> str:
keys = [k.strip() for k in m.group(1).split(",")]
kept = [k for k in keys if k not in missing]
if not kept:
return ""
return "\\cite{" + ", ".join(kept) + "}"
tex_text = re.sub(r"\\cite\{([^}]+)\}", _filter_cite, tex_text)
tex_text = re.sub(r" +", " ", tex_text)
tex_text = re.sub(r" ([.,;:)])", r"\1", tex_text)
return tex_text
def test_mixed_real_and_missing_keys(self):
"""\\cite{real_key, missing_key} should become \\cite{real_key}."""
tex = r"Some text \cite{real_key, missing_key} and more."
bib = textwrap.dedent("""\
@article{real_key,
author = {Doe},
title = {Real Paper},
year = {2024},
}
""")
result = self._run_cite_stripping(tex, bib)
assert r"\cite{real_key}" in result, f"Expected \\cite{{real_key}}, got: {result!r}"
assert "missing_key" not in result, f"missing_key should be gone: {result!r}"
print(f"[IMP-14] PASS: mixed keys → {result!r}")
def test_all_keys_missing(self):
"""\\cite{missing1, missing2} should be entirely removed."""
tex = r"Some text \cite{missing1, missing2} more."
bib = "" # empty bib
result = self._run_cite_stripping(tex, bib)
assert r"\cite" not in result, f"Expected no \\cite, got: {result!r}"
print(f"[IMP-14] PASS: all missing → {result!r}")
def test_all_keys_valid(self):
"""When all keys are valid, tex should remain unchanged (except whitespace)."""
tex = r"Text \cite{key1, key2} end."
bib = textwrap.dedent("""\
@article{key1,
author = {A},
title = {T},
year = {2024},
}
@article{key2,
author = {B},
title = {T2},
year = {2024},
}
""")
result = self._run_cite_stripping(tex, bib)
assert r"\cite{key1, key2}" in result, f"Expected unchanged, got: {result!r}"
print(f"[IMP-14] PASS: all valid → {result!r}")
def test_multiple_cite_commands(self):
"""Multiple \\cite commands, each with different missing keys."""
tex = (
r"First \cite{a, b} second \cite{b, c} third \cite{d}."
)
bib = textwrap.dedent("""\
@article{a,
author = {X},
title = {Y},
year = {2024},
}
@article{c,
author = {X},
title = {Y},
year = {2024},
}
""")
result = self._run_cite_stripping(tex, bib)
# a is valid, b is missing, c is valid, d is missing
assert r"\cite{a}" in result, f"Expected \\cite{{a}}, got: {result!r}"
assert r"\cite{c}" in result, f"Expected \\cite{{c}}, got: {result!r}"
# b should not appear as a cite key
assert r"\cite{b}" not in result, f"\\cite{{b}} should be gone: {result!r}"
assert r", b}" not in result and r"{b," not in result, \
f"b key should be stripped: {result!r}"
# \cite{d} should be entirely removed (d was the only key)
assert r"\cite{d}" not in result, f"\\cite{{d}} should be gone: {result!r}"
print(f"[IMP-14] PASS: multiple cites → {result!r}")
def test_whitespace_cleanup(self):
"""After removing a full \\cite{}, leftover double-spaces and ' .' are cleaned."""
tex = r"Text \cite{missing} end."
bib = ""
result = self._run_cite_stripping(tex, bib)
# Should not have double spaces or " ."
assert " " not in result, f"Double space in result: {result!r}"
assert " ." not in result, f"Space-dot in result: {result!r}"
print(f"[IMP-14] PASS: whitespace cleanup → {result!r}")
# ============================================================
# IMP-15: Test BibTeX deduplication
# ============================================================
class TestIMP15_BibDedup:
"""IMP-15: Deduplicate .bib entries sharing the same cite key."""
@staticmethod
def _run_dedup(bib_text: str) -> str:
"""Reproduce IMP-15 logic from runner.py lines 486-503."""
_seen_bib_keys: set[str] = set()
_deduped_entries: list[str] = []
for _bm in re.finditer(
r"(@\w+\{([^,]+),.*?\n\})", bib_text, re.DOTALL
):
_bkey = _bm.group(2).strip()
if _bkey not in _seen_bib_keys:
_seen_bib_keys.add(_bkey)
_deduped_entries.append(_bm.group(1))
if len(_deduped_entries) < len(
list(re.finditer(r"@\w+\{", bib_text))
):
bib_text = "\n\n".join(_deduped_entries) + "\n"
return bib_text
def test_duplicate_entries_removed(self):
bib = textwrap.dedent("""\
@article{smith2024,
author = {Smith},
title = {Paper 1},
year = {2024},
}
@article{smith2024,
author = {Smith},
title = {Paper 1 duplicate},
year = {2024},
}
@article{jones2023,
author = {Jones},
title = {Paper 2},
year = {2023},
}
""")
result = self._run_dedup(bib)
# Count how many @article{smith2024, appear
count_smith = len(re.findall(r"@article\{smith2024,", result))
count_jones = len(re.findall(r"@article\{jones2023,", result))
assert count_smith == 1, f"Expected 1 smith2024 entry, got {count_smith}"
assert count_jones == 1, f"Expected 1 jones2023 entry, got {count_jones}"
# First version should be kept
assert "Paper 1" in result
print(f"[IMP-15] PASS: 2 smith2024 → 1, jones2023 kept. Total entries correct.")
def test_no_duplicates_unchanged(self):
bib = textwrap.dedent("""\
@article{alpha2024,
author = {Alpha},
title = {A},
year = {2024},
}
@inproceedings{beta2023,
author = {Beta},
title = {B},
year = {2023},
}
""")
result = self._run_dedup(bib)
# Should remain unchanged (both entries present)
assert "alpha2024" in result
assert "beta2023" in result
count = len(re.findall(r"@\w+\{", result))
assert count == 2, f"Expected 2 entries, got {count}"
print(f"[IMP-15] PASS: no duplicates → unchanged")
def test_triple_duplicate(self):
bib = textwrap.dedent("""\
@article{x2024,
author = {X},
title = {First},
year = {2024},
}
@article{x2024,
author = {X},
title = {Second},
year = {2024},
}
@article{x2024,
author = {X},
title = {Third},
year = {2024},
}
""")
result = self._run_dedup(bib)
count = len(re.findall(r"@article\{x2024,", result))
assert count == 1, f"Expected 1 x2024 entry, got {count}"
# First version kept
assert "First" in result
assert "Second" not in result
assert "Third" not in result
print(f"[IMP-15] PASS: triple duplicate → 1 entry")
def test_empty_bib(self):
"""Edge case: empty bib text should not crash."""
bib = ""
result = self._run_dedup(bib)
assert result == "", f"Expected empty, got: {result!r}"
print(f"[IMP-15] PASS: empty bib → no crash")
# ============================================================
# IMP-16: Test bootstrap CI fallback
# ============================================================
class TestIMP16_BootstrapCIFallback:
"""IMP-16: If bootstrap CI does not contain the mean,
fall back to normal approximation (mean +/- 1.96*SE)."""
@staticmethod
def _compute_ci_with_fallback(vals: list[float]) -> tuple[float, float, bool]:
"""Reproduce IMP-16 logic from executor.py lines 3367-3397.
Returns (ci_low, ci_high, used_fallback)."""
_mean = statistics.mean(vals)
_std = statistics.stdev(vals)
# Bootstrap 95% CI
_rng = random.Random(42)
_boot_means = []
for _ in range(1000):
_sample = [_rng.choice(vals) for _ in range(len(vals))]
_boot_means.append(statistics.mean(_sample))
_boot_means.sort()
_ci_low = round(_boot_means[int(0.025 * len(_boot_means))], 6)
_ci_high = round(_boot_means[int(0.975 * len(_boot_means))], 6)
# IMP-16: Sanity check
used_fallback = False
if _ci_low > _mean or _ci_high < _mean:
_se = _std / (len(vals) ** 0.5)
_ci_low = round(_mean - 1.96 * _se, 6)
_ci_high = round(_mean + 1.96 * _se, 6)
used_fallback = True
return _ci_low, _ci_high, used_fallback
def test_normal_case_no_fallback(self):
"""Normal data: bootstrap CI should contain the mean, no fallback needed."""
vals = [0.8, 0.82, 0.79, 0.81, 0.83]
ci_low, ci_high, used_fallback = self._compute_ci_with_fallback(vals)
mean = statistics.mean(vals)
assert ci_low <= mean <= ci_high, \
f"CI [{ci_low}, {ci_high}] should contain mean {mean}"
assert not used_fallback, "Should NOT have used fallback for normal data"
print(f"[IMP-16] PASS: normal data → CI=[{ci_low}, {ci_high}], mean={mean:.4f}, no fallback")
def test_fallback_triggers_for_pathological_data(self):
"""Construct data where bootstrap CI might not contain the mean.
This tests the fallback logic path itself. We directly test the
condition and fallback formula rather than relying on pathological
data generation (which is inherently fragile).
"""
# Directly test the fallback formula
vals = [1.0, 2.0, 3.0, 4.0, 5.0]
mean = statistics.mean(vals)
std = statistics.stdev(vals)
se = std / (len(vals) ** 0.5)
# Simulate a bad CI that doesn't contain the mean
bad_ci_low = mean + 0.1 # Above mean - CI doesn't contain mean
bad_ci_high = mean + 1.0
# Apply fallback logic
assert bad_ci_low > mean, "Bad CI should not contain mean"
fallback_low = round(mean - 1.96 * se, 6)
fallback_high = round(mean + 1.96 * se, 6)
assert fallback_low <= mean <= fallback_high, \
f"Fallback CI [{fallback_low}, {fallback_high}] must contain mean {mean}"
print(f"[IMP-16] PASS: fallback CI=[{fallback_low}, {fallback_high}], mean={mean:.4f}")
def test_fallback_ci_always_contains_mean(self):
"""The normal-approximation fallback MUST always contain the mean."""
test_cases = [
[10, 20, 30],
[0.001, 0.002, 0.003, 0.004],
[100, 200, 300, 400, 500],
[-5, -3, -1, 1, 3, 5],
]
for vals in test_cases:
mean = statistics.mean(vals)
std = statistics.stdev(vals)
se = std / (len(vals) ** 0.5)
ci_low = round(mean - 1.96 * se, 6)
ci_high = round(mean + 1.96 * se, 6)
assert ci_low <= mean <= ci_high, \
f"Fallback CI [{ci_low}, {ci_high}] must contain mean {mean} for vals={vals}"
print(f"[IMP-16] PASS: fallback always contains mean for {len(test_cases)} test cases")
def test_condition_check_logic(self):
"""Verify the condition `_ci_low > _mean or _ci_high < _mean` is correct.
The condition should detect when the mean is OUTSIDE the CI."""
mean = 5.0
# Case 1: Mean below CI
assert (6.0 > mean or 8.0 < mean) == True, "Mean below CI not detected"
# Case 2: Mean above CI
assert (1.0 > mean or 4.0 < mean) == True, "Mean above CI not detected"
# Case 3: Mean inside CI
assert (3.0 > mean or 7.0 < mean) == False, "Mean inside CI incorrectly flagged"
# Case 4: Mean equals boundary
assert (5.0 > mean or 7.0 < mean) == False, "Mean at lower boundary incorrectly flagged"
assert (3.0 > mean or 5.0 < mean) == False, "Mean at upper boundary incorrectly flagged"
print("[IMP-16] PASS: condition check logic correct for all cases")
def test_min_sample_size(self):
"""The code requires len(vals) >= 3 for bootstrap. Verify with exactly 3."""
vals = [1.0, 2.0, 3.0]
ci_low, ci_high, _ = self._compute_ci_with_fallback(vals)
mean = statistics.mean(vals)
assert ci_low <= mean <= ci_high, \
f"CI [{ci_low}, {ci_high}] should contain mean {mean} for n=3"
print(f"[IMP-16] PASS: n=3 works → CI=[{ci_low}, {ci_high}], mean={mean:.4f}")
# ============================================================
# Integration-style: Test the runner.py _package_deliverables
# cite-stripping + dedup pipeline end-to-end
# ============================================================
class TestIMP14_15_Integration:
"""End-to-end test: dedup + cite stripping on a realistic scenario."""
def test_dedup_then_strip(self):
"""Run dedup (IMP-15) then cite-strip (IMP-14) in sequence, as runner.py does."""
bib_text = textwrap.dedent("""\
@article{smith2024,
author = {Smith},
title = {Paper A},
year = {2024},
}
@article{smith2024,
author = {Smith},
title = {Paper A dup},
year = {2024},
}
@article{jones2023,
author = {Jones},
title = {Paper B},
year = {2023},
}
""")
tex_text = r"Results from \cite{smith2024, jones2023, ghost2024} show..."
# Step 1: IMP-15 dedup
_seen: set[str] = set()
_deduped: list[str] = []
for m in re.finditer(r"(@\w+\{([^,]+),.*?\n\})", bib_text, re.DOTALL):
k = m.group(2).strip()
if k not in _seen:
_seen.add(k)
_deduped.append(m.group(1))
if len(_deduped) < len(list(re.finditer(r"@\w+\{", bib_text))):
bib_text = "\n\n".join(_deduped) + "\n"
# Verify dedup
assert bib_text.count("smith2024") == 1, "Dedup failed for smith2024"
# Step 2: IMP-14 cite stripping
all_cite_keys: set[str] = set()
for cm in re.finditer(r"\\cite\{([^}]+)\}", tex_text):
all_cite_keys.update(k.strip() for k in cm.group(1).split(","))
bib_keys = set(re.findall(r"@\w+\{([^,]+),", bib_text))
missing = all_cite_keys - bib_keys
assert missing == {"ghost2024"}, f"Expected only ghost2024 missing, got {missing}"
def _filter_cite(m: re.Match[str]) -> str:
keys = [k.strip() for k in m.group(1).split(",")]
kept = [k for k in keys if k not in missing]
if not kept:
return ""
return "\\cite{" + ", ".join(kept) + "}"
tex_text = re.sub(r"\\cite\{([^}]+)\}", _filter_cite, tex_text)
tex_text = re.sub(r" +", " ", tex_text)
tex_text = re.sub(r" ([.,;:)])", r"\1", tex_text)
assert r"\cite{smith2024, jones2023}" in tex_text, \
f"Expected valid keys kept, got: {tex_text!r}"
assert "ghost2024" not in tex_text, \
f"ghost2024 should be stripped: {tex_text!r}"
print(f"[Integration] PASS: dedup + cite strip → {tex_text!r}")
# ============================================================
# Runner
# ============================================================
def run_all_tests():
"""Run all tests manually (fallback if pytest not available)."""
test_classes = [
TestIMP13_ExtractPaperTitle,
TestIMP14_StripOrphanedCites,
TestIMP15_BibDedup,
TestIMP16_BootstrapCIFallback,
TestIMP14_15_Integration,
]
total = 0
passed = 0
failed = 0
errors: list[str] = []
for cls in test_classes:
instance = cls()
test_methods = [m for m in dir(instance) if m.startswith("test_")]
for method_name in sorted(test_methods):
total += 1
method = getattr(instance, method_name)
try:
method()
passed += 1
except Exception as e:
failed += 1
err_msg = f"FAIL: {cls.__name__}.{method_name}: {e}"
errors.append(err_msg)
print(f" FAIL: {err_msg}")
print(f"\n{'='*60}")
print(f"Results: {passed}/{total} passed, {failed} failed")
if errors:
print("Failures:")
for e in errors:
print(f" - {e}")
print(f"{'='*60}")
return failed == 0
if __name__ == "__main__":
# Add project root to path
project_root = Path(__file__).resolve().parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
success = run_all_tests()
sys.exit(0 if success else 1)
================================================
FILE: tests/test_verified_registry.py
================================================
"""Tests for VerifiedRegistry — ground truth number whitelist."""
from __future__ import annotations
import json
import math
from pathlib import Path
import pytest
from researchclaw.pipeline.verified_registry import (
ConditionResult,
VerifiedRegistry,
_is_finite,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
ARTIFACTS = Path(__file__).resolve().parent.parent / "artifacts"
def _load_experiment_summary(run_id: str) -> dict:
"""Load experiment_summary.json for a given run."""
pattern = f"rc-*-{run_id}"
matches = sorted(ARTIFACTS.glob(pattern))
if not matches:
pytest.skip(f"Artifact {run_id} not found")
summary_path = matches[0] / "stage-14" / "experiment_summary.json"
if not summary_path.exists():
pytest.skip(f"No experiment_summary for {run_id}")
return json.loads(summary_path.read_text())
def _load_refinement_log(run_id: str) -> dict | None:
pattern = f"rc-*-{run_id}"
matches = sorted(ARTIFACTS.glob(pattern))
if not matches:
return None
log_path = matches[0] / "stage-13" / "refinement_log.json"
if not log_path.exists():
return None
return json.loads(log_path.read_text())
# ---------------------------------------------------------------------------
# Unit tests — ConditionResult
# ---------------------------------------------------------------------------
class TestConditionResult:
def test_compute_stats_multiple_seeds(self):
cr = ConditionResult(name="test", per_seed_values={0: 10.0, 1: 20.0, 2: 30.0})
cr.compute_stats()
assert cr.n_seeds == 3
assert cr.mean == pytest.approx(20.0)
assert cr.std == pytest.approx(10.0)
def test_compute_stats_single_seed(self):
cr = ConditionResult(name="test", per_seed_values={0: 42.0})
cr.compute_stats()
assert cr.n_seeds == 1
assert cr.mean == pytest.approx(42.0)
assert cr.std == 0.0
def test_compute_stats_with_nan(self):
cr = ConditionResult(
name="test", per_seed_values={0: 10.0, 1: float("nan"), 2: 30.0}
)
cr.compute_stats()
assert cr.n_seeds == 2 # NaN excluded
assert cr.mean == pytest.approx(20.0)
def test_compute_stats_empty(self):
cr = ConditionResult(name="test")
cr.compute_stats()
assert cr.n_seeds == 0
assert cr.mean is None
# ---------------------------------------------------------------------------
# Unit tests — VerifiedRegistry core operations
# ---------------------------------------------------------------------------
class TestVerifiedRegistryCore:
def test_add_value(self):
reg = VerifiedRegistry()
reg.add_value(74.28, "test_source")
assert reg.is_verified(74.28)
# Rounding variant
assert reg.is_verified(74.3, tolerance=0.01)
def test_percentage_conversion(self):
"""Value in [0,1] should also register value*100."""
reg = VerifiedRegistry()
reg.add_value(0.7428, "accuracy_fraction")
assert reg.is_verified(0.7428)
assert reg.is_verified(74.28) # ×100 variant
def test_reverse_percentage(self):
"""Value > 1 should also register value/100."""
reg = VerifiedRegistry()
reg.add_value(74.28, "accuracy_percent")
assert reg.is_verified(74.28)
assert reg.is_verified(0.7428) # ÷100 variant
def test_tolerance_matching(self):
reg = VerifiedRegistry()
reg.add_value(92.14, "test")
# Within 1% tolerance
assert reg.is_verified(92.14)
assert reg.is_verified(92.0, tolerance=0.01) # 0.15% off
# Outside tolerance
assert not reg.is_verified(95.0, tolerance=0.01)
def test_zero_handling(self):
reg = VerifiedRegistry()
reg.add_value(0.0, "zero_metric")
assert reg.is_verified(0.0)
assert reg.is_verified(1e-8) # Very close to zero
assert not reg.is_verified(0.01) # Not close enough
def test_negative_values(self):
reg = VerifiedRegistry()
reg.add_value(-459.6, "bad_return")
assert reg.is_verified(-459.6)
assert reg.is_verified(-460.0, tolerance=0.01)
def test_nan_inf_rejected(self):
reg = VerifiedRegistry()
reg.add_value(float("nan"), "nan_metric")
reg.add_value(float("inf"), "inf_metric")
assert not reg.is_verified(float("nan"))
assert not reg.is_verified(float("inf"))
assert len(reg.values) == 0
def test_lookup(self):
reg = VerifiedRegistry()
reg.add_value(42.0, "the_answer")
assert reg.lookup(42.0) == "the_answer"
assert reg.lookup(999.0) is None
def test_verify_condition(self):
reg = VerifiedRegistry()
reg.condition_names = {"DQN", "DQN+Abstraction"}
assert reg.verify_condition("DQN")
assert not reg.verify_condition("PPO")
# ---------------------------------------------------------------------------
# Unit tests — from_experiment (synthetic data)
# ---------------------------------------------------------------------------
class TestFromExperiment:
def _make_summary(self) -> dict:
return {
"metrics_summary": {
"CondA/0/metric": {"min": 80.0, "max": 80.0, "mean": 80.0, "count": 1},
"CondA/1/metric": {"min": 85.0, "max": 85.0, "mean": 85.0, "count": 1},
"CondB/0/metric": {"min": 70.0, "max": 70.0, "mean": 70.0, "count": 1},
"primary_metric": {"min": 82.5, "max": 82.5, "mean": 82.5, "count": 1},
},
"best_run": {
"metrics": {
"CondA/0/metric": 80.0,
"CondA/1/metric": 85.0,
"CondB/0/metric": 70.0,
"primary_metric": 82.5,
"primary_metric_std": 3.5355,
"total_elapsed_seconds": 1500.0,
},
},
"condition_summaries": {
"CondA": {"metrics": {"metric": 82.5}},
"CondB": {"metrics": {"metric": 70.0}},
},
"condition_metrics": {
"CondA": {"metrics": {"metric": 82.5}},
"CondB": {"metrics": {"metric": 70.0}},
},
"total_conditions": 2,
}
def test_conditions_extracted(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
assert "CondA" in reg.condition_names
assert "CondB" in reg.condition_names
assert len(reg.condition_names) == 2
def test_per_seed_values(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
assert reg.conditions["CondA"].per_seed_values == {0: 80.0, 1: 85.0}
assert reg.conditions["CondB"].per_seed_values == {0: 70.0}
def test_condition_stats(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
cond_a = reg.conditions["CondA"]
assert cond_a.n_seeds == 2
assert cond_a.mean == pytest.approx(82.5)
assert cond_a.std == pytest.approx(3.5355, rel=0.01)
def test_primary_metric(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
assert reg.primary_metric == pytest.approx(82.5)
assert reg.primary_metric_std == pytest.approx(3.5355)
def test_all_values_registered(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
# Core values must be verified
assert reg.is_verified(80.0)
assert reg.is_verified(85.0)
assert reg.is_verified(70.0)
assert reg.is_verified(82.5)
assert reg.is_verified(3.5355, tolerance=0.01)
def test_pairwise_differences(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
diff = 82.5 - 70.0 # CondA.mean - CondB.mean
assert reg.is_verified(diff)
assert reg.is_verified(abs(diff))
def test_fabricated_number_rejected(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
assert not reg.is_verified(99.99)
assert not reg.is_verified(60.51)
def test_infra_keys_excluded(self):
reg = VerifiedRegistry.from_experiment(self._make_summary())
# total_elapsed_seconds goes to training_config, not values
assert 1500.0 not in reg.values
assert reg.training_config.get("total_elapsed_seconds") == 1500.0
def test_with_refinement_log(self):
summary = self._make_summary()
ref_log = {
"best_metric": 82.5,
"best_version": "experiment_v1/",
"iterations": [
{
"version_dir": "experiment_v1/",
"metric": 82.5,
"sandbox": {"metrics": {"CondA/0/metric": 80.0}},
}
],
}
reg = VerifiedRegistry.from_experiment(summary, ref_log)
assert reg.is_verified(82.5)
# ---------------------------------------------------------------------------
# Integration tests — real artifact data
# ---------------------------------------------------------------------------
class TestRealArtifacts:
"""Test against actual pipeline output. Skipped if artifacts not present."""
def test_run_e57360_rl_exploration(self):
"""Run 38 (RL LACE) — 3 conditions, CartPole + Acrobot."""
summary = _load_experiment_summary("e57360")
ref_log = _load_refinement_log("e57360")
reg = VerifiedRegistry.from_experiment(summary, ref_log)
# Conditions that actually ran
assert reg.verify_condition("DQN")
assert reg.verify_condition("DQN+Abstraction")
assert reg.verify_condition("DQN+RawCount")
# Conditions that did NOT run (paper fabricated these)
assert not reg.verify_condition("PPO")
assert not reg.verify_condition("PPO+Abstraction")
assert not reg.verify_condition("DQN+Autoencoder")
# Real primary metric
assert reg.is_verified(278.9333)
assert reg.is_verified(146.4139, tolerance=0.01)
# Fabricated number from paper (0.0 primary metric) — should NOT verify
# unless 0.0 happens to be in the data for another reason
# The paper claimed primary_metric=0.0 which is fabricated
assert reg.primary_metric == pytest.approx(278.9333)
def test_run_acbdfa_cnn_vs_ssm(self):
"""Run acbdfa (CTS) — ResNet vs S4D on CIFAR-100."""
summary = _load_experiment_summary("acbdfa")
reg = VerifiedRegistry.from_experiment(summary)
# Real values from experiment
assert reg.is_verified(69.99)
assert reg.is_verified(69.93)
assert reg.is_verified(58.66)
assert reg.is_verified(2.75)
# Primary metric
assert reg.is_verified(66.1933, tolerance=0.01)
def test_run_85fefc_contrastive_kd(self):
"""Run 85fefc (CRAFT) — contrastive KD."""
summary = _load_experiment_summary("85fefc")
ref_log = _load_refinement_log("85fefc")
reg = VerifiedRegistry.from_experiment(summary, ref_log)
# Should have conditions
assert len(reg.condition_names) > 0
# Primary metric should be registered
assert reg.primary_metric is not None
def test_run_8b4a1b_gard_lora(self):
"""Run 8b4a1b (GARD) — experiment failed, very few values."""
summary = _load_experiment_summary("8b4a1b")
reg = VerifiedRegistry.from_experiment(summary)
# With empty metrics, registry should be sparse
best_metrics = summary.get("best_run", {}).get("metrics", {})
if not best_metrics:
assert len(reg.values) == 0
# ---------------------------------------------------------------------------
# Unit tests — from_run_dir (merges multiple sources)
# ---------------------------------------------------------------------------
class TestFromRunDir:
def _write_summary(self, path: Path, data: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
def test_from_run_dir_merges_multiple_stage14(self, tmp_path: Path) -> None:
"""Two stage-14 dirs with different values → both present."""
run_dir = tmp_path / "run"
run_dir.mkdir()
# Stage-14 with CondA
self._write_summary(
run_dir / "stage-14" / "experiment_summary.json",
{
"best_run": {"metrics": {"CondA/0/metric": 80.0}},
"condition_summaries": {"CondA": {"metrics": {"metric": 80.0}}},
"metrics_summary": {},
},
)
# Stage-14-v2 with CondB
self._write_summary(
run_dir / "stage-14-v2" / "experiment_summary.json",
{
"best_run": {"metrics": {"CondB/0/metric": 90.0}},
"condition_summaries": {"CondB": {"metrics": {"metric": 90.0}}},
"metrics_summary": {},
},
)
reg = VerifiedRegistry.from_run_dir(run_dir)
assert "CondA" in reg.condition_names
assert "CondB" in reg.condition_names
assert reg.is_verified(80.0)
assert reg.is_verified(90.0)
def test_from_run_dir_includes_best(self, tmp_path: Path) -> None:
"""experiment_summary_best.json values merged."""
run_dir = tmp_path / "run"
run_dir.mkdir()
# Only best summary at root level
self._write_summary(
run_dir / "experiment_summary_best.json",
{
"best_run": {"metrics": {"primary_metric": 0.95}},
"condition_summaries": {"Proposed": {"metrics": {"acc": 0.95}}},
"metrics_summary": {"acc": {"mean": 0.95, "min": 0.95, "max": 0.95}},
},
)
reg = VerifiedRegistry.from_run_dir(run_dir)
assert reg.is_verified(0.95)
assert reg.is_verified(95.0) # percentage variant
assert "Proposed" in reg.condition_names
def test_from_run_dir_empty_dir(self, tmp_path: Path) -> None:
"""Empty run dir → empty registry, no crash."""
run_dir = tmp_path / "empty_run"
run_dir.mkdir()
reg = VerifiedRegistry.from_run_dir(run_dir)
assert len(reg.values) == 0
assert len(reg.condition_names) == 0
# -----------------------------------------------------------------------
# BUG-222: best_only mode — REFINE bypass prevention
# -----------------------------------------------------------------------
def test_best_only_uses_experiment_summary_best(self, tmp_path: Path) -> None:
"""best_only=True should use ONLY experiment_summary_best.json."""
run_dir = tmp_path / "run"
run_dir.mkdir()
# v1 (best): FeatureKD 74.52%
self._write_summary(
run_dir / "experiment_summary_best.json",
{
"best_run": {"metrics": {"FeatureKD/0/metric": 0.7452}},
"condition_summaries": {"FeatureKD": {"metrics": {"metric": 0.7452}}},
"metrics_summary": {"metric": {"mean": 0.7452}},
},
)
# v3 (regressed): FeatureKD 69.30%
self._write_summary(
run_dir / "stage-14" / "experiment_summary.json",
{
"best_run": {"metrics": {"FeatureKD/0/metric": 0.6930}},
"condition_summaries": {"FeatureKD": {"metrics": {"metric": 0.6930}}},
"metrics_summary": {"metric": {"mean": 0.6930}},
},
)
reg = VerifiedRegistry.from_run_dir(run_dir, best_only=True)
# Should ONLY have v1 (best) data
assert reg.is_verified(0.7452)
assert reg.is_verified(74.52) # percentage variant
# Should NOT have v3 (regressed) data
assert not reg.is_verified(0.6930)
assert not reg.is_verified(69.30)
def test_best_only_excludes_refinement_log(self, tmp_path: Path) -> None:
"""best_only=True should NOT merge refinement_log.json sandbox data."""
run_dir = tmp_path / "run"
run_dir.mkdir()
# Best summary
self._write_summary(
run_dir / "experiment_summary_best.json",
{
"best_run": {"metrics": {"primary_metric": 0.7452}},
"condition_summaries": {"FeatureKD": {"metrics": {"metric": 0.7452}}},
"metrics_summary": {"metric": {"mean": 0.7452}},
},
)
# Refinement log with sandbox metrics from regressed iteration
rl_dir = run_dir / "stage-13"
rl_dir.mkdir(parents=True)
(rl_dir / "refinement_log.json").write_text(json.dumps({
"iterations": [
{"sandbox": {"metrics": {"primary_metric": 0.6930, "best_metric": 0.6930}}}
]
}), encoding="utf-8")
reg = VerifiedRegistry.from_run_dir(run_dir, best_only=True)
assert reg.is_verified(0.7452)
assert not reg.is_verified(0.6930), "Refinement log sandbox values should NOT be in best_only registry"
def test_best_only_falls_back_to_stage14(self, tmp_path: Path) -> None:
"""best_only=True without best.json falls back to stage-14/ (non-versioned)."""
run_dir = tmp_path / "run"
run_dir.mkdir()
self._write_summary(
run_dir / "stage-14" / "experiment_summary.json",
{
"best_run": {"metrics": {"metric": 0.85}},
"condition_summaries": {"Baseline": {"metrics": {"metric": 0.85}}},
"metrics_summary": {"metric": {"mean": 0.85}},
},
)
reg = VerifiedRegistry.from_run_dir(run_dir, best_only=True)
assert reg.is_verified(0.85)
assert "Baseline" in reg.condition_names
def test_default_mode_still_merges_all(self, tmp_path: Path) -> None:
"""Default (best_only=False) preserves backward-compat merging."""
run_dir = tmp_path / "run"
run_dir.mkdir()
self._write_summary(
run_dir / "experiment_summary_best.json",
{
"best_run": {"metrics": {"FeatureKD/0/metric": 0.7452}},
"condition_summaries": {"FeatureKD": {"metrics": {"metric": 0.7452}}},
"metrics_summary": {},
},
)
self._write_summary(
run_dir / "stage-14" / "experiment_summary.json",
{
"best_run": {"metrics": {"FeatureKD/0/metric": 0.6930}},
"condition_summaries": {"FeatureKD": {"metrics": {"metric": 0.6930}}},
"metrics_summary": {},
},
)
reg = VerifiedRegistry.from_run_dir(run_dir, best_only=False)
# Both should be present in non-best_only mode
assert reg.is_verified(0.7452)
assert reg.is_verified(0.6930)
================================================
FILE: tests/test_web_crawler.py
================================================
"""Tests for researchclaw.web.crawler — WebCrawler."""
from __future__ import annotations
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.web.crawler import CrawlResult, WebCrawler
from researchclaw.web import check_url_ssrf
# ---------------------------------------------------------------------------
# CrawlResult dataclass
# ---------------------------------------------------------------------------
class TestCrawlResult:
def test_has_content_true(self):
r = CrawlResult(url="https://example.com", markdown="x" * 100, success=True)
assert r.has_content
def test_has_content_false_empty(self):
r = CrawlResult(url="https://example.com", markdown="", success=True)
assert not r.has_content
def test_has_content_false_short(self):
r = CrawlResult(url="https://example.com", markdown="too short", success=True)
assert not r.has_content
# ---------------------------------------------------------------------------
# HTML → Markdown conversion (urllib fallback)
# ---------------------------------------------------------------------------
class TestHtmlToMarkdown:
def test_strips_script_tags(self):
html = "Hello
World
"
md = WebCrawler._html_to_markdown(html)
assert "alert" not in md
assert "Hello" in md
assert "World" in md
def test_converts_headings(self):
html = "Title
Subtitle
Section
"
md = WebCrawler._html_to_markdown(html)
assert "# Title" in md
assert "## Subtitle" in md
assert "### Section" in md
def test_converts_paragraphs(self):
html = "First paragraph.
Second paragraph.
"
md = WebCrawler._html_to_markdown(html)
assert "First paragraph." in md
assert "Second paragraph." in md
def test_converts_links(self):
html = 'Click'
md = WebCrawler._html_to_markdown(html)
assert "[Click](https://example.com)" in md
def test_converts_list_items(self):
html = "- Item 1
- Item 2
"
md = WebCrawler._html_to_markdown(html)
assert "- Item 1" in md
assert "- Item 2" in md
def test_decodes_entities(self):
html = "A & B < C > D
"
md = WebCrawler._html_to_markdown(html)
assert "A & B < C > D" in md
def test_collapses_whitespace(self):
html = "Hello
\n\n\n\nWorld
"
md = WebCrawler._html_to_markdown(html)
assert "\n\n\n" not in md
# ---------------------------------------------------------------------------
# urllib fallback crawl
# ---------------------------------------------------------------------------
class TestCrawlUrllibFallback:
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_urllib_success(self, mock_urlopen):
mock_resp = MagicMock()
mock_resp.read.return_value = b"Test Content here
"
mock_resp.headers = {"Content-Type": "text/html; charset=utf-8"}
mock_urlopen.return_value = mock_resp
crawler = WebCrawler()
import time
t0 = time.monotonic()
result = crawler._crawl_with_urllib("https://example.com", t0)
assert result.success
assert result.title == "Test"
assert "Content here" in result.markdown
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_urllib_truncation(self, mock_urlopen):
mock_resp = MagicMock()
long_content = "" + "x" * 60000 + "
"
mock_resp.read.return_value = long_content.encode()
mock_resp.headers = {"Content-Type": "text/html"}
mock_urlopen.return_value = mock_resp
crawler = WebCrawler(max_content_length=1000)
import time
t0 = time.monotonic()
result = crawler._crawl_with_urllib("https://example.com", t0)
assert len(result.markdown) <= 1100 # 1000 + truncation notice
# ---------------------------------------------------------------------------
# Sync crawl (goes through crawl4ai → urllib fallback chain)
# ---------------------------------------------------------------------------
class TestCrawlSync:
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_sync_falls_back_to_urllib(self, mock_urlopen):
"""crawl_sync tries crawl4ai, then falls back to urllib."""
mock_resp = MagicMock()
mock_resp.read.return_value = b"Sync Works via urllib
"
mock_resp.headers = {"Content-Type": "text/html"}
mock_urlopen.return_value = mock_resp
crawler = WebCrawler()
# Crawl4AI may or may not work in test env (no browser),
# but urllib fallback should always work
result = crawler.crawl_sync("https://example.com")
assert result.success or result.error # either crawl4ai or urllib
# ---------------------------------------------------------------------------
# Async crawl
# ---------------------------------------------------------------------------
class TestCrawlAsync:
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_async_urllib_fallback(self, mock_urlopen):
"""When crawl4ai's browser isn't set up, async crawl falls back to urllib."""
mock_resp = MagicMock()
mock_resp.read.return_value = b"Async Works
"
mock_resp.headers = {"Content-Type": "text/html"}
mock_urlopen.return_value = mock_resp
crawler = WebCrawler()
result = asyncio.run(crawler.crawl("https://example.com"))
# Should succeed via either crawl4ai or urllib fallback
assert isinstance(result, CrawlResult)
# ---------------------------------------------------------------------------
# SSRF validation: check_url_ssrf
# ---------------------------------------------------------------------------
class TestCheckUrlSsrf:
def test_http_allowed(self):
assert check_url_ssrf("http://example.com") is None
def test_https_allowed(self):
assert check_url_ssrf("https://arxiv.org/abs/2301.00001") is None
def test_rejects_file_scheme(self):
err = check_url_ssrf("file:///etc/passwd")
assert err is not None
assert "scheme" in err.lower()
def test_rejects_ftp_scheme(self):
err = check_url_ssrf("ftp://server/file")
assert err is not None
def test_rejects_localhost(self):
err = check_url_ssrf("http://localhost:8080")
assert err is not None
assert "internal" in err.lower() or "private" in err.lower() or "blocked" in err.lower()
def test_rejects_127(self):
err = check_url_ssrf("http://127.0.0.1:6379")
assert err is not None
def test_rejects_10_range(self):
err = check_url_ssrf("http://10.0.0.1")
assert err is not None
def test_rejects_172_range(self):
err = check_url_ssrf("http://172.16.0.1")
assert err is not None
def test_rejects_192_range(self):
err = check_url_ssrf("http://192.168.1.1")
assert err is not None
def test_rejects_aws_metadata(self):
err = check_url_ssrf("http://169.254.169.254/latest/meta-data")
assert err is not None
def test_rejects_empty_hostname(self):
err = check_url_ssrf("http://")
assert err is not None
# ---------------------------------------------------------------------------
# Crawler SSRF integration
# ---------------------------------------------------------------------------
class TestCrawlerSsrfIntegration:
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_sync_rejects_private_url(self, mock_urlopen):
crawler = WebCrawler()
result = crawler.crawl_sync("http://127.0.0.1:8080")
assert not result.success
assert result.error
mock_urlopen.assert_not_called()
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_sync_rejects_file_scheme(self, mock_urlopen):
crawler = WebCrawler()
result = crawler.crawl_sync("file:///etc/passwd")
assert not result.success
assert "scheme" in result.error.lower()
mock_urlopen.assert_not_called()
@patch("researchclaw.web.crawler.urlopen")
def test_crawl_async_rejects_private_url(self, mock_urlopen):
crawler = WebCrawler()
result = asyncio.run(crawler.crawl("http://10.0.0.1:9200"))
assert not result.success
assert result.error
mock_urlopen.assert_not_called()
================================================
FILE: tests/test_web_integration.py
================================================
"""Integration tests for researchclaw.web — WebSearchAgent end-to-end."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.web.agent import WebSearchAgent, WebSearchAgentResult
from researchclaw.web.crawler import CrawlResult
from researchclaw.web.search import SearchResult, WebSearchResponse
from researchclaw.web.scholar import ScholarPaper
# ---------------------------------------------------------------------------
# WebSearchAgentResult
# ---------------------------------------------------------------------------
class TestWebSearchAgentResult:
def test_total_results(self):
r = WebSearchAgentResult(
topic="test",
web_results=[SearchResult(title="A", url="u1")],
scholar_papers=[ScholarPaper(title="B")],
)
assert r.total_results == 2
def test_to_context_string_empty(self):
r = WebSearchAgentResult(topic="test")
ctx = r.to_context_string()
assert isinstance(ctx, str)
def test_to_context_string_with_results(self):
r = WebSearchAgentResult(
topic="knowledge distillation",
web_results=[
SearchResult(
title="KD Survey",
url="https://example.com/kd",
snippet="A comprehensive survey on KD",
source="tavily",
),
],
scholar_papers=[
ScholarPaper(
title="Distilling Knowledge",
authors=["Hinton", "Vinyals", "Dean"],
year=2015,
citation_count=5000,
abstract="We propose a technique for model compression.",
),
],
search_answer="KD is a model compression technique.",
)
ctx = r.to_context_string()
assert "AI Search Summary" in ctx
assert "KD Survey" in ctx
assert "Distilling Knowledge" in ctx
assert "Hinton" in ctx
def test_to_context_string_truncation(self):
r = WebSearchAgentResult(
topic="test",
web_results=[
SearchResult(title=f"R{i}", url=f"u{i}", snippet="x" * 1000)
for i in range(50)
],
)
ctx = r.to_context_string(max_length=5000)
assert len(ctx) <= 5100
def test_to_dict(self):
r = WebSearchAgentResult(
topic="test",
web_results=[SearchResult(title="A", url="u1")],
)
d = r.to_dict()
assert d["topic"] == "test"
assert d["web_results_count"] == 1
def test_to_context_with_crawled_pages(self):
r = WebSearchAgentResult(
topic="test",
crawled_pages=[
CrawlResult(
url="https://blog.example.com",
markdown="# Great Blog Post\n\nContent " * 50,
title="Great Blog Post",
success=True,
),
],
)
ctx = r.to_context_string()
assert "Crawled Page Content" in ctx
assert "Great Blog Post" in ctx
# ---------------------------------------------------------------------------
# WebSearchAgent — orchestration
# ---------------------------------------------------------------------------
class TestWebSearchAgent:
def test_generate_queries(self):
queries = WebSearchAgent._generate_queries("knowledge distillation")
assert len(queries) == 3
assert "knowledge distillation" in queries
assert any("survey" in q for q in queries)
assert any("benchmark" in q for q in queries)
def test_select_urls_to_crawl(self):
agent = WebSearchAgent(max_crawl_urls=3)
result = WebSearchAgentResult(
topic="test",
web_results=[
SearchResult(title=f"R{i}", url=f"https://ex.com/{i}")
for i in range(10)
],
)
urls = agent._select_urls_to_crawl(result)
assert len(urls) <= 3
assert all(url.startswith("https://") for url in urls)
def test_select_urls_skips_pdf(self):
agent = WebSearchAgent(max_crawl_urls=5)
result = WebSearchAgentResult(
topic="test",
web_results=[
SearchResult(title="Paper", url="https://ex.com/paper.pdf"),
SearchResult(title="Blog", url="https://ex.com/blog"),
],
)
urls = agent._select_urls_to_crawl(result)
assert "https://ex.com/paper.pdf" not in urls
assert "https://ex.com/blog" in urls
def test_find_pdf_urls(self):
result = WebSearchAgentResult(
topic="test",
web_results=[
SearchResult(title="P1", url="https://ex.com/a.pdf"),
SearchResult(title="P2", url="https://ex.com/b.html"),
SearchResult(title="P3", url="https://ex.com/c.pdf"),
],
)
pdfs = WebSearchAgent._find_pdf_urls(result)
assert len(pdfs) == 2
assert all(u.endswith(".pdf") for u in pdfs)
@patch("researchclaw.web.search.urlopen")
@patch("researchclaw.web.scholar.scholarly")
def test_search_and_extract_minimal(self, mock_scholarly, mock_urlopen):
"""End-to-end test with mocked HTTP — DuckDuckGo + mocked Scholar."""
mock_resp = MagicMock()
mock_resp.read.return_value = b"""
Paper About KD
A study on knowledge distillation
"""
mock_urlopen.return_value = mock_resp
# Mock scholarly to return empty (avoid network calls)
mock_scholarly.search_pubs.return_value = iter([])
agent = WebSearchAgent(
enable_scholar=True,
enable_crawling=False,
enable_pdf=False,
)
result = agent.search_and_extract("knowledge distillation")
assert result.topic == "knowledge distillation"
assert result.elapsed_seconds > 0
@patch("researchclaw.web.search.urlopen")
@patch("researchclaw.web.scholar.scholarly")
@patch("researchclaw.web.crawler.urlopen")
def test_search_and_extract_with_crawling(self, mock_crawl_urlopen, mock_scholarly, mock_search_urlopen):
"""Test with crawling enabled."""
mock_search_resp = MagicMock()
mock_search_resp.read.return_value = b"""
KD Tutorial
A tutorial
"""
mock_search_urlopen.return_value = mock_search_resp
mock_crawl_resp = MagicMock()
mock_crawl_resp.read.return_value = (
b"KD Tutorial "
+ b"Tutorial content about knowledge distillation. " * 20
+ b"
"
)
mock_crawl_resp.headers = {"Content-Type": "text/html"}
mock_crawl_urlopen.return_value = mock_crawl_resp
mock_scholarly.search_pubs.return_value = iter([])
agent = WebSearchAgent(
enable_scholar=False,
enable_crawling=True,
enable_pdf=False,
max_crawl_urls=2,
)
result = agent.search_and_extract("knowledge distillation")
assert result.elapsed_seconds > 0
# ---------------------------------------------------------------------------
# Config integration
# ---------------------------------------------------------------------------
class TestWebSearchConfig:
def test_default_config(self):
from researchclaw.config import WebSearchConfig
cfg = WebSearchConfig()
assert cfg.enabled is True
assert cfg.max_web_results == 10
assert cfg.enable_scholar is True
def test_config_in_rcconfig(self):
from researchclaw.config import RCConfig
import dataclasses
field_names = [f.name for f in dataclasses.fields(RCConfig)]
assert "web_search" in field_names
================================================
FILE: tests/test_web_pdf_extractor.py
================================================
"""Tests for researchclaw.web.pdf_extractor — PDFExtractor."""
from __future__ import annotations
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.web.pdf_extractor import PDFContent, PDFExtractor
# ---------------------------------------------------------------------------
# PDFContent dataclass
# ---------------------------------------------------------------------------
class TestPDFContent:
def test_has_content_true(self):
c = PDFContent(path="test.pdf", text="x" * 200, success=True)
assert c.has_content
def test_has_content_false_empty(self):
c = PDFContent(path="test.pdf", text="", success=True)
assert not c.has_content
def test_has_content_false_short(self):
c = PDFContent(path="test.pdf", text="short", success=True)
assert not c.has_content
# ---------------------------------------------------------------------------
# PDFExtractor
# ---------------------------------------------------------------------------
class TestPDFExtractor:
def test_backend_detection(self):
extractor = PDFExtractor()
assert extractor.backend == "pymupdf" # PyMuPDF is now installed
def test_extract_nonexistent_file(self, tmp_path):
extractor = PDFExtractor()
result = extractor.extract(tmp_path / "does_not_exist.pdf")
assert not result.success or "not found" in result.error.lower() or result.error
def test_extract_abstract_pattern(self):
text = """
Some header text
Abstract
This paper presents a novel approach to knowledge distillation
that achieves state-of-the-art results on ImageNet.
1 Introduction
We begin by motivating our approach...
"""
abstract = PDFExtractor._extract_abstract(text)
assert "knowledge distillation" in abstract
def test_extract_abstract_no_match(self):
text = "No abstract section here, just random text."
abstract = PDFExtractor._extract_abstract(text)
assert abstract == ""
def test_detect_sections(self):
text = """
1. Introduction
This is the introduction section with some content.
2. Related Work
This covers prior work in the field.
3. Method
Our proposed approach works as follows.
4. Experiments
We evaluate on several benchmarks.
"""
sections = PDFExtractor._detect_sections(text)
assert len(sections) >= 3
headings = [s["heading"] for s in sections]
assert any("Introduction" in h for h in headings)
assert any("Related" in h or "Method" in h for h in headings)
def test_detect_sections_empty(self):
text = "No numbered sections here at all."
sections = PDFExtractor._detect_sections(text)
assert sections == []
@patch("researchclaw.web.pdf_extractor.urlopen")
def test_extract_from_url_failure(self, mock_urlopen):
mock_urlopen.side_effect = Exception("404 Not Found")
extractor = PDFExtractor()
result = extractor.extract_from_url("https://example.com/paper.pdf")
assert not result.success or result.error
================================================
FILE: tests/test_web_platform.py
================================================
"""Tests for Agent A — Web platform and user interface.
Covers: FastAPI routes, WebSocket, intents, dashboard collector, wizard, voice commands.
All tests run without external services (mocked LLM, mocked Whisper).
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import tempfile
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
class TestServerConfig:
"""Test ServerConfig and DashboardConfig in config.py."""
def test_server_config_defaults(self) -> None:
from researchclaw.config import ServerConfig
cfg = ServerConfig()
assert cfg.enabled is False
assert cfg.host == "0.0.0.0"
assert cfg.port == 8080
assert cfg.cors_origins == ("*",)
assert cfg.auth_token == ""
assert cfg.voice_enabled is False
def test_dashboard_config_defaults(self) -> None:
from researchclaw.config import DashboardConfig
cfg = DashboardConfig()
assert cfg.enabled is True
assert cfg.refresh_interval_sec == 5
assert cfg.max_log_lines == 1000
def test_parse_server_config(self) -> None:
from researchclaw.config import _parse_server_config
cfg = _parse_server_config({
"enabled": True,
"host": "127.0.0.1",
"port": 9090,
"auth_token": "secret123",
})
assert cfg.enabled is True
assert cfg.host == "127.0.0.1"
assert cfg.port == 9090
assert cfg.auth_token == "secret123"
def test_parse_server_config_empty(self) -> None:
from researchclaw.config import _parse_server_config
cfg = _parse_server_config({})
assert cfg.enabled is False
assert cfg.port == 8080
def test_parse_dashboard_config(self) -> None:
from researchclaw.config import _parse_dashboard_config
cfg = _parse_dashboard_config({
"refresh_interval_sec": 10,
"max_log_lines": 500,
})
assert cfg.refresh_interval_sec == 10
assert cfg.max_log_lines == 500
def test_rcconfig_has_server_and_dashboard(self) -> None:
from researchclaw.config import RCConfig, ServerConfig, DashboardConfig
# Build minimal valid config dict
data = {
"project": {"name": "test"},
"research": {"topic": "test topic"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "console"},
"knowledge_base": {"root": "knowledge"},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost",
"api_key_env": "TEST_KEY",
},
"server": {"enabled": True, "port": 9999},
"dashboard": {"refresh_interval_sec": 3},
}
cfg = RCConfig.from_dict(data, check_paths=False)
assert isinstance(cfg.server, ServerConfig)
assert cfg.server.enabled is True
assert cfg.server.port == 9999
assert isinstance(cfg.dashboard, DashboardConfig)
assert cfg.dashboard.refresh_interval_sec == 3
# ---------------------------------------------------------------------------
# CLI tests
# ---------------------------------------------------------------------------
class TestCLI:
"""Test new CLI subcommands are registered."""
def test_serve_subcommand_exists(self) -> None:
from researchclaw.cli import main
with pytest.raises(SystemExit) as exc:
main(["serve", "--help"])
assert exc.value.code == 0
def test_dashboard_subcommand_exists(self) -> None:
from researchclaw.cli import main
with pytest.raises(SystemExit) as exc:
main(["dashboard", "--help"])
assert exc.value.code == 0
def test_wizard_subcommand_exists(self) -> None:
from researchclaw.cli import main
with pytest.raises(SystemExit) as exc:
main(["wizard", "--help"])
assert exc.value.code == 0
# ---------------------------------------------------------------------------
# Intent classification tests
# ---------------------------------------------------------------------------
class TestIntents:
"""Test intent classification."""
def test_help_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, conf = classify_intent("help")
assert intent == Intent.HELP
def test_status_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("What stage are we at?")
assert intent == Intent.CHECK_STATUS
def test_start_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("Start the pipeline")
assert intent == Intent.START_PIPELINE
def test_topic_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("Help me find a research direction")
assert intent == Intent.TOPIC_SELECTION
def test_results_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("What are the results?")
assert intent == Intent.DISCUSS_RESULTS
def test_config_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("Change the learning rate to 0.001")
assert intent == Intent.MODIFY_CONFIG
def test_paper_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("Edit the abstract")
assert intent == Intent.EDIT_PAPER
def test_general_intent(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("Hello there")
assert intent == Intent.GENERAL_CHAT
def test_chinese_status(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("现在到哪一步了")
assert intent == Intent.CHECK_STATUS
def test_chinese_start(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, _ = classify_intent("开始跑实验")
assert intent == Intent.START_PIPELINE
def test_empty_message(self) -> None:
from researchclaw.server.dialog.intents import Intent, classify_intent
intent, conf = classify_intent("")
assert intent == Intent.GENERAL_CHAT
assert conf == 0.0
# ---------------------------------------------------------------------------
# Session management tests
# ---------------------------------------------------------------------------
class TestSession:
"""Test chat session management."""
def test_session_create(self) -> None:
from researchclaw.server.dialog.session import SessionManager
mgr = SessionManager()
session = mgr.get_or_create("client1")
assert session.client_id == "client1"
assert len(session.history) == 0
def test_session_add_message(self) -> None:
from researchclaw.server.dialog.session import SessionManager
mgr = SessionManager()
session = mgr.get_or_create("client1")
session.add_message("user", "Hello")
session.add_message("assistant", "Hi!")
assert len(session.history) == 2
assert session.history[0].role == "user"
def test_session_context(self) -> None:
from researchclaw.server.dialog.session import SessionManager
mgr = SessionManager()
session = mgr.get_or_create("client1")
for i in range(20):
session.add_message("user", f"msg {i}")
ctx = session.get_context(last_n=5)
assert len(ctx) == 5
def test_session_max_history(self) -> None:
from researchclaw.server.dialog.session import ChatSession
session = ChatSession(client_id="test")
for i in range(100):
session.add_message("user", f"msg {i}")
assert len(session.history) <= session.MAX_HISTORY
def test_session_persistence(self) -> None:
from researchclaw.server.dialog.session import SessionManager
with tempfile.TemporaryDirectory() as tmpdir:
mgr = SessionManager(persist_dir=tmpdir)
session = mgr.get_or_create("persist-test")
session.add_message("user", "saved message")
mgr.save("persist-test")
# Load in new manager
mgr2 = SessionManager(persist_dir=tmpdir)
loaded = mgr2.load("persist-test")
assert loaded is not None
assert len(loaded.history) == 1
assert loaded.history[0].content == "saved message"
# ---------------------------------------------------------------------------
# Dashboard collector tests
# ---------------------------------------------------------------------------
class TestDashboardCollector:
"""Test dashboard data collection from artifacts/."""
def test_collect_empty_dir(self) -> None:
from researchclaw.dashboard.collector import DashboardCollector
with tempfile.TemporaryDirectory() as tmpdir:
collector = DashboardCollector(artifacts_dir=tmpdir)
runs = collector.collect_all()
assert runs == []
def test_collect_run_with_checkpoint(self) -> None:
from researchclaw.dashboard.collector import DashboardCollector
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "rc-20260315-abc123"
run_dir.mkdir()
ckpt = {"stage": 5, "stage_name": "LITERATURE_SCREEN", "status": "running"}
(run_dir / "checkpoint.json").write_text(json.dumps(ckpt))
collector = DashboardCollector(artifacts_dir=tmpdir)
runs = collector.collect_all()
assert len(runs) == 1
assert runs[0].current_stage == 5
assert runs[0].current_stage_name == "LITERATURE_SCREEN"
def test_collect_run_active_heartbeat(self) -> None:
from researchclaw.dashboard.collector import DashboardCollector
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "rc-20260315-test01"
run_dir.mkdir()
hb = {"timestamp": time.time()} # fresh heartbeat
(run_dir / "heartbeat.json").write_text(json.dumps(hb))
collector = DashboardCollector(artifacts_dir=tmpdir)
runs = collector.collect_all()
assert len(runs) == 1
assert runs[0].is_active is True
def test_collect_run_stale_heartbeat(self) -> None:
from researchclaw.dashboard.collector import DashboardCollector
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "rc-20260315-stale1"
run_dir.mkdir()
hb = {"timestamp": time.time() - 120} # old heartbeat
(run_dir / "heartbeat.json").write_text(json.dumps(hb))
collector = DashboardCollector(artifacts_dir=tmpdir)
runs = collector.collect_all()
assert runs[0].is_active is False
def test_collect_stage_directories(self) -> None:
from researchclaw.dashboard.collector import DashboardCollector
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "rc-20260315-stages"
run_dir.mkdir()
(run_dir / "stage-01").mkdir()
(run_dir / "stage-02").mkdir()
(run_dir / "stage-03").mkdir()
collector = DashboardCollector(artifacts_dir=tmpdir)
runs = collector.collect_all()
assert len(runs[0].stages_completed) == 3
def test_collect_metrics(self) -> None:
from researchclaw.dashboard.collector import DashboardCollector
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "rc-20260315-metric"
run_dir.mkdir()
metrics = {"accuracy": 0.85, "loss": 0.12}
(run_dir / "results.json").write_text(json.dumps(metrics))
collector = DashboardCollector(artifacts_dir=tmpdir)
runs = collector.collect_all()
assert runs[0].metrics["accuracy"] == 0.85
def test_snapshot_to_dict(self) -> None:
from researchclaw.dashboard.collector import RunSnapshot
snap = RunSnapshot(run_id="test-1", path="/tmp/test")
d = snap.to_dict()
assert d["run_id"] == "test-1"
assert "current_stage" in d
# ---------------------------------------------------------------------------
# Metrics tests
# ---------------------------------------------------------------------------
class TestMetrics:
"""Test metric aggregation."""
def test_aggregate_empty(self) -> None:
from researchclaw.dashboard.metrics import aggregate_metrics
result = aggregate_metrics([])
assert result["total_runs"] == 0
def test_aggregate_mixed(self) -> None:
from researchclaw.dashboard.metrics import aggregate_metrics
runs = [
{"is_active": True, "status": "running", "current_stage": 10},
{"is_active": False, "status": "completed", "current_stage": 23},
{"is_active": False, "status": "failed", "current_stage": 5},
]
result = aggregate_metrics(runs)
assert result["total_runs"] == 3
assert result["active_runs"] == 1
assert result["completed_runs"] == 1
assert result["failed_runs"] == 1
def test_extract_training_curve(self) -> None:
from researchclaw.dashboard.metrics import extract_training_curve
metrics = {
"training_log": [
{"epoch": 1, "loss": 0.5, "accuracy": 0.7},
{"epoch": 2, "loss": 0.3, "accuracy": 0.85},
]
}
curve = extract_training_curve(metrics)
assert len(curve) == 2
assert curve[1]["loss"] == 0.3
# ---------------------------------------------------------------------------
# Voice command tests
# ---------------------------------------------------------------------------
class TestVoiceCommands:
"""Test voice command parsing."""
def test_start_command(self) -> None:
from researchclaw.voice.commands import VoiceCommand, parse_voice_input
result = parse_voice_input("start experiment")
assert result.command == VoiceCommand.START
def test_stop_command(self) -> None:
from researchclaw.voice.commands import VoiceCommand, parse_voice_input
result = parse_voice_input("stop")
assert result.command == VoiceCommand.STOP
def test_chinese_start(self) -> None:
from researchclaw.voice.commands import VoiceCommand, parse_voice_input
result = parse_voice_input("开始实验")
assert result.command == VoiceCommand.START
def test_chinese_pause(self) -> None:
from researchclaw.voice.commands import VoiceCommand, parse_voice_input
result = parse_voice_input("暂停")
assert result.command == VoiceCommand.PAUSE
def test_not_a_command(self) -> None:
from researchclaw.voice.commands import VoiceCommand, parse_voice_input
result = parse_voice_input("What about the neural network?")
assert result.command == VoiceCommand.NONE
def test_status_command(self) -> None:
from researchclaw.voice.commands import VoiceCommand, parse_voice_input
result = parse_voice_input("查看进度")
assert result.command == VoiceCommand.STATUS
# ---------------------------------------------------------------------------
# Wizard tests
# ---------------------------------------------------------------------------
class TestWizard:
"""Test wizard templates and validation."""
def test_list_templates(self) -> None:
from researchclaw.wizard.templates import list_templates
templates = list_templates()
assert len(templates) >= 3
names = [t["name"] for t in templates]
assert "quick-demo" in names
assert "standard-cv" in names
def test_get_template(self) -> None:
from researchclaw.wizard.templates import get_template
tpl = get_template("quick-demo")
assert tpl is not None
assert tpl["experiment.mode"] == "simulated"
def test_get_template_missing(self) -> None:
from researchclaw.wizard.templates import get_template
assert get_template("nonexistent") is None
def test_wizard_web_mode(self) -> None:
from researchclaw.wizard.quickstart import QuickStartWizard
wizard = QuickStartWizard()
config = wizard.run_web([
{"key": "project_name", "value": "test-proj"},
{"key": "topic", "value": "neural scaling laws"},
{"key": "mode", "value": "docker"},
])
assert config.get("project", {}).get("name") == "test-proj"
assert config.get("research", {}).get("topic") == "neural scaling laws"
def test_environment_detection(self) -> None:
from researchclaw.wizard.validator import detect_environment
report = detect_environment()
assert report.has_python is True
assert report.python_version != ""
d = report.to_dict()
assert "has_gpu" in d
assert "recommendations" in d
# ---------------------------------------------------------------------------
# WebSocket events tests
# ---------------------------------------------------------------------------
class TestEvents:
"""Test WebSocket event types."""
def test_event_serialization(self) -> None:
from researchclaw.server.websocket.events import Event, EventType
evt = Event(type=EventType.STAGE_COMPLETE, data={"stage": 5})
json_str = evt.to_json()
parsed = json.loads(json_str)
assert parsed["type"] == "stage_complete"
assert parsed["data"]["stage"] == 5
def test_event_deserialization(self) -> None:
from researchclaw.server.websocket.events import Event, EventType
raw = json.dumps({
"type": "heartbeat",
"data": {"active_clients": 3},
"timestamp": 1234567890.0,
})
evt = Event.from_json(raw)
assert evt.type == EventType.HEARTBEAT
assert evt.data["active_clients"] == 3
def test_event_types_enum(self) -> None:
from researchclaw.server.websocket.events import EventType
assert EventType.CONNECTED.value == "connected"
assert EventType.STAGE_START.value == "stage_start"
assert EventType.CHAT_RESPONSE.value == "chat_response"
# ---------------------------------------------------------------------------
# Dialog router tests
# ---------------------------------------------------------------------------
class TestDialogRouter:
"""Test dialog message routing."""
@pytest.mark.asyncio
async def test_route_help_message(self) -> None:
from researchclaw.server.dialog.router import route_message
response = await route_message("help", "test-client")
assert "help" in response.lower() or "I can" in response
@pytest.mark.asyncio
async def test_route_json_message(self) -> None:
from researchclaw.server.dialog.router import route_message
msg = json.dumps({"message": "help me"})
response = await route_message(msg, "test-client-2")
assert isinstance(response, str)
assert len(response) > 0
@pytest.mark.asyncio
async def test_route_status_message(self) -> None:
from researchclaw.server.dialog.router import route_message
response = await route_message("What's the current progress?", "test-client-3")
assert isinstance(response, str)
# ---------------------------------------------------------------------------
# FastAPI app tests (requires fastapi + httpx)
# ---------------------------------------------------------------------------
class TestFastAPIApp:
"""Test FastAPI application if dependencies are available."""
@pytest.fixture
def _skip_if_no_fastapi(self) -> None:
try:
import fastapi
import httpx
except ImportError:
pytest.skip("fastapi/httpx not installed")
@pytest.fixture
def app(self, _skip_if_no_fastapi: None) -> object:
from researchclaw.config import RCConfig
data = {
"project": {"name": "test"},
"research": {"topic": "test"},
"runtime": {"timezone": "UTC"},
"notifications": {"channel": "console"},
"knowledge_base": {"root": "knowledge"},
"llm": {
"provider": "openai-compatible",
"base_url": "http://localhost",
"api_key_env": "TEST",
},
}
config = RCConfig.from_dict(data, check_paths=False)
from researchclaw.server.app import create_app
return create_app(config)
@pytest.mark.asyncio
async def test_health_endpoint(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.get("/api/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
@pytest.mark.asyncio
async def test_config_endpoint(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.get("/api/config")
assert resp.status_code == 200
data = resp.json()
assert data["project"] == "test"
@pytest.mark.asyncio
async def test_pipeline_status_idle(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.get("/api/pipeline/status")
assert resp.status_code == 200
assert resp.json()["status"] == "idle"
@pytest.mark.asyncio
async def test_pipeline_stages(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.get("/api/pipeline/stages")
assert resp.status_code == 200
stages = resp.json()["stages"]
assert len(stages) == 23
@pytest.mark.asyncio
async def test_runs_list(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.get("/api/runs")
assert resp.status_code == 200
assert "runs" in resp.json()
@pytest.mark.asyncio
async def test_projects_list(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.get("/api/projects")
assert resp.status_code == 200
assert "projects" in resp.json()
@pytest.mark.asyncio
async def test_stop_pipeline_404_when_idle(self, app: object) -> None:
from httpx import AsyncClient, ASGITransport
transport = ASGITransport(app=app) # type: ignore[arg-type]
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.post("/api/pipeline/stop")
assert resp.status_code == 404
================================================
FILE: tests/test_web_scholar.py
================================================
"""Tests for researchclaw.web.scholar — GoogleScholarClient."""
from __future__ import annotations
import time
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.web.scholar import GoogleScholarClient, ScholarPaper
# ---------------------------------------------------------------------------
# ScholarPaper dataclass
# ---------------------------------------------------------------------------
class TestScholarPaper:
def test_to_dict(self):
p = ScholarPaper(
title="Attention Is All You Need",
authors=["Vaswani", "Shazeer"],
year=2017,
citation_count=50000,
)
d = p.to_dict()
assert d["title"] == "Attention Is All You Need"
assert d["year"] == 2017
assert d["source"] == "google_scholar"
def test_to_literature_paper(self):
p = ScholarPaper(
title="Test Paper",
authors=["Author One", "Author Two"],
year=2024,
abstract="An abstract.",
citation_count=100,
url="https://example.com",
)
lit = p.to_literature_paper()
assert lit.title == "Test Paper"
assert lit.source == "google_scholar"
assert len(lit.authors) == 2
assert lit.authors[0].name == "Author One"
# ---------------------------------------------------------------------------
# GoogleScholarClient
# ---------------------------------------------------------------------------
class TestGoogleScholarClient:
@patch("researchclaw.web.scholar.HAS_SCHOLARLY", True)
def test_available_always_true(self):
"""scholarly is now an installed dependency, always available."""
client = GoogleScholarClient()
assert client.available
def test_parse_pub_full(self):
"""Test _parse_pub with a complete publication dict."""
pub = {
"bib": {
"title": "Deep Learning",
"author": ["LeCun", "Bengio", "Hinton"],
"pub_year": "2015",
"abstract": "Deep learning review.",
"venue": "Nature",
},
"num_citations": 30000,
"pub_url": "https://nature.com/dl",
"cites_id": ["abc123"],
}
paper = GoogleScholarClient._parse_pub(pub)
assert paper.title == "Deep Learning"
assert paper.year == 2015
assert paper.citation_count == 30000
assert "LeCun" in paper.authors
assert paper.venue == "Nature"
def test_parse_pub_string_authors(self):
pub = {
"bib": {
"title": "Paper",
"author": "Smith and Jones",
"pub_year": "2023",
},
"num_citations": 10,
"pub_url": "https://example.com",
}
paper = GoogleScholarClient._parse_pub(pub)
assert paper.title == "Paper"
assert "Smith" in paper.authors
assert "Jones" in paper.authors
def test_parse_pub_missing_fields(self):
pub = {"bib": {}, "num_citations": 0}
paper = GoogleScholarClient._parse_pub(pub)
assert paper.title == ""
assert paper.year == 0
assert paper.authors == []
@patch("researchclaw.web.scholar.HAS_SCHOLARLY", True)
def test_rate_limiting(self):
client = GoogleScholarClient(inter_request_delay=0.01)
t0 = time.monotonic()
client._rate_limit()
client._rate_limit()
elapsed = time.monotonic() - t0
assert elapsed >= 0.01
@patch("researchclaw.web.scholar.HAS_SCHOLARLY", True)
@patch("researchclaw.web.scholar.scholarly")
def test_search_with_mocked_scholarly(self, mock_scholarly):
"""Test search using mocked scholarly library."""
mock_pub = {
"bib": {
"title": "Test Paper",
"author": ["Author A"],
"pub_year": "2024",
},
"num_citations": 5,
"pub_url": "https://example.com",
}
mock_scholarly.search_pubs.return_value = iter([mock_pub])
client = GoogleScholarClient(inter_request_delay=0.0)
results = client.search("test query", limit=5)
assert len(results) == 1
assert results[0].title == "Test Paper"
@patch("researchclaw.web.scholar.HAS_SCHOLARLY", True)
@patch("researchclaw.web.scholar.scholarly")
def test_search_error_graceful(self, mock_scholarly):
"""Search should return empty list on error, not raise."""
mock_scholarly.search_pubs.side_effect = Exception("Rate limited")
client = GoogleScholarClient(inter_request_delay=0.0)
results = client.search("test query")
assert results == []
================================================
FILE: tests/test_web_search.py
================================================
"""Tests for researchclaw.web.search — WebSearchClient."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from researchclaw.web.search import SearchResult, WebSearchClient, WebSearchResponse
# ---------------------------------------------------------------------------
# SearchResult dataclass
# ---------------------------------------------------------------------------
class TestSearchResult:
def test_to_dict(self):
r = SearchResult(
title="Test", url="https://example.com", snippet="A snippet", source="tavily"
)
d = r.to_dict()
assert d["title"] == "Test"
assert d["url"] == "https://example.com"
assert d["source"] == "tavily"
# ---------------------------------------------------------------------------
# WebSearchResponse dataclass
# ---------------------------------------------------------------------------
class TestWebSearchResponse:
def test_has_results_true(self):
r = WebSearchResponse(
query="test", results=[SearchResult(title="A", url="u")],
)
assert r.has_results
def test_has_results_false(self):
r = WebSearchResponse(query="test")
assert not r.has_results
# ---------------------------------------------------------------------------
# DuckDuckGo HTML parsing
# ---------------------------------------------------------------------------
class TestDDGParsing:
def test_parse_ddg_html_basic(self):
html = """
"""
results = WebSearchClient._parse_ddg_html(html, limit=10)
assert len(results) == 2
assert results[0].title == "Title One"
assert results[0].url == "https://example.com/1"
assert results[0].snippet == "Snippet one here"
def test_parse_ddg_html_skips_ddg_links(self):
html = """
DDG Link
Real
"""
results = WebSearchClient._parse_ddg_html(html, limit=10)
assert len(results) == 1
assert results[0].url == "https://example.com/real"
def test_parse_ddg_html_respects_limit(self):
html = ""
for i in range(20):
html += f'T{i}\n'
results = WebSearchClient._parse_ddg_html(html, limit=5)
assert len(results) == 5
# ---------------------------------------------------------------------------
# WebSearchClient.search
# ---------------------------------------------------------------------------
class TestWebSearchClient:
@patch("researchclaw.web.search.urlopen")
def test_search_ddg_fallback_no_api_key(self, mock_urlopen):
"""When no API key is set, uses DuckDuckGo fallback."""
mock_resp = MagicMock()
mock_resp.read.return_value = b"""
Paper Title
About the paper
"""
mock_urlopen.return_value = mock_resp
client = WebSearchClient(api_key="") # No API key
response = client.search("test query")
assert response.source == "duckduckgo"
@patch("researchclaw.web.search.urlopen")
def test_search_ddg_error_graceful(self, mock_urlopen):
mock_urlopen.side_effect = Exception("Network error")
client = WebSearchClient(api_key="")
response = client.search("test query")
assert response.source == "duckduckgo"
assert len(response.results) == 0
def test_search_tavily_with_mock(self):
"""Test Tavily search with mocked SDK."""
mock_client_instance = MagicMock()
mock_client_instance.search.return_value = {
"results": [
{
"title": "Tavily Result",
"url": "https://tavily.com/r1",
"content": "Content from Tavily",
"score": 0.95,
}
],
"answer": "AI summary answer",
}
mock_tavily_module = MagicMock()
mock_tavily_module.TavilyClient.return_value = mock_client_instance
with patch.dict("sys.modules", {"tavily": mock_tavily_module}):
client = WebSearchClient(api_key="test-key")
import time
response = client._search_tavily("test query", 10, None, None, time.monotonic())
assert response.source == "tavily"
assert len(response.results) == 1
assert response.results[0].title == "Tavily Result"
assert response.answer == "AI summary answer"
@patch("researchclaw.web.search.urlopen")
def test_search_multi_deduplication(self, mock_urlopen):
mock_resp = MagicMock()
mock_resp.read.return_value = b"""
Same Result
"""
mock_urlopen.return_value = mock_resp
client = WebSearchClient(api_key="")
responses = client.search_multi(["query1", "query2"], inter_query_delay=0.0)
assert len(responses) == 2
# Second query should have same URL deduped
if responses[0].results:
assert all(
r.url != responses[0].results[0].url
for r in responses[1].results
)
================================================
FILE: website/features.html
================================================
Features — AutoResearchClaw
Features
Everything you need for autonomous research paper generation, built for reliability and quality.
Multi-Source Literature Search
Searches OpenAlex (primary, 10K/day), Semantic Scholar, and arXiv in parallel. Intelligent source fallback ensures results even when individual APIs are rate-limited.
Rate Limit Defense
Five-layer defense: adaptive rate limiter, three-state circuit breaker, multi-source fallback, intelligent caching with per-source TTL, and request optimization via S2 batch API.
Docker Sandbox with GPU
Experiments run in isolated Docker containers based on nvidia/cuda:12.4.1 with PyTorch, GPU passthrough, network sandboxing, and pre-cached datasets (CIFAR-10, FashionMNIST).
Hardware-Aware Design
Automatically detects available GPU memory and adjusts experiment parameters (batch size, model size, training epochs) to fit within hardware constraints.
Multi-Agent Peer Review
Simulated conference-style peer review with multiple LLM reviewer personas providing structured feedback on technical soundness, methodology, and clarity.
Pivot / Refine / Proceed
After analyzing experiment results, the pipeline autonomously decides whether to proceed with paper writing, refine the experiment, or pivot to a new hypothesis (max 2 pivots).
Experiment Charts
Automatically generates publication-quality comparison charts, filtering out timing/meta metrics. Supports bar charts, learning curves, and ablation visualizations.
Conference-Ready LaTeX
Outputs publication-quality LaTeX with proper bibliography, figure placement, and conference formatting (NeurIPS, ICLR, ICML templates).
Citation Verification
Every cited paper is verified against real academic databases (CrossRef, OpenAlex, arXiv, Semantic Scholar) in optimized order to minimize API pressure.
Result Caching
Per-source cache TTL: arXiv results cached 24h (daily metadata updates), S2/OpenAlex 3 days, citation verification results cached permanently.
Seminal Paper Library
Built-in seed library of foundational ML papers (normalization, ResNets, transformers, etc.) injected during literature search to ensure key references are cited.
Code Security Validation
Generated experiment code is validated for security (no network access, no subprocess calls, no file system writes outside workspace) before Docker execution.
Contradiction Detection
Automatically detects contradictions in experiment results: null findings, negative results, and cases where control outperforms proposed method.
Quality Assessment
Built-in quality scoring across novelty, soundness, significance, clarity, and reproducibility. Papers below threshold trigger rewriting.
Knowledge Archive
Research findings are archived in a persistent knowledge base (Markdown-backed) for cross-project knowledge transfer and future reference.
LLM Fine-Tuning
Optional QLoRA/LoRA fine-tuning support for adapting language models to specific research domains and writing styles.
How We Compare
AutoResearchClaw vs. other autonomous research tools
Feature
AutoResearchClaw
PaperClaw
Sibyl
Idea2Paper
Literature search
3 APIs + cache
2 APIs
arXiv only
Offline KG
Rate limit handling
Circuit breaker + fallback
Exponential backoff
None
N/A
Code execution
Docker + GPU
No
No
No
Peer review
Multi-agent
No
Single agent
No
Citation verification
4 API sources
No
No
No
Pipeline stages
23
~8
~5
~6
================================================
FILE: website/getting-started.html
================================================
Get Started — AutoResearchClaw
Get Started
From zero to your first autonomous research paper in minutes.
0 Prerequisites
- ☑ Python 3.10+
- ☑ Docker with NVIDIA Container Toolkit (for GPU experiments)
- ☑ An OpenAI-compatible API key (Azure OpenAI, OpenAI, or local LLM)
- ☑ NVIDIA GPU with 8GB+ VRAM (optional, for Docker sandbox)
1 Clone the Repository
git clone https://github.com/aiming-lab/AutoResearchClaw.git
cd AutoResearchClaw
2 Install Dependencies
pip install -e .
This installs the researchclaw package and all required dependencies.
3 Configure Your LLM
Create a YAML config file (e.g., config.yaml) with your LLM settings:
# config.yaml
project:
name: "my-first-paper"
mode: "docs-first"
research:
topic: "Your research topic here"
llm:
provider: "openai-compatible"
base_url: "https://api.openai.com/v1"
api_key_env: "OPENAI_API_KEY"
experiment:
backend: "docker" # or "subprocess" for local
timeout_sec: 1800
4 Set Your API Key
export OPENAI_API_KEY="sk-your-key-here"
5 Build the Docker Image (Optional)
If using the Docker sandbox backend for GPU-accelerated experiments:
docker build -t researchclaw-sandbox -f researchclaw/docker/Dockerfile .
6 Run Your First Paper
python -m researchclaw run --config config.yaml
The pipeline will execute all 23 stages autonomously. Output will be saved
to the output/ directory including the paper PDF, LaTeX source,
experiment code, and charts.
7 Review Your Paper
After the pipeline completes, find your generated paper at:
output/<run-id>/
paper.pdf # Final PDF
paper.tex # LaTeX source
references.bib # Bibliography
code/main.py # Experiment code
charts/ # Generated figures
results.json # Experiment metrics
Tips
- Use GPT-4.1 or newer for best paper quality
- Set
timeout_sec: 3600 for complex experiments
- For Azure OpenAI, set
provider: "azure_openai" and configure your endpoint
- The pipeline caches literature results, so re-runs with the same topic are faster
- Run
python -m pytest tests/ -v to verify your installation
================================================
FILE: website/index.html
================================================
AutoResearchClaw — Autonomous Research Paper Generation
Chat an Idea.
Get a Paper.
AutoResearchClaw is a fully autonomous 23-stage pipeline that transforms a research topic
into a conference-ready paper — with real experiments, GPU-accelerated code,
and verified citations.
# one command, one paper
python -m researchclaw run --topic "your research idea"
23
Autonomous Stages
1117
Tests Passing
3
Literature APIs
GPU
Docker Sandbox
From Idea to Paper in 23 Steps
Eight autonomous phases transform a research topic into a publication-ready manuscript.
A: Research Scoping
Topic initialization, problem decomposition, and scope definition.
B: Literature Discovery
Multi-source paper search via OpenAlex, Semantic Scholar, and arXiv with quality screening.
C: Knowledge Synthesis
Gap analysis, trend synthesis, and novel hypothesis generation.
D: Experiment Design
Methodology design, code generation, and resource planning with hardware awareness.
E: Experiment Execution
GPU-accelerated Docker sandbox execution with iterative refinement.
F: Analysis & Decision
Result analysis with pivot/refine/proceed decisions.
G: Paper Writing
Structured drafting, multi-agent peer review, and iterative revision.
H: Finalization
Quality gate, knowledge archival, LaTeX export, and citation verification.
Key Features
Built for serious research, engineered for reliability.
Real Literature Search
Multi-source search across OpenAlex, Semantic Scholar, and arXiv with circuit breakers, rate limiting, and intelligent caching.
Docker Sandbox + GPU
Experiments run in isolated Docker containers with NVIDIA GPU passthrough, network sandboxing, and automatic dependency management.
Multi-Agent Peer Review
Simulated conference-style peer review with multiple reviewer personas providing structured feedback for revision.
Iterative Refinement
Automatic pivot/refine/proceed decisions with rollback to any previous stage based on experiment outcomes.
Conference-Ready LaTeX
Publication-quality LaTeX output with proper citations, experiment charts, and structured abstracts.
Citation Verification
All citations verified against CrossRef, OpenAlex, and arXiv APIs to ensure bibliography accuracy.
Showcase Papers
Papers generated entirely by the pipeline, from topic to camera-ready PDF.
📄
Curriculum Learning with Adaptive Difficulty Scheduling for Image Classification
Investigates adaptive curriculum strategies on CIFAR-10/100 benchmarks, demonstrating improved convergence speed and final accuracy compared to standard training.
📄
Test-Time Adaptation via Batch Normalization Statistics for Distribution Shift
Explores test-time adaptation methods using batch normalization statistics to handle distribution shift on CIFAR-10-C corruption benchmarks.
📄
Entropy-Guided Exploration Bonuses for Sparse-Reward Continuous Control
Proposes entropy-guided intrinsic reward bonuses to improve exploration in sparse-reward MuJoCo locomotion environments.
System Architecture
End-to-end pipeline architecture from topic input to published paper.
Ready to Generate Your First Paper?
Clone the repo, configure your LLM API key, and run your first autonomous research paper.
================================================
FILE: website/papers.html
================================================
Showcase Papers — AutoResearchClaw
Showcase Papers
Every paper below was generated entirely by AutoResearchClaw — from a single topic prompt to a complete research manuscript with real experiments.
📊
Curriculum Learning with Adaptive Difficulty Scheduling for Image Classification
Investigates adaptive curriculum learning strategies on CIFAR-10/100 benchmarks.
Proposes a difficulty-aware scheduling mechanism that dynamically adjusts training
sample ordering to improve convergence speed and final accuracy.
💬
Prompt-Length-Aware Routing for Mixture-of-LoRA Experts in Instruction-Following
Proposes a routing mechanism for Mixture-of-LoRA experts that considers prompt
length characteristics. Fine-tunes Qwen-2.5-3B with QLoRA to demonstrate
improved instruction-following across varying input lengths.
🧬
Graph Attention Networks with Learnable Edge Features for Molecular Property Prediction
Extends graph attention networks with learnable edge feature transformations
for molecular property prediction on the OGB-MolHIV benchmark, achieving
competitive performance with existing specialized architectures.
🎮
Entropy-Guided Exploration Bonuses for Sparse-Reward Continuous Control
Proposes entropy-guided intrinsic reward bonuses to improve exploration
efficiency in sparse-reward MuJoCo locomotion environments. Demonstrates
improved sample efficiency over baseline algorithms.
🎨
Spectral Normalization Effects on Mode Collapse in Conditional GANs for CIFAR-10
Systematically studies the effect of spectral normalization on mode collapse
in conditional GANs trained on CIFAR-10, providing both visual and quantitative
analysis (FID, IS) of generation diversity.
🔄
Test-Time Adaptation via Batch Normalization Statistics for Distribution Shift
Explores lightweight test-time adaptation methods that update batch normalization
statistics to handle distribution shift on CIFAR-10-C corruption benchmarks,
demonstrating practical robustness improvements.
Papers Coming Soon
We're generating showcase papers across diverse ML subfields. Each paper will include a downloadable PDF, LaTeX source, experiment code, and quality assessment. Check back soon!
================================================
FILE: website/pipeline.html
================================================
Pipeline — AutoResearchClaw
The 23-Stage Pipeline
Click any stage to expand its description. Yellow badges mark gate stages that require quality checks before proceeding.
A
Research Scoping
1
Topic Initialization
Define research topic, scope, and target conference
Takes a user-provided topic prompt and generates a structured research plan including target conference, research questions, and expected contributions. Emphasizes novelty and alignment with recent conference trends.
LLM
2
Problem Decomposition
Break research into sub-problems and objectives
Decomposes the research topic into concrete sub-problems, defines evaluation criteria, and identifies the key technical challenges to address.
LLM
B
Literature Discovery
3
Search Strategy
Generate search queries and select paper sources
Generates targeted search queries from the research plan, selects which APIs to query (OpenAlex, Semantic Scholar, arXiv), and defines inclusion/exclusion criteria.
LLM
4
Literature Collect
Search OpenAlex, Semantic Scholar, and arXiv
Executes multi-source literature search with intelligent caching, circuit breakers, and rate limiting. Deduplicates results across sources and injects seminal papers from the seed library.
API
5
Literature Screen
Quality and relevance screening (Gate)
Gate Stage. LLM reviews each collected paper for relevance, quality, and domain match. Cross-domain false positives are explicitly rejected. Papers below threshold are filtered out.
Gate
LLM
6
Knowledge Extract
Extract key insights and methodologies from papers
Extracts structured knowledge from screened papers: key contributions, methods, results, limitations, and open questions. Builds a knowledge graph for synthesis.
LLM
C
Knowledge Synthesis
7
Synthesis
Gap analysis and research trend synthesis
Clusters extracted knowledge by topic, identifies research gaps, and synthesizes trends. Produces a structured literature review summary that informs hypothesis generation.
LLM
8
Hypothesis Generation
Generate testable research hypotheses
Generates novel, testable hypotheses that address gaps not covered by existing literature. Each hypothesis includes expected outcomes, evaluation metrics, and ablation dimensions.
LLM
D
Experiment Design
9
Experiment Design
Methodology design and validation (Gate)
Gate Stage. Designs the complete experimental methodology: baselines, ablations, metrics, datasets, and statistical tests. Requires modern benchmarks and real datasets (CIFAR-10, etc.).
Gate
LLM
10
Code Generation
Generate executable experiment code
Generates complete Python experiment code (main.py) with dataset loading, model definition, training loop, evaluation, and results output. Includes security validation, import checking, and code review.
LLM
New
11
Resource Planning
Estimate compute budget and time allocation
Estimates GPU memory requirements, training time, and compute budget. Configures Docker sandbox resource limits and timeout values based on available hardware.
LLM
E
Experiment Execution
12
Experiment Run
Execute experiments in Docker sandbox
Runs generated code inside an isolated Docker container with NVIDIA GPU passthrough. Captures stdout metrics, timing data, and exit codes. Pre-cached datasets available at /workspace/data.
Docker
13
Iterative Refinement
Fix errors and improve experiment code
If experiment fails or produces poor results, automatically diagnoses issues and generates refined code. Checks ablation effectiveness (>5% difference from baseline). Up to 3 refinement iterations.
LLM
Docker
New
F
Analysis & Decision
14
Result Analysis
Statistical analysis of experiment outcomes
Parses experiment outputs, computes statistical significance, generates comparison charts, and produces structured results summaries. Detects result contradictions and null findings.
LLM
15
Research Decision
Pivot, refine, or proceed based on results
Evaluates experiment results and decides: Proceed (results support hypothesis), Refine (re-run with improvements), or Pivot (discard hypothesis, generate new one). Max 2 pivots to prevent infinite loops.
LLM
G
Paper Writing
16
Paper Outline
Structure paper sections and arguments
Creates a detailed paper outline with section-by-section arguments, key claims, and figure placements. Follows conference template structure (abstract, intro, related work, method, experiments, conclusion).
LLM
17
Paper Draft
Write the full paper draft
Generates the complete paper in Markdown/LaTeX with structured writing rules: 150-200 word abstract, no number repetition across sections, proper citation of original papers for all discussed techniques.
LLM
18
Peer Review
Multi-agent simulated conference review
Multiple LLM reviewer personas evaluate the paper: one technical reviewer, one methodology expert, and one clarity/presentation reviewer. Each provides structured feedback with scores.
LLM
19
Paper Revision
Revise based on peer review feedback
Addresses reviewer comments systematically: fixes technical issues, improves writing clarity, adds missing comparisons, and strengthens the narrative. Produces a revised draft.
LLM
New
H
Finalization
20
Quality Gate
Final quality assessment (Gate)
Gate Stage. Comprehensive quality assessment scoring the paper on novelty, soundness, significance, clarity, and reproducibility. Papers below threshold are sent back for rewriting.
Gate
LLM
21
Knowledge Archive
Archive findings to knowledge base
Stores research findings, methodology, and results in the persistent knowledge base for future reference and cross-project knowledge transfer. Non-critical: failure doesn't abort pipeline.
LLM
22
Export & Publish
Generate LaTeX PDF and final output
Converts the paper to conference-ready LaTeX, compiles to PDF, generates BibTeX bibliography, and produces the final output package (paper.pdf, main.tex, references.bib, charts/).
LaTeX
23
Citation Verify
Verify all citations against real databases
Verifies every cited paper exists in real academic databases. Checks DOI via CrossRef, title via OpenAlex, arXiv ID via arXiv API, and falls back to Semantic Scholar. Non-critical: failure doesn't abort pipeline.
API
================================================
FILE: website/style.css
================================================
/* ============================================================
AutoResearchClaw — Showcase Website Styles
Pure CSS, no build step. Tailwind-inspired utility patterns.
============================================================ */
/* ---------- Reset & Variables ---------- */
:root {
--color-bg: #0f172a;
--color-bg-alt: #1e293b;
--color-surface: #334155;
--color-border: #475569;
--color-text: #e2e8f0;
--color-text-muted:#94a3b8;
--color-primary: #38bdf8;
--color-primary-d: #0284c7;
--color-accent: #a78bfa;
--color-accent-d: #7c3aed;
--color-success: #4ade80;
--color-warning: #fbbf24;
--color-danger: #f87171;
--color-white: #f8fafc;
--font-sans: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
--font-mono: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace;
--radius: 0.75rem;
--radius-lg: 1rem;
--shadow: 0 4px 6px -1px rgba(0,0,0,.3), 0 2px 4px -2px rgba(0,0,0,.2);
--shadow-lg: 0 10px 15px -3px rgba(0,0,0,.4), 0 4px 6px -4px rgba(0,0,0,.3);
--transition: 0.2s ease;
}
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
html { scroll-behavior: smooth; font-size: 16px; }
body {
font-family: var(--font-sans);
background: var(--color-bg);
color: var(--color-text);
line-height: 1.7;
min-height: 100vh;
}
a { color: var(--color-primary); text-decoration: none; transition: color var(--transition); }
a:hover { color: var(--color-accent); }
img { max-width: 100%; height: auto; display: block; }
/* ---------- Layout ---------- */
.container { max-width: 1200px; margin: 0 auto; padding: 0 1.5rem; }
.section { padding: 5rem 0; }
/* ---------- Navigation ---------- */
.navbar {
position: fixed; top: 0; left: 0; right: 0; z-index: 100;
background: rgba(15, 23, 42, 0.85);
backdrop-filter: blur(12px);
border-bottom: 1px solid var(--color-border);
padding: 0.75rem 0;
}
.navbar .container {
display: flex; align-items: center; justify-content: space-between;
}
.nav-brand {
display: flex; align-items: center; gap: 0.75rem;
font-weight: 700; font-size: 1.15rem; color: var(--color-white);
}
.nav-brand img { height: 32px; width: 32px; border-radius: 6px; }
.nav-links { display: flex; gap: 1.5rem; list-style: none; }
.nav-links a {
color: var(--color-text-muted); font-size: 0.9rem; font-weight: 500;
padding: 0.4rem 0; transition: color var(--transition);
}
.nav-links a:hover, .nav-links a.active { color: var(--color-primary); }
.nav-github {
display: inline-flex; align-items: center; gap: 0.4rem;
background: var(--color-surface); color: var(--color-white);
padding: 0.45rem 1rem; border-radius: 9999px; font-size: 0.85rem;
font-weight: 600; transition: background var(--transition);
}
.nav-github:hover { background: var(--color-primary-d); color: var(--color-white); }
/* Mobile nav toggle */
.nav-toggle { display: none; background: none; border: none; color: var(--color-text); font-size: 1.5rem; cursor: pointer; }
@media (max-width: 768px) {
.nav-toggle { display: block; }
.nav-links {
display: none; flex-direction: column; position: absolute;
top: 100%; left: 0; right: 0; background: var(--color-bg-alt);
padding: 1rem 1.5rem; gap: 0.5rem; border-bottom: 1px solid var(--color-border);
}
.nav-links.open { display: flex; }
}
/* ---------- Hero ---------- */
.hero {
padding: 10rem 0 5rem;
text-align: center;
background: linear-gradient(180deg, rgba(56,189,248,0.08) 0%, transparent 60%);
}
.hero h1 {
font-size: clamp(2rem, 5vw, 3.5rem);
font-weight: 800;
line-height: 1.15;
margin-bottom: 1rem;
}
.hero h1 .gradient {
background: linear-gradient(135deg, var(--color-primary), var(--color-accent));
-webkit-background-clip: text; background-clip: text;
-webkit-text-fill-color: transparent;
}
.hero .tagline {
font-size: clamp(1.05rem, 2vw, 1.35rem);
color: var(--color-text-muted);
max-width: 640px; margin: 0 auto 2rem;
}
.hero-actions { display: flex; gap: 1rem; justify-content: center; flex-wrap: wrap; }
.btn {
display: inline-flex; align-items: center; gap: 0.5rem;
padding: 0.75rem 1.75rem; border-radius: 9999px;
font-weight: 600; font-size: 0.95rem; cursor: pointer;
border: none; transition: all var(--transition);
}
.btn-primary {
background: linear-gradient(135deg, var(--color-primary), var(--color-accent));
color: #0f172a;
}
.btn-primary:hover { transform: translateY(-2px); box-shadow: var(--shadow-lg); color: #0f172a; }
.btn-outline {
background: transparent; color: var(--color-primary);
border: 2px solid var(--color-primary);
}
.btn-outline:hover { background: rgba(56,189,248,0.1); color: var(--color-primary); }
.hero-code {
margin-top: 2.5rem;
display: inline-block;
background: var(--color-bg-alt);
border: 1px solid var(--color-border);
border-radius: var(--radius);
padding: 0.8rem 1.5rem;
font-family: var(--font-mono);
font-size: 0.9rem;
color: var(--color-success);
}
/* ---------- Stats ---------- */
.stats {
display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr));
gap: 1.5rem; padding: 3rem 0;
}
.stat {
text-align: center; padding: 1.5rem;
background: var(--color-bg-alt); border-radius: var(--radius);
border: 1px solid var(--color-border);
}
.stat-value {
font-size: 2rem; font-weight: 800;
background: linear-gradient(135deg, var(--color-primary), var(--color-accent));
-webkit-background-clip: text; background-clip: text;
-webkit-text-fill-color: transparent;
}
.stat-label { font-size: 0.85rem; color: var(--color-text-muted); margin-top: 0.25rem; }
/* ---------- Section headings ---------- */
.section-header {
text-align: center; margin-bottom: 3rem;
}
.section-header h2 {
font-size: clamp(1.5rem, 3vw, 2.25rem);
font-weight: 700; margin-bottom: 0.5rem;
}
.section-header p {
color: var(--color-text-muted); max-width: 600px; margin: 0 auto;
}
/* ---------- Pipeline Overview (Landing) ---------- */
.pipeline-preview {
display: grid; grid-template-columns: repeat(auto-fit, minmax(240px, 1fr));
gap: 1rem;
}
.phase-card {
background: var(--color-bg-alt); border-radius: var(--radius);
border: 1px solid var(--color-border);
padding: 1.5rem; transition: all var(--transition);
}
.phase-card:hover {
border-color: var(--color-primary);
transform: translateY(-3px); box-shadow: var(--shadow-lg);
}
.phase-card .phase-icon { font-size: 1.75rem; margin-bottom: 0.75rem; }
.phase-card h3 { font-size: 1rem; font-weight: 600; margin-bottom: 0.4rem; }
.phase-card p { font-size: 0.85rem; color: var(--color-text-muted); line-height: 1.5; }
.phase-card .phase-stages {
margin-top: 0.75rem; display: flex; gap: 0.35rem; flex-wrap: wrap;
}
.stage-dot {
width: 8px; height: 8px; border-radius: 50%;
background: var(--color-primary); opacity: 0.5;
}
.stage-dot.gate { background: var(--color-warning); opacity: 1; }
/* ---------- Paper Cards ---------- */
.paper-grid {
display: grid; grid-template-columns: repeat(auto-fit, minmax(320px, 1fr));
gap: 1.5rem;
}
.paper-card {
background: var(--color-bg-alt); border-radius: var(--radius-lg);
border: 1px solid var(--color-border);
overflow: hidden; transition: all var(--transition);
}
.paper-card:hover {
border-color: var(--color-primary);
transform: translateY(-3px); box-shadow: var(--shadow-lg);
}
.paper-thumb {
height: 180px; background: linear-gradient(135deg, var(--color-surface), var(--color-bg));
display: flex; align-items: center; justify-content: center;
font-size: 3rem; color: var(--color-text-muted);
}
.paper-body { padding: 1.25rem; }
.paper-body h3 { font-size: 1rem; font-weight: 600; line-height: 1.4; margin-bottom: 0.5rem; }
.paper-body .paper-meta {
display: flex; gap: 0.5rem; flex-wrap: wrap; margin-bottom: 0.75rem;
}
.badge {
display: inline-flex; align-items: center; gap: 0.25rem;
padding: 0.2rem 0.6rem; border-radius: 9999px;
font-size: 0.72rem; font-weight: 600;
}
.badge-domain { background: rgba(56,189,248,0.15); color: var(--color-primary); }
.badge-score { background: rgba(74,222,128,0.15); color: var(--color-success); }
.badge-pending { background: rgba(251,191,36,0.15); color: var(--color-warning); }
.paper-body .paper-abstract {
font-size: 0.82rem; color: var(--color-text-muted);
line-height: 1.55; display: -webkit-box;
-webkit-line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden;
}
/* ---------- Feature Cards ---------- */
.feature-grid {
display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
gap: 1.5rem;
}
.feature-card {
background: var(--color-bg-alt); border-radius: var(--radius);
border: 1px solid var(--color-border);
padding: 1.75rem; transition: all var(--transition);
}
.feature-card:hover {
border-color: var(--color-accent);
transform: translateY(-2px);
}
.feature-card .feature-icon {
width: 48px; height: 48px; border-radius: 12px;
display: flex; align-items: center; justify-content: center;
font-size: 1.4rem; margin-bottom: 1rem;
}
.feature-card .feature-icon.bg-blue { background: rgba(56,189,248,0.15); }
.feature-card .feature-icon.bg-purple { background: rgba(167,139,250,0.15); }
.feature-card .feature-icon.bg-green { background: rgba(74,222,128,0.15); }
.feature-card .feature-icon.bg-amber { background: rgba(251,191,36,0.15); }
.feature-card .feature-icon.bg-red { background: rgba(248,113,113,0.15); }
.feature-card h3 { font-size: 1.05rem; font-weight: 600; margin-bottom: 0.4rem; }
.feature-card p { font-size: 0.85rem; color: var(--color-text-muted); line-height: 1.5; }
/* ---------- Pipeline Page (full) ---------- */
.pipeline-full { max-width: 860px; margin: 0 auto; }
.phase-group { margin-bottom: 3rem; }
.phase-group-header {
display: flex; align-items: center; gap: 0.75rem;
margin-bottom: 1rem; padding-bottom: 0.5rem;
border-bottom: 2px solid var(--color-border);
}
.phase-group-header .phase-letter {
width: 36px; height: 36px; border-radius: 10px;
display: flex; align-items: center; justify-content: center;
font-weight: 800; font-size: 0.85rem;
background: linear-gradient(135deg, var(--color-primary), var(--color-accent));
color: #0f172a;
}
.phase-group-header h3 { font-size: 1.15rem; font-weight: 600; }
.stage-list { display: flex; flex-direction: column; gap: 0.5rem; }
.stage-item {
display: flex; align-items: flex-start; gap: 1rem;
padding: 1rem 1.25rem;
background: var(--color-bg-alt);
border-radius: var(--radius);
border: 1px solid var(--color-border);
cursor: pointer; transition: all var(--transition);
}
.stage-item:hover { border-color: var(--color-primary); }
.stage-item.expanded { border-color: var(--color-primary); background: rgba(56,189,248,0.04); }
.stage-number {
flex-shrink: 0; width: 32px; height: 32px; border-radius: 8px;
display: flex; align-items: center; justify-content: center;
font-weight: 700; font-size: 0.8rem;
background: var(--color-surface); color: var(--color-text);
}
.stage-number.gate { background: rgba(251,191,36,0.2); color: var(--color-warning); }
.stage-info { flex: 1; }
.stage-info h4 { font-size: 0.95rem; font-weight: 600; margin-bottom: 0.15rem; }
.stage-info .stage-subtitle { font-size: 0.8rem; color: var(--color-text-muted); }
.stage-detail {
display: none; margin-top: 0.75rem; padding-top: 0.75rem;
border-top: 1px solid var(--color-border);
font-size: 0.85rem; color: var(--color-text-muted); line-height: 1.6;
}
.stage-item.expanded .stage-detail { display: block; }
.stage-badges { display: flex; gap: 0.35rem; flex-wrap: wrap; margin-top: 0.5rem; }
.stage-badge {
padding: 0.15rem 0.5rem; border-radius: 9999px;
font-size: 0.7rem; font-weight: 600;
}
.stage-badge-gate { background: rgba(251,191,36,0.15); color: var(--color-warning); }
.stage-badge-new { background: rgba(167,139,250,0.15); color: var(--color-accent); }
.stage-badge-llm { background: rgba(56,189,248,0.15); color: var(--color-primary); }
.stage-badge-docker { background: rgba(74,222,128,0.15); color: var(--color-success); }
/* ---------- Footer ---------- */
.footer {
padding: 3rem 0; text-align: center;
border-top: 1px solid var(--color-border);
}
.footer p { color: var(--color-text-muted); font-size: 0.85rem; }
.footer-links {
display: flex; justify-content: center; gap: 1.5rem;
list-style: none; margin-bottom: 1rem;
}
.footer-links a { color: var(--color-text-muted); font-size: 0.85rem; }
.footer-links a:hover { color: var(--color-primary); }
/* ---------- Getting Started ---------- */
.getting-started-content { max-width: 780px; margin: 0 auto; }
.step-block {
background: var(--color-bg-alt); border-radius: var(--radius);
border: 1px solid var(--color-border);
padding: 1.5rem; margin-bottom: 1rem;
}
.step-block h3 {
display: flex; align-items: center; gap: 0.75rem;
font-size: 1.05rem; font-weight: 600; margin-bottom: 0.75rem;
}
.step-num {
width: 28px; height: 28px; border-radius: 50%;
display: inline-flex; align-items: center; justify-content: center;
font-size: 0.8rem; font-weight: 700;
background: var(--color-primary); color: #0f172a;
}
.code-block {
background: #0d1117; border-radius: 8px;
padding: 1rem 1.25rem; margin-top: 0.5rem;
font-family: var(--font-mono); font-size: 0.85rem;
color: var(--color-success); overflow-x: auto;
border: 1px solid #21262d;
}
.code-block .comment { color: var(--color-text-muted); }
/* ---------- Connector Line (pipeline page) ---------- */
.connector {
display: flex; justify-content: center; padding: 0.5rem 0;
}
.connector-line {
width: 2px; height: 24px;
background: linear-gradient(180deg, var(--color-primary), var(--color-accent));
opacity: 0.4;
}
/* ---------- Coming Soon Overlay ---------- */
.coming-soon {
display: flex; flex-direction: column; align-items: center;
justify-content: center; padding: 4rem 2rem;
text-align: center;
}
.coming-soon .cs-icon { font-size: 3rem; margin-bottom: 1rem; }
.coming-soon h3 { font-size: 1.25rem; margin-bottom: 0.5rem; }
.coming-soon p { color: var(--color-text-muted); max-width: 400px; }
/* ---------- Utilities ---------- */
.text-center { text-align: center; }
.mt-2 { margin-top: 0.5rem; }
.mt-4 { margin-top: 1rem; }
.mt-8 { margin-top: 2rem; }
.mb-4 { margin-top: 1rem; }
.gap-2 { gap: 0.5rem; }
.flex { display: flex; }
.items-center { align-items: center; }
.justify-center { justify-content: center; }