Repository: facebookresearch/DocAgent Branch: main Commit: f27f68574ee1 Files: 109 Total size: 781.2 KB Directory structure: gitextract_774lw2f5/ ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── INSTALL.md ├── LICENSE ├── README.md ├── config/ │ └── example_config.yaml ├── data/ │ ├── raw_test_repo/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── example.py │ │ ├── inventory/ │ │ │ ├── __init__.py │ │ │ └── inventory_manager.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ └── product.py │ │ ├── payment/ │ │ │ ├── __init__.py │ │ │ └── payment_processor.py │ │ └── vending_machine.py │ └── raw_test_repo_simple/ │ ├── helper.py │ ├── inner/ │ │ └── inner_functions.py │ ├── main.py │ ├── processor.py │ └── test_file.py ├── eval_completeness.py ├── generate_docstrings.py ├── output/ │ └── dependency_graphs/ │ └── raw_test_repo_dependency_graph.json ├── run_web_ui.py ├── setup.py ├── src/ │ ├── DocstringGenerator.egg-info/ │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ ├── requires.txt │ │ └── top_level.txt │ ├── __init__.py │ ├── agent/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── base.py │ │ ├── llm/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── claude_llm.py │ │ │ ├── factory.py │ │ │ ├── gemini_llm.py │ │ │ ├── huggingface_llm.py │ │ │ ├── openai_llm.py │ │ │ └── rate_limiter.py │ │ ├── orchestrator.py │ │ ├── reader.py │ │ ├── searcher.py │ │ ├── tool/ │ │ │ ├── README.md │ │ │ ├── ast.py │ │ │ ├── internal_traverse.py │ │ │ └── perplexity_api.py │ │ ├── verifier.py │ │ ├── workflow.py │ │ └── writer.py │ ├── analyze_helpfulness_significance.py │ ├── data/ │ │ └── parse/ │ │ ├── data_process.py │ │ ├── downloader.py │ │ └── repo_tree.py │ ├── dependency_analyzer/ │ │ ├── __init__.py │ │ ├── ast_parser.py │ │ └── topo_sort.py │ ├── evaluate_helpfulness.py │ ├── evaluator/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── base.py │ │ ├── completeness.py │ │ ├── evaluation_common.py │ │ ├── helper/ │ │ │ └── context_finder.py │ │ ├── helpfulness_attributes.py │ │ ├── helpfulness_description.py │ │ ├── helpfulness_evaluator.py │ │ ├── helpfulness_evaluator_ablation.py │ │ ├── helpfulness_examples.py │ │ ├── helpfulness_parameters.py │ │ ├── helpfulness_summary.py │ │ ├── segment.py │ │ └── truthfulness.py │ ├── visualizer/ │ │ ├── __init__.py │ │ ├── progress.py │ │ ├── status.py │ │ └── web_bridge.py │ ├── web/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── app.py │ │ ├── config_handler.py │ │ ├── process_handler.py │ │ ├── run.py │ │ ├── static/ │ │ │ ├── css/ │ │ │ │ └── style.css │ │ │ └── js/ │ │ │ ├── completeness.js │ │ │ ├── config.js │ │ │ ├── log-handler.js │ │ │ ├── main.js │ │ │ ├── repo-structure.js │ │ │ └── status-visualizer.js │ │ ├── templates/ │ │ │ └── index.html │ │ └── visualization_handler.py │ └── web_eval/ │ ├── README.md │ ├── app.py │ ├── helpers.py │ ├── requirements.txt │ ├── start_server.sh │ ├── static/ │ │ └── css/ │ │ └── style.css │ ├── templates/ │ │ ├── index.html │ │ └── results.html │ └── test_docstring_parser.py └── tool/ ├── remove_docstrings.py ├── remove_docstrings.sh └── serve_local_llm.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ config/agent_config.yaml tool/add_header.sh ================================================ FILE: CHANGELOG.md ================================================ 0.0.1 (April 17, 2025) ### First Version Include web UI, CLI for DocAgent. ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when there is a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to DocAgent We want to make contributing to this project as easy and transparent as possible. ## Pull Requests We actively welcome your pull requests. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Meta's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Meta has a [bounty program](https://bugbounty.meta.com/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## Coding Style * 2 spaces for indentation rather than tabs * 80 character line length * Use [Black](https://github.com/psf/black) for code formatting. * Use [Flake8](https://flake8.pycqa.org/en/latest/) for linting. * Follow [PEP 8](https://www.python.org/dev/peps/pep-0008/) style guidelines. * Use snake_case for variable and function names. * Use PascalCase for class names. * Write docstrings for all public modules, classes, functions, and methods using Google style. * Use type hints for function signatures. * Keep imports organized: standard library first, then third-party libraries, then local application/library specific imports, each group separated by a blank line. Use [isort](https://pycqa.github.io/isort/) to automate this. ## License By contributing to DocAgent, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ================================================ FILE: INSTALL.md ================================================ # Installation Guide This guide details how to set up the environment for DocAgent. ## Option 1: Installation with pip (Recommended) ### Basic Installation To install the basic package with core dependencies: ```bash # For all dependencies pip install -e ".[all]" ``` ## Development Setup For development, we recommend installing in editable mode with dev dependencies: ```bash # Install the package in editable mode with dev dependencies pip install -e ".[dev]" # Run tests pytest ``` ## Troubleshooting ### GraphViz Dependencies For visualization components, you may need to install system-level dependencies for GraphViz: ```bash # Ubuntu/Debian sudo apt-get install graphviz graphviz-dev # CentOS/RHEL sudo yum install graphviz graphviz-devel # macOS brew install graphviz ``` ### CUDA Support If you're using CUDA for accelerated processing, ensure you have the correct CUDA toolkit installed that matches your PyTorch version. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) Meta Platforms, Inc. and affiliates. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # DocAgent: Agentic Hierarchical Docstring Generation System

Meta Logo

DocAgent is a system designed to generate high-quality, context-aware docstrings for Python codebases using a multi-agent approach and hierarchical processing. ## Citation If you use DocAgent in your research, please cite our paper: ```bibtex @misc{yang2025docagent, title={DocAgent: A Multi-Agent System for Automated Code Documentation Generation}, author={Dayu Yang and Antoine Simoulin and Xin Qian and Xiaoyi Liu and Yuwei Cao and Zhaopu Teng and Grey Yang}, year={2025}, eprint={2504.08725}, archivePrefix={arXiv}, primaryClass={cs.SE} } ``` You can find the paper on arXiv: [https://arxiv.org/abs/2504.08725](https://arxiv.org/abs/2504.08725) ## Table of Contents - [Motivation](#motivation) - [Methodology](#methodology) - [Installation](#installation) - [Components](#components) - [Configuration](#configuration) - [Usage](#usage) - [Running the Evaluation System](#running-the-evaluation-system) - [Optional: Using a Local LLM](#optional-using-a-local-llm) ## Motivation High-quality docstrings are crucial for code readability, usability, and maintainability, especially in large repositories. They should explain the purpose, parameters, returns, exceptions, and usage within the broader context. Current LLMs often struggle with this, producing superficial or redundant comments and failing to capture essential context or rationale. DocAgent aims to address these limitations by generating informative, concise, and contextually aware docstrings. ## Methodology DocAgent employs two key strategies: 1. **Hierarchical Traversal**: Processes code components by analyzing dependencies, starting with files having fewer dependencies. This builds a documented foundation before tackling more complex code, addressing the challenge of documenting context that itself lacks documentation. 2. **Agentic System**: Utilizes a team of specialized agents (`Reader`, `Searcher`, `Writer`, `Verifier`) coordinated by an `Orchestrator`. This system gathers context (internal and external), drafts docstrings according to standards, and verifies their quality in an iterative process. System Overview For more details on the agentic framework, see the [Agent Component README](./src/agent/README.md). ## Installation 1. Clone the repository: ```bash git clone cd DocAgent ``` 2. Install the necessary dependencies. It's recommended to use a virtual environment: ```bash python -m venv venv source venv/bin/activate # if you use venv, you can also use conda pip install -e . ``` *Note: For optional features like development tools, web UI components, or specific hardware support (e.g., CUDA), refer to the comments in `setup.py` and install extras as needed (e.g., `pip install -e ".[dev,web]"`).* ## Components DocAgent is composed of several key parts: - **[Core Agent Framework](./src/agent/README.md)**: Implements the multi-agent system (Reader, Searcher, Writer, Verifier, Orchestrator) responsible for the generation logic. - **[Docstring Evaluator](./src/evaluator/README.md)**: Provides tools for evaluating docstring quality, primarily focusing on completeness based on static code analysis (AST). *Note: Evaluation is run separately, see its README.* - **[Generation Web UI](./src/web/README.md)**: A web interface for configuring, running, and monitoring the docstring *generation* process in real-time. ## Configuration Before running DocAgent, you **must** create a configuration file named `config/agent_config.yaml`. This file specifies crucial parameters for the agents, such as the LLM endpoints, API keys (if required), model names, and generation settings. 1. **Copy the Example**: An example configuration file is provided at `config/example_config.yaml`. Copy this file to `config/agent_config.yaml`: ```bash cp config/example_config.yaml config/agent_config.yaml ``` 2. **Edit the Configuration**: Open `config/agent_config.yaml` in a text editor and modify the settings according to your environment and requirements. Pay close attention to the LLM provider, model selection, and any necessary API credentials. ## Usage You can run the docstring generation process using either the command line or the web UI. **1. Command Line Interface (CLI)** This is the primary method for running the generation process directly. ```bash # Example: Run on a test repo (remove existing docstrings first if desired) ./test/tool/remove_docstrings.sh data/raw_test_repo python generate_docstrings.py --repo-path data/raw_test_repo ``` Use `python generate_docstrings.py --help` to see available options, such as specifying different configurations or test modes. **2. Generation Web UI** The web UI provides a graphical interface to configure, run, and monitor the process. - Note that when input repo path, always put complete absolute path. ```bash # Launch the web UI server python run_web_ui.py --host 0.0.0.0 --port 5000 ``` Then, access the UI in your web browser, typically at `http://localhost:5000`. If running the server remotely, you might need to set up SSH tunneling (see instructions below or the [Web UI README](./src/web/README.md)). *Basic SSH Tunneling (if running server remotely):* ```bash # In your local terminal ssh -L 5000:localhost:5000 @ # Then access http://localhost:5000 in your local browser ``` ## Running the Evaluation System DocAgent includes a separate web-based interface for evaluating the quality of generated docstrings. **1. Running Locally** To run the evaluation system on your local machine: ```bash python src/web_eval/app.py ``` Then, access the evaluation UI in your web browser at `http://localhost:5001`. **2. Running on a Remote Server** To run the evaluation system on a remote server: ```bash python src/web_eval/app.py --host 0.0.0.0 --port 5001 ``` Then, set up SSH tunneling to access the remote server from your local machine: ```bash ssh -L 5001:localhost:5001 @ ``` Once the tunnel is established, access the evaluation UI in your local web browser at `http://localhost:5001`. ## Optional: Using a Local LLM If you prefer to use a local LLM (e.g., one hosted via Hugging Face), you can configure DocAgent to interact with it via an API endpoint. 1. **Serve the Local LLM**: Use a tool like `vllm` to serve your model. A convenience script is provided: ```bash # Ensure vllm is installed: pip install vllm bash tool/serve_local_llm.sh ``` This script will likely start an OpenAI-compatible API server (check the script details). Note the URL where the model is served (e.g., `http://localhost:8000/v1`). 2. **Configure DocAgent**: Update your `config/agent_config.yaml` to point to the local LLM API endpoint. You'll typically need to set: - The `provider` to `openai` (if using an OpenAI-compatible server like vllm's default). - The `api_base` or equivalent URL parameter to your local server address (e.g., `http://localhost:8000/v1`). - The `model_name` to the appropriate identifier for your local model. - Set the `api_key` to `None` or an empty string if no key is required by your local server. 3. **Run DocAgent**: Run the generation process as usual (CLI or Web UI). DocAgent will now send requests to your local LLM. ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. ================================================ FILE: config/example_config.yaml ================================================ # Example configuration file for DocAgent # Copy this file to agent_config.yaml and add your own API keys # LLM configuration for all agents llm: # Choose ONE of the following LLM provider configurations by uncommenting # Option 1: Claude (Anthropic) type: "claude" api_key: "your-anthropic-api-key-here" model: "claude-3-5-haiku-latest" # Options: claude-3-5-sonnet, claude-3-opus, etc. temperature: 0.1 max_output_tokens: 4096 max_input_tokens: 100000 # Maximum number of tokens for input context # Option 2: OpenAI # type: "openai" # api_key: "your-openai-api-key-here" # model: "gpt-4o" # Options: gpt-4o, gpt-4-turbo, gpt-3.5-turbo, etc. # temperature: 0.1 # max_output_tokens: 4096 # max_input_tokens: 100000 # Option 3: Gemini # type: "gemini" # api_key: "your-gemini-api-key-here" # model: "gemini-1.5-pro" # temperature: 0.1 # max_output_tokens: 4096 # max_input_tokens: 100000 # Option 4: HuggingFace (for local models) # type: "huggingface" # model: "codellama/CodeLlama-34b-Instruct-hf" # api_base: "http://localhost:8000/v1" # Local API endpoint # api_key: "EMPTY" # Can be empty for local models # device: "cuda" # Options: cuda, cpu # torch_dtype: "float16" # temperature: 0.1 # max_output_tokens: 4096 # max_input_tokens: 32000 # Rate limit settings for different LLM providers # These are default values - adjust based on your specific API tier rate_limits: # Claude rate limits claude: requests_per_minute: 50 input_tokens_per_minute: 20000 output_tokens_per_minute: 8000 input_token_price_per_million: 3.0 output_token_price_per_million: 15.0 # OpenAI rate limits openai: requests_per_minute: 500 input_tokens_per_minute: 200000 output_tokens_per_minute: 100000 input_token_price_per_million: 0.15 output_token_price_per_million: 0.60 # Gemini rate limits gemini: requests_per_minute: 60 input_tokens_per_minute: 30000 output_tokens_per_minute: 10000 input_token_price_per_million: 0.125 output_token_price_per_million: 0.375 # Flow control parameters flow_control: max_reader_search_attempts: 2 # Maximum times reader can call searcher max_verifier_rejections: 1 # Maximum times verifier can reject a docstring status_sleep_time: 1 # Time to sleep between status updates (seconds) # Docstring generation options docstring_options: overwrite_docstrings: false # Whether to overwrite existing docstrings (default: false) # Perplexity API configuration (for web search capability) perplexity: api_key: "your-perplexity-api-key-here" # Replace with your actual Perplexity API key model: "sonar" # Default model temperature: 0.1 max_output_tokens: 250 ================================================ FILE: data/raw_test_repo/README.md ================================================ # Vending Machine Test Repository A comprehensive vending machine implementation in Python that demonstrates various programming concepts, design patterns, and documentation styles. This repository serves as a test bed for docstring generation systems and code documentation analysis. ## Project Structure ``` test_repo_vm/ ├── __init__.py # Main package initialization ├── example.py # Example usage demonstration ├── vending_machine.py # Main vending machine implementation ├── models/ # Data models │ ├── __init__.py │ └── product.py # Product class definition ├── payment/ # Payment processing │ ├── __init__.py │ └── payment_processor.py # Payment-related classes └── inventory/ # Inventory management ├── __init__.py └── inventory_manager.py # Inventory tracking system ``` ## Components ### 1. Product Management (`models/product.py`) - `Product` class with attributes like ID, name, price, quantity, and expiry date - Methods for checking availability and managing stock ### 2. Payment Processing (`payment/payment_processor.py`) - Abstract `PaymentMethod` base class for different payment types - `CashPayment` implementation for handling cash transactions - `PaymentTransaction` class for tracking payment status - `PaymentStatus` enum for transaction states ### 3. Inventory Management (`inventory/inventory_manager.py`) - `InventoryManager` class for product storage and retrieval - Slot-based product organization - Stock level tracking - Product availability checking ### 4. Main Vending Machine (`vending_machine.py`) - `VendingMachine` class that coordinates all components - Product selection and purchase workflow - Payment processing and change calculation - Exception handling for error cases ## Code Features This repository demonstrates various Python programming features: 1. **Object-Oriented Design** - Abstract base classes - Inheritance - Encapsulation - Interface definitions 2. **Modern Python Features** - Type hints - Dataclasses - Enums - Optional types - Package organization 3. **Documentation** - Comprehensive docstrings - Type annotations - Code organization - Exception documentation 4. **Best Practices** - SOLID principles - Clean code architecture - Error handling - Modular design ## Usage Example ```python from decimal import Decimal from vending_machine import VendingMachine from models.product import Product # Create a vending machine vm = VendingMachine() # Add products product = Product( id="COLA001", name="Cola Classic", price=1.50, quantity=10, category="drinks" ) vm.inventory.add_product(product, slot=0) # Insert money vm.insert_money(Decimal('2.00')) # Purchase product product, change = vm.purchase_product(slot=0) print(f"Purchased: {product.name}") print(f"Change: ${change:.2f}") ``` ## Running the Example To run the example implementation: ```bash python example.py ``` This will demonstrate: 1. Creating a vending machine 2. Adding products to inventory 3. Displaying available products 4. Making a purchase 5. Handling change 6. Updating inventory ## Testing Documentation Generation This repository is structured to test various aspects of documentation generation: 1. **Complex Imports** - Cross-module dependencies - Package-level imports - Relative imports 2. **Documentation Styles** - Function documentation - Class documentation - Module documentation - Package documentation 3. **Code Complexity** - Multiple inheritance - Abstract classes - Type annotations - Exception hierarchies ## Requirements - Python 3.7+ - No external dependencies required ## License This project is open source and available under the MIT License. ================================================ FILE: data/raw_test_repo/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Vending Machine Package A comprehensive vending machine implementation with: - Product management - Inventory tracking - Payment processing - Transaction handling """ ================================================ FILE: data/raw_test_repo/example.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from decimal import Decimal from datetime import datetime, timedelta from models.product import Item from vending_machine import Sys, SysErr def main(): s = Sys() items = [Item(code='D1', label='Drink1', val=1.5, count=10, grp='d', exp=datetime.now() + timedelta(days=90)), Item(code='S1', label= 'Snack1', val=1.0, count=15, grp='s', exp=datetime.now() + timedelta(days=30)), Item(code='S2', label='Snack2', val=2.0, count =8, grp='s', exp=datetime.now() + timedelta(days=60))] for i, item in enumerate(items): s.store.put(item, i) try: print('Items:') for pos, item in s.ls(): print(f'Pos {pos}: {item.label} - ${item.val:.2f}') pos = 0 print('\nAdding $2.00...') s.add_money(Decimal('2.00')) item, ret = s.buy(pos) print(f'\nBought: {item.label}') if ret: print(f'Return: ${ret:.2f}') print('\nUpdated Items:') for pos, item in s.ls(): print( f'Pos {pos}: {item.label} - ${item.val:.2f} (Count: {item.count})' ) except SysErr as e: print(f'Err: {str(e)}') if __name__ == '__main__': main() ================================================ FILE: data/raw_test_repo/inventory/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """Inventory management package for product stock tracking.""" ================================================ FILE: data/raw_test_repo/inventory/inventory_manager.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, List, Optional from ..models.product import Item class Store: def __init__(self, cap: int=20): self.cap = cap self._data: Dict[str, Item] = {} self._map: Dict[int, str] = {} def put(self, obj: Item, pos: Optional[int]=None) ->bool: if obj.code in self._data: curr = self._data[obj.code] curr.count += obj.count return True if pos is not None: if pos < 0 or pos >= self.cap: return False if pos in self._map: return False self._map[pos] = obj.code else: for i in range(self.cap): if i not in self._map: self._map[i] = obj.code break else: return False self._data[obj.code] = obj return True def rm(self, code: str) ->bool: if code not in self._data: return False for k, v in list(self._map.items()): if v == code: del self._map[k] del self._data[code] return True def get(self, code: str) ->Optional[Item]: return self._data.get(code) def get_at(self, pos: int) ->Optional[Item]: if pos not in self._map: return None code = self._map[pos] return self._data.get(code) def ls(self) ->List[Item]: return [obj for obj in self._data.values() if obj.check()] def find(self, code: str) ->Optional[int]: for k, v in self._map.items(): if v == code: return k return None ================================================ FILE: data/raw_test_repo/models/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """Models package for data structures used in the vending machine.""" ================================================ FILE: data/raw_test_repo/models/product.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass from typing import Optional from datetime import datetime @dataclass class Item: """ Summary: Represents an item with associated attributes for tracking and management in various contexts. Description: This class serves as a blueprint for creating items that can be tracked and managed within a system. Each item has attributes such as a unique code, a label, a value, a count, an optional expiration date, and a group classification. The primary motivation behind this class is to facilitate resource management, inventory tracking, or any scenario where items need to be monitored for validity and availability. Use this class when you need to represent items that may have a limited lifespan or quantity, such as in inventory systems, gaming resources, or token management. It provides methods to check the validity of an item and to modify its count, ensuring that operations on the item are safe and consistent. The class fits into larger systems by allowing for easy integration with resource management workflows, enabling developers to track item states and manage their lifecycle effectively. Example: ```python from datetime import datetime, timedelta # Create an item with a specific expiration date item = Item(code='A123', label='Sample Item', val=10.0, count=5, exp=datetime.now() + timedelta(days=1)) # Check if the item is valid is_valid = item.check() # Returns True if count > 0 and not expired # Modify the count of the item item.mod(2) # Decreases count by 2, returns True ``` Parameters: - code (str): A unique identifier for the item. - label (str): A descriptive name for the item. - val (float): The value associated with the item, representing its worth. - count (int): The quantity of the item available. Must be a non-negative integer. - exp (Optional[datetime]): An optional expiration date for the item. If set, the item will be considered invalid after this date. - grp (str): A classification group for the item, defaulting to 'misc'. Attributes: - code (str): The unique identifier for the item. - label (str): The name or description of the item. - val (float): The monetary or functional value of the item. - count (int): The current quantity of the item available, must be non-negative. - exp (Optional[datetime]): The expiration date of the item, if applicable. - grp (str): The group classification of the item, useful for categorization. """ code: str label: str val: float count: int exp: Optional[datetime] = None grp: str = 'misc' def check(self) -> bool: """ Validates the current object's state based on count and expiration. Checks whether the object is still valid by verifying two key conditions: 1. The object's count is greater than zero 2. The object has not exceeded its expiration timestamp This method is typically used to determine if an object is still usable or has become stale/invalid. It provides a quick state validation check that can be used in resource management, token validation, or lifecycle tracking scenarios. Returns: bool: True if the object is valid (count > 0 and not expired), False otherwise. """ if self.count <= 0: return False if self.exp and datetime.now() > self.exp: return False return True def mod(self, n: int=1) -> bool: """ Summary: Determines if the current count can be decremented by a specified value. Description: This method checks if the `count` attribute is greater than or equal to the provided integer `n`. If so, it decrements `count` by `n` and returns `True`. If `count` is less than `n`, it returns `False`, indicating that the operation could not be performed. Use this function when managing resources or operations that require a controlled decrement of a count, ensuring that the count does not drop below zero. This is particularly useful in scenarios such as resource allocation, gaming mechanics, or iterative processes. The method is integral to classes that require precise control over a count, allowing for safe decrements while maintaining the integrity of the count value. Args: n (int, optional): The value to decrement from `count`. Must be a positive integer that does not exceed the current `count`. Default is 1. Returns: bool: Returns `True` if the decrement was successful (i.e., `count` was greater than or equal to `n`), otherwise returns `False`. Raises: No exceptions are raised by this method. Ensure that `n` is a positive integer and does not exceed the current `count` to avoid logical errors. Examples: ```python obj = YourClass() obj.count = 5 result = obj.mod(2) # result will be True, obj.count will be 3 result = obj.mod(4) # result will be False, obj.count remains 3 result = obj.mod(0) # result will be False, as n should be greater than 0 result = obj.mod(-1) # result will be False, as n should be a positive integer ``` """ if self.count >= n: self.count -= n return True return False ================================================ FILE: data/raw_test_repo/payment/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """Payment processing package for handling different payment methods.""" ================================================ FILE: data/raw_test_repo/payment/payment_processor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Optional from decimal import Decimal class TxStatus(Enum): WAIT = 'pending' DONE = 'completed' ERR = 'failed' RET = 'refunded' @dataclass class Tx: id: str amt: Decimal st: TxStatus mth: str msg: Optional[str] = None class Handler(ABC): @abstractmethod def proc(self, amt: Decimal) ->Tx: pass @abstractmethod def rev(self, tx: Tx) ->bool: pass class Cash(Handler): def __init__(self): self.bal: Decimal = Decimal('0.00') def add(self, amt: Decimal) ->None: self.bal += amt def proc(self, amt: Decimal) ->Tx: if self.bal >= amt: self.bal -= amt return Tx(id=f'C_{id(self)}', amt=amt, st=TxStatus.DONE, mth='cash' ) return Tx(id=f'C_{id(self)}', amt=amt, st=TxStatus.ERR, mth='cash', msg='insufficient') def rev(self, tx: Tx) ->bool: if tx.st == TxStatus.DONE: self.bal += tx.amt tx.st = TxStatus.RET return True return False def ret(self) ->Decimal: tmp = self.bal self.bal = Decimal('0.00') return tmp ================================================ FILE: data/raw_test_repo/vending_machine.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from decimal import Decimal from typing import Optional, List, Tuple from .models.product import Item from .payment.payment_processor import Handler, Tx, TxStatus, Cash from .inventory.inventory_manager import Store class SysErr(Exception): pass class Sys: def __init__(self, h: Optional[Handler]=None): self.store = Store() self.h = h or Cash() self._tx: Optional[Tx] = None def ls(self) ->List[Tuple[int, Item]]: items = [] for item in self.store.ls(): pos = self.store.find(item.code) if pos is not None: items.append((pos, item)) return sorted(items, key=lambda x: x[0]) def pick(self, pos: int) ->Optional[Item]: item = self.store.get_at(pos) if not item: raise SysErr('invalid pos') if not item.check(): raise SysErr('unavailable') return item def add_money(self, amt: Decimal) ->None: if not isinstance(self.h, Cash): raise SysErr('cash not supported') self.h.add(amt) def buy(self, pos: int) ->Tuple[Item, Optional[Decimal]]: item = self.pick(pos) tx = self.h.proc(Decimal(str(item.val))) self._tx = tx if tx.st != TxStatus.DONE: raise SysErr(tx.msg or 'tx failed') if not item.mod(): self.h.rev(tx) raise SysErr('dispense failed') ret = None if isinstance(self.h, Cash): ret = self.h.ret() return item, ret def cancel(self) ->Optional[Decimal]: if not self._tx: raise SysErr('no tx') ok = self.h.rev(self._tx) if not ok: raise SysErr('rev failed') ret = None if isinstance(self.h, Cash): ret = self.h.ret() self._tx = None return ret ================================================ FILE: data/raw_test_repo_simple/helper.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates class HelperClass: """ Represents a utility for managing and processing data. The `HelperClass` is designed to facilitate data processing tasks by leveraging the `DataProcessor` class. It serves as an intermediary that manages the workflow of data processing, making it easier to handle data updates and retrievals within a system. This class is particularly useful in scenarios where data needs to be processed and accessed in a structured manner. The `HelperClass` fits into the larger system architecture as a component that coordinates data processing tasks. It achieves its purpose by using the `DataProcessor` to perform the actual data processing and then managing the processed data internally. Example: # Initialize the HelperClass helper = HelperClass() # Process data using the helper helper.process_data() # Retrieve the processed data result result = helper.get_result() print(result) # Output: '[1, 2, 3]' Attributes: data (list): Stores the processed data, initially an empty list. """ def __init__(self): self.data = [] def process_data(self): """ Processes and updates the internal data. This method orchestrates the data processing workflow by invoking the `DataProcessor.process()` method to perform the main data processing task. It then calls `_internal_process()` to finalize the processing and update the internal `data` attribute. Use this method when you need to refresh or initialize the data within the `HelperClass` instance. Returns: None: This method updates the internal state and does not return a value. """ self.data = DataProcessor.process() self._internal_process() def _internal_process(self): """ No docstring provided. """ return self.data def get_result(self): """ No docstring provided. """ return str(self.data) class DataProcessor: ''' """Handles basic data processing tasks within a system. This class is designed to perform simple data processing operations, providing utility methods that can be used in various scenarios where basic data manipulation is required. It is particularly useful in contexts where a straightforward list of integers is needed for further processing or testing. The `DataProcessor` class fits into the larger system architecture as a utility component, offering static and internal methods to handle specific processing tasks. It achieves its purpose by providing a static method for general use and an internal method for class-specific operations. Example: # Initialize the DataProcessor class processor = DataProcessor() # Use the static method to process data result = DataProcessor.process() print(result) # Output: [1, 2, 3] # Use the internal method for internal processing internal_result = processor._internal_process() print(internal_result) # Output: 'processed' """ ''' @staticmethod def process(): ''' """Processes data and returns a list of integers. This static method is designed to perform a basic data processing task and return a predefined list of integers. It can be used whenever a simple list of integers is required for further operations or testing purposes. Returns: list of int: A list containing the integers [1, 2, 3]. """ ''' return [1, 2, 3] def _internal_process(self): ''' """Processes internal data and returns a status message. This method is used internally within the `DataProcessor` class to perform specific data processing tasks that are not exposed publicly. It is typically called by other methods within the class to handle intermediate processing steps. Returns: str: A string indicating the processing status, specifically 'processed'. """ ''' return 'processed' ================================================ FILE: data/raw_test_repo_simple/inner/inner_functions.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates def inner_function(): """ Returns a greeting message from an inner function. This function is designed to return a simple greeting message, which can be used in nested or internal function calls to verify execution flow or for debugging purposes. It is typically used in development environments where confirming the execution of specific code paths is necessary. Returns: str: A greeting message stating 'Hello from inner function!' Example: >>> message = inner_function() >>> print(message) 'Hello from inner function!' """ return 'Hello from inner function!' def get_random_quote(): """ Fetches a predefined inspirational quote. This function is designed to provide users with a motivational quote, which can be used in applications that aim to inspire or uplift users. It is particularly useful in scenarios where a quick, positive message is needed to enhance user experience. Returns: str: A quote string stating 'The best way to predict the future is to create it.' Example: >>> quote = get_random_quote() >>> print(quote) 'The best way to predict the future is to create it.' """ return 'The best way to predict the future is to create it.' def generate_timestamp(): """ Generates and returns a static timestamp. This function provides a hardcoded timestamp string, which can be used in scenarios where a consistent and predictable timestamp is required for testing or logging purposes. It fits into workflows where a fixed date and time representation is needed without relying on the current system time. Returns: str: A string representing the static timestamp '2023-05-15 14:30:22'. """ return '2023-05-15 14:30:22' def get_system_status(): """ Provides a static message indicating the operational status of systems. This function is used to retrieve a fixed status message that confirms all systems are functioning correctly. It is useful in monitoring dashboards or status pages where a quick confirmation of system health is required. Returns: str: A status message stating 'All systems operational.' Example: >>> status = get_system_status() >>> print(status) 'All systems operational' """ return 'All systems operational' def fetch_user_message(): ''' """Fetches a predefined user message indicating notifications. This function is used to retrieve a static message that informs the user about the number of notifications they have. It is typically used in scenarios where a quick status update is needed for user engagement. Returns: str: A message string stating 'Welcome back! You have 3 notifications.' Example: >>> message = fetch_user_message() >>> print(message) 'Welcome back! You have 3 notifications.' """ ''' return 'Welcome back! You have 3 notifications.' ================================================ FILE: data/raw_test_repo_simple/main.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from helper import HelperClass from inner.inner_functions import inner_function, get_random_quote, generate_timestamp, get_system_status, fetch_user_message def main_function(): """ Executes data processing and utility operations, returning the processed data as a string. This function initializes a `HelperClass` instance to manage and process data, invokes a utility function to provide a placeholder value, and generates a static timestamp for consistency in logging or testing scenarios. The function is useful when a complete data processing sequence is needed, integrating utility operations to produce a final result. Returns: str: The processed data result as a string, derived from the `HelperClass` instance after executing the data processing and utility functions. Example: # Execute the main function to process data and retrieve the result result = main_function() print(result) # Output: '[1, 2, 3]' """ helper = HelperClass() helper.process_data() utility_function() generate_timestamp() return helper.get_result() def utility_function(): """ Returns a utility string. This function provides a simple utility string, which can be used in various contexts where a placeholder or a generic return value is needed. It is typically used within workflows that require a consistent return value for testing or demonstration purposes. Returns: str: The string 'utility', serving as a generic utility value. """ return 'utility' ================================================ FILE: data/raw_test_repo_simple/processor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from helper import HelperClass from processor import DataProcessor from main import utility_function class AdvancedProcessor: """ Facilitates advanced data processing by coordinating multiple processing components. The `AdvancedProcessor` class is designed to manage and execute complex data processing workflows by integrating the functionalities of `HelperClass` and `DataProcessor`. It is ideal for scenarios where a comprehensive processing sequence is needed, providing a streamlined approach to handle data operations and produce a final result. This class fits into the larger system architecture as a high-level orchestrator of data processing tasks, ensuring that each component's capabilities are effectively utilized to achieve the desired outcome. Example: # Initialize the AdvancedProcessor processor = AdvancedProcessor() # Execute the processing workflow result = processor.run() print(result) # Output: 'utility' Attributes: helper (HelperClass): An instance of `HelperClass` used to manage data processing tasks. data_processor (DataProcessor): An instance of `DataProcessor` used to perform specific data processing operations. """ def __init__(self): self.helper = HelperClass() self.data_processor = DataProcessor() def run(self): """ Executes the complete data processing workflow and returns the result. This method coordinates the data processing tasks by utilizing both the `HelperClass` and `DataProcessor` to perform necessary operations. It is designed to be used when a full processing sequence is required, culminating in a final result that indicates the completion of these tasks. Returns: str: The result of the processing workflow, typically a utility string indicating successful completion. Example: # Create an instance of AdvancedProcessor processor = AdvancedProcessor() # Run the processing workflow result = processor.run() print(result) # Output: 'utility' """ self.helper.process_data() self.data_processor._internal_process() return self.process_result() def process_result(self): """ Returns a utility string as the result of processing. This method is part of the `AdvancedProcessor` class workflow, providing a consistent utility value after processing operations. It is typically used when a placeholder or generic result is needed following the execution of data processing tasks within the class. Returns: str: The string 'utility', serving as a generic utility value to indicate the completion of processing tasks. """ return utility_function() ================================================ FILE: data/raw_test_repo_simple/test_file.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates def test_function(): """ Returns a boolean value indicating a successful test condition. This function is typically used in scenarios where a simple, consistent boolean value is required to represent a successful outcome or condition. It can be integrated into workflows that need a straightforward pass/fail indicator for testing or validation purposes. Returns: bool: The boolean value `True`, indicating a successful or positive condition. """ return True ================================================ FILE: eval_completeness.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import ast import os from typing import Dict, Any, List, Union from pathlib import Path from evaluator.completeness import ClassCompletenessEvaluator, FunctionCompletenessEvaluator from tabulate import tabulate def run_docstring_tests(source_file: str) -> Dict[str, Any]: """ Run comprehensive docstring evaluation tests on a Python source file. This function reads a Python file and evaluates docstrings for all classes, functions, and methods found within. It provides detailed evaluation results using different evaluators. Args: source_file: Path to the Python file to analyze Returns: Dictionary containing evaluation results for each found element Example: >>> results = run_docstring_tests('my_module.py') >>> print(results['functions'][0]) 1.0 """ with open(source_file, 'r', encoding='utf-8') as f: source = f.read() try: tree = ast.parse(source) except SyntaxError as e: return { 'status': 'error', 'message': f'Failed to parse {source_file}: {str(e)}' } results = { 'status': 'success', 'file': source_file, 'classes': [], 'functions': [], 'debug_info': {} } # Instantiate evaluators class_evaluator = ClassCompletenessEvaluator() func_evaluator = FunctionCompletenessEvaluator() # Process all nodes in the AST for node in ast.iter_child_nodes(tree): if isinstance(node, ast.ClassDef): class_result = { 'name': node.name, 'type': 'class', 'completeness_score': class_evaluator.evaluate(node), 'completeness_elements': class_evaluator.element_scores, 'element_required': class_evaluator.element_required } results['classes'].append(class_result) # Evaluate methods within the class for method in [n for n in ast.iter_child_nodes(node) if isinstance(n, ast.FunctionDef)]: # Skip __init__ methods if method.name == '__init__': continue method_result = { 'name': f"{node.name}.{method.name}", 'type': 'method', 'completeness_score': func_evaluator.evaluate(method), 'completeness_elements': func_evaluator.element_scores, 'element_required': func_evaluator.element_required } results['functions'].append(method_result) elif isinstance(node, ast.FunctionDef): # Only process top-level functions func_result = { 'name': node.name, 'type': 'function', 'completeness_score': func_evaluator.evaluate(node), 'completeness_elements': func_evaluator.element_scores, 'element_required': func_evaluator.element_required } results['functions'].append(func_result) # Add overall statistics results['statistics'] = { 'total_classes': len(results['classes']), 'total_functions': len(results['functions']), 'average_class_score': sum(r['completeness_score'] for r in results['classes']) / max(1, len(results['classes'])), 'average_function_score': sum(r['completeness_score'] for r in results['functions']) / max(1, len(results['functions'])) } return results def process_directory(directory_path: str) -> Dict[str, Any]: """ Process all Python files in a directory and its subdirectories. Args: directory_path: Path to the directory to analyze Returns: Dictionary containing aggregated evaluation results for all files """ directory = Path(directory_path) # Initialize aggregate results aggregate_results = { 'status': 'success', 'directory': str(directory), 'files': [], 'file_results': [], 'classes': [], 'functions': [], 'statistics': { 'total_files': 0, 'successful_files': 0, 'failed_files': 0, 'total_classes': 0, 'total_functions': 0, 'average_class_score': 0.0, 'average_function_score': 0.0, 'overall_average_score': 0.0 } } # Find all Python files recursively python_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith('.py'): python_files.append(os.path.join(root, file)) if not python_files: aggregate_results['status'] = 'error' aggregate_results['message'] = f'No Python files found in {directory_path}' return aggregate_results aggregate_results['statistics']['total_files'] = len(python_files) # Process each Python file all_class_scores = [] all_function_scores = [] for py_file in python_files: file_result = run_docstring_tests(py_file) if file_result['status'] == 'success': aggregate_results['successful_files'] = aggregate_results['statistics']['successful_files'] + 1 aggregate_results['file_results'].append(file_result) aggregate_results['files'].append(py_file) # Accumulate classes and functions with file path context for class_result in file_result['classes']: class_result['file'] = py_file aggregate_results['classes'].append(class_result) all_class_scores.append(class_result['completeness_score']) for func_result in file_result['functions']: func_result['file'] = py_file aggregate_results['functions'].append(func_result) all_function_scores.append(func_result['completeness_score']) # Update statistics aggregate_results['statistics']['total_classes'] += file_result['statistics']['total_classes'] aggregate_results['statistics']['total_functions'] += file_result['statistics']['total_functions'] else: aggregate_results['statistics']['failed_files'] += 1 # Calculate average scores if all_class_scores: aggregate_results['statistics']['average_class_score'] = sum(all_class_scores) / len(all_class_scores) if all_function_scores: aggregate_results['statistics']['average_function_score'] = sum(all_function_scores) / len(all_function_scores) # Calculate overall average score (classes and functions combined) all_scores = all_class_scores + all_function_scores if all_scores: aggregate_results['statistics']['overall_average_score'] = sum(all_scores) / len(all_scores) return aggregate_results def print_evaluation_results(results: Dict[str, Any]) -> None: """ Pretty print the evaluation results in a readable format with colors. Args: results: Dictionary containing evaluation results from run_docstring_tests """ # ANSI color codes GREEN = '\033[92m' RED = '\033[91m' BLUE = '\033[94m' YELLOW = '\033[93m' BOLD = '\033[1m' ENDC = '\033[0m' # Check if this is a directory result or a file result is_directory = 'directory' in results if is_directory: # Print directory path print(f"\n{BOLD}Evaluating Python files in directory: {results['directory']}{ENDC}") print("=" * 80) # Print file summary print(f"\n{BLUE}{BOLD}FILE SUMMARY:{ENDC}") stats_data = [ ['Total Files', results['statistics']['total_files']], ['Successfully Processed Files', results['statistics']['successful_files']], ['Failed Files', results['statistics']['failed_files']] ] print(tabulate(stats_data, tablefmt='simple')) # Print overall statistics print(f"\n{BLUE}{BOLD}OVERALL STATISTICS:{ENDC}") # Add colored statistics class_score = results['statistics']['average_class_score'] if class_score >= 0.8: class_score_str = f"{GREEN}{class_score:.2f}{ENDC}" elif class_score >= 0.5: class_score_str = f"{YELLOW}{class_score:.2f}{ENDC}" else: class_score_str = f"{RED}{class_score:.2f}{ENDC}" func_score = results['statistics']['average_function_score'] if func_score >= 0.8: func_score_str = f"{GREEN}{func_score:.2f}{ENDC}" elif func_score >= 0.5: func_score_str = f"{YELLOW}{func_score:.2f}{ENDC}" else: func_score_str = f"{RED}{func_score:.2f}{ENDC}" overall_score = results['statistics']['overall_average_score'] if overall_score >= 0.8: overall_score_str = f"{GREEN}{overall_score:.2f}{ENDC}" elif overall_score >= 0.5: overall_score_str = f"{YELLOW}{overall_score:.2f}{ENDC}" else: overall_score_str = f"{RED}{overall_score:.2f}{ENDC}" stats_data = [ ['Total Classes', results['statistics']['total_classes']], ['Total Functions/Methods', results['statistics']['total_functions']], ['Average Class Score', class_score_str], ['Average Function Score', func_score_str], ['Overall Average Score', overall_score_str] ] print(tabulate(stats_data, tablefmt='simple')) # Ask if the user wants to see details for individual files print(f"\nUse python {os.path.basename(__file__)} to see detailed results for a specific file.") else: # Original single file display logic # Print file path print(f"\n{BOLD}Evaluating Python file: {results['file']}{ENDC}") print("=" * 80) # Print class results table if results['classes']: print(f"\n{BLUE}{BOLD}CLASSES:{ENDC}") headers = ['Class Name', 'Score'] elements = list(results['classes'][0]['completeness_elements'].keys()) headers.extend(elements) table_data = [] for class_result in results['classes']: row = [class_result['name']] score = class_result['completeness_score'] # Color the score based on value if score >= 0.8: score_str = f"{GREEN}{score:.2f}{ENDC}" elif score >= 0.5: score_str = f"{YELLOW}{score:.2f}{ENDC}" else: score_str = f"{RED}{score:.2f}{ENDC}" row.append(score_str) for element in elements: required = class_result['element_required'][element] has_element = class_result['completeness_elements'][element] if has_element: check = f"{GREEN}✓{ENDC}" else: check = f"{RED}✗{ENDC}" cell = f"{YELLOW if required else '-'}{'R' if required else ''}{ENDC if required else ''} | {check}" row.append(cell) table_data.append(row) print(tabulate(table_data, headers=headers, tablefmt='grid')) # Print function/method results table if results['functions']: print(f"\n{BLUE}{BOLD}FUNCTIONS/METHODS:{ENDC}") headers = ['Function Name', 'Type', 'Score'] elements = list(results['functions'][0]['completeness_elements'].keys()) headers.extend(elements) table_data = [] for func_result in results['functions']: row = [func_result['name'], func_result['type']] score = func_result['completeness_score'] # Color the score based on value if score >= 0.8: score_str = f"{GREEN}{score:.2f}{ENDC}" elif score >= 0.5: score_str = f"{YELLOW}{score:.2f}{ENDC}" else: score_str = f"{RED}{score:.2f}{ENDC}" row.append(score_str) for element in elements: required = func_result['element_required'][element] has_element = func_result['completeness_elements'][element] if has_element: check = f"{GREEN}✓{ENDC}" else: check = f"{RED}✗{ENDC}" cell = f"{YELLOW if required else '-'}{'R' if required else ''}{ENDC if required else ''} | {check}" row.append(cell) table_data.append(row) print(tabulate(table_data, headers=headers, tablefmt='grid')) # Print overall statistics print(f"\n{BLUE}{BOLD}OVERALL STATISTICS:{ENDC}") stats_data = [] # Add colored statistics class_score = results['statistics']['average_class_score'] if class_score >= 0.8: class_score_str = f"{GREEN}{class_score:.2f}{ENDC}" elif class_score >= 0.5: class_score_str = f"{YELLOW}{class_score:.2f}{ENDC}" else: class_score_str = f"{RED}{class_score:.2f}{ENDC}" func_score = results['statistics']['average_function_score'] if func_score >= 0.8: func_score_str = f"{GREEN}{func_score:.2f}{ENDC}" elif func_score >= 0.5: func_score_str = f"{YELLOW}{func_score:.2f}{ENDC}" else: func_score_str = f"{RED}{func_score:.2f}{ENDC}" stats_data = [ ['Total Classes', results['statistics']['total_classes']], ['Total Functions/Methods', results['statistics']['total_functions']], ['Average Class Score', class_score_str], ['Average Function Score', func_score_str] ] print(tabulate(stats_data, tablefmt='simple')) if __name__ == "__main__": # Example usage import sys if len(sys.argv) < 2: print("Usage: python eval_completeness.py ") sys.exit(1) path = sys.argv[1] if not Path(path).exists(): print(f"Error: Path not found: {path}") sys.exit(1) if Path(path).is_dir(): # Process directory results = process_directory(path) if results['status'] == 'success': print_evaluation_results(results) else: print(f"Error: {results['message']}") else: # Process single file results = run_docstring_tests(path) if results['status'] == 'success': print_evaluation_results(results) else: print(f"Error: {results['message']}") ================================================ FILE: generate_docstrings.py ================================================ #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates """ Docstring Generator with Dependency-Based Ordering This script generates docstrings for Python code components (functions, classes, methods) using a DFS-based approach that starts from components with no dependencies. Key features: 1. Parses Python code to identify components and their dependencies 2. Builds a dependency graph where A→B means "A depends on B" 3. Performs DFS traversal starting from components with no dependencies 4. Processes dependencies before the components that depend on them 5. Ensures classes depend on their methods, not vice versa 6. Skips __init__ methods as they typically don't need separate docstrings 7. Provides visual representation of progress in the terminal Usage: python generate_docstrings.py --repo-path PATH --config-path PATH [--test-mode] """ import os import sys import time import ast import json import argparse import logging import random from pathlib import Path from typing import Dict, List, Set, Optional, Any from collections import defaultdict import tiktoken # Add this import for token counting # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger("docstring_generator") # Import dependency analyzer modules from src.dependency_analyzer import ( CodeComponent, DependencyParser, dependency_first_dfs, build_graph_from_components ) from src.visualizer import ProgressVisualizer from src.agent.orchestrator import Orchestrator def generate_test_docstring(component: CodeComponent) -> str: """ Generate a placeholder docstring for test mode. Args: component: The code component to generate a placeholder docstring for. Returns: A placeholder docstring based on the component type. """ comp_type = component.component_type name = component.id.split(".")[-1] if comp_type == "function": return f""" Test docstring for function '{name}'. This is a placeholder docstring generated in test mode. In normal mode, this would be replaced with an AI-generated docstring. Args: arg1: Description of first argument arg2: Description of second argument Returns: Description of return value """ elif comp_type == "class": return f""" Test docstring for class '{name}'. This is a placeholder docstring generated in test mode. In normal mode, this would be replaced with an AI-generated docstring. Attributes: attr1: Description of first attribute attr2: Description of second attribute """ elif comp_type == "method": class_name = component.id.split(".")[-2] return f""" Test docstring for method '{name}' in class '{class_name}'. This is a placeholder docstring generated in test mode. In normal mode, this would be replaced with an AI-generated docstring. Args: arg1: Description of first argument arg2: Description of second argument Returns: Description of return value """ else: return f""" Test docstring for {comp_type} '{name}'. This is a placeholder docstring generated in test mode. """ def generate_docstring_for_component(component: CodeComponent, orchestrator: Optional[Orchestrator], test_mode: str = 'none', dependency_graph: Optional[Dict[str, List[str]]] = None) -> str: """ Generate a docstring for a single component. Args: component: The component to generate a docstring for. orchestrator: The orchestrator instance. test_mode: The test mode to use. dependency_graph: Optional dependency graph. Returns: The generated docstring. """ # do not use try/except here, we want to fail if there is an error if not orchestrator: return "" file_path = component.file_path # Get the component code component_code = component.source_code # Estimate token count of the focal component encoding = tiktoken.get_encoding("cl100k_base") # Default OpenAI encoding token_consume_focal = len(encoding.encode(component_code)) # Skip if the component is too large (> 10000 tokens) if token_consume_focal > 10000: # truncate the component code to 10000 tokens component_code = encoding.decode(encoding.encode(component_code)[:10000]) # Parse the file with open(file_path, "r", encoding="utf-8") as f: file_content = f.read() ast_tree = ast.parse(file_content) ast_node = None # Locate the AST node for the component component_parts = component.id.split(".") component_name = component_parts[-1] if component.component_type == "function": # Find top-level function for node in ast.iter_child_nodes(ast_tree): if (isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == component_name): ast_node = node break elif component.component_type == "class": # Find class for node in ast.iter_child_nodes(ast_tree): if isinstance(node, ast.ClassDef) and node.name == component_name: ast_node = node break elif component.component_type == "method": # Find method inside class class_name, method_name = component_parts[-2:] for node in ast.iter_child_nodes(ast_tree): if isinstance(node, ast.ClassDef) and node.name == class_name: for item in node.body: if (isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name): ast_node = item break break try: # Pass component.id as the focal_node_dependency_path docstring = orchestrator.process( focal_component=component_code, file_path=file_path, ast_node=ast_node, ast_tree=ast_tree, dependency_graph=dependency_graph, focal_node_dependency_path=component.id, token_consume_focal=token_consume_focal # Pass token count to orchestrator ) return docstring except Exception as e: print(f"Error generating docstring for {component.id}: {str(e)}") return "" def set_docstring_in_file(file_path: str, component: CodeComponent, docstring: str) -> bool: """ Update a Python file with a newly generated docstring for a component. Args: file_path: Path to the file to update. component: The component to update with a docstring. docstring: The docstring to insert. Returns: True if successful, False otherwise. """ # Do not use Try/Except here, we want to fail if there is an error # Read the file with open(file_path, "r", encoding="utf-8") as f: source = f.read() # Parse the file tree = ast.parse(source) # Find the component in the parsed AST component_node = None component_parts = component.id.split(".") component_name = component_parts[-1] if component.component_type == "function": # Find top-level function for node in ast.iter_child_nodes(tree): if (isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == component_name): component_node = node break elif component.component_type == "class": # Find class for node in ast.iter_child_nodes(tree): if isinstance(node, ast.ClassDef) and node.name == component_name: component_node = node break elif component.component_type == "method": # Find method inside class class_name, method_name = component_parts[-2:] for node in ast.iter_child_nodes(tree): if isinstance(node, ast.ClassDef) and node.name == class_name: for item in node.body: if (isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name): component_node = item break break if not component_node: logger.error(f"Could not find component {component.id} in {file_path}") return False # Set the docstring set_node_docstring(component_node, docstring) # Unparse the AST back to source code if hasattr(ast, "unparse"): new_source = ast.unparse(tree) else: try: import astor new_source = astor.to_source(tree) except ImportError: logger.error( "Error: You need to install 'astor' or use Python 3.9+ to unparse the AST. " f"Skipping file: {file_path}" ) return False # Write back to the file with open(file_path, "w", encoding="utf-8") as f: f.write(new_source) return True def set_node_docstring(node: ast.AST, docstring: str): """ Safely set or update the docstring on an AST node (ClassDef, FunctionDef, etc.). Also adjusts indentation relative to the node's existing indentation level, ensuring both the opening and closing triple quotes are properly aligned. Args: node: The AST node to modify (ClassDef, FunctionDef, etc.). docstring: The new docstring (as a plain string) to insert. """ import textwrap # 1. Strip leading/trailing empty lines in the provided docstring # to avoid spurious blank lines. stripped_docstring = docstring.strip('\n') if not stripped_docstring: # If empty or all whitespace, provide a placeholder stripped_docstring = "No docstring provided." # 2. Dedent possible indentation in docstring (so a multiline docstring # doesn't carry undesired left margins). dedented = textwrap.dedent(stripped_docstring) # 3. Determine how many spaces to indent for doc lines plus triple quotes. existing_indent = getattr(node, 'col_offset', 0) doc_indent_str = ' ' * (existing_indent + 4) # 4. Build the final string: # - Start with a newline (so triple quotes appear on a new line). # - Indent all docstring lines. # - End with a newline+same indentation (so the closing triple quotes # line also has the doc_indent_str). prepared_docstring = ( "\n" + textwrap.indent(dedented, doc_indent_str) + "\n" + doc_indent_str ) # 5. Create an AST Expr node to store this docstring as a constant. docstring_node = ast.Expr(value=ast.Constant(value=prepared_docstring, kind=None)) # If there's no body, just make one containing our new docstring. if not hasattr(node, "body") or not isinstance(node.body, list) or len(node.body) == 0: node.body = [docstring_node] else: # If the first statement is an existing docstring, replace it; # otherwise, insert the new docstring as the first statement. first_stmt = node.body[0] if ( isinstance(first_stmt, ast.Expr) and isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str) ): node.body[0] = docstring_node else: node.body.insert(0, docstring_node) def main(): """ Main entry point for the docstring generation script with flexible component ordering. The script supports different ordering modes through the --order-mode flag: - 'topo' (default): Dependency-based ordering using a DFS-based approach: 1. If A depends on B, the graph has an edge A→B (meaning "A depends on B") 2. Root nodes (nodes with no dependencies) are processed first 3. Dependencies are always processed before the components that depend on them 4. This ensures proper docstring generation order - 'random_node': Randomly shuffles all Python components, ignoring dependencies - 'random_file': Processes files in random order, but preserves component order within files Class methods are processed before the classes that depend on them (not vice versa) in 'topo' mode, ensuring proper docstring generation order. Special __init__ methods are skipped as they typically don't need separate docstrings. The script provides options to skip or overwrite existing docstrings: - By default, components with existing docstrings are skipped - With --overwrite-docstrings flag, existing docstrings will be overwritten - This behavior can also be configured in the config.yaml file under docstring_options.overwrite_docstrings Web interface integration: - With --enable-web flag, the script enables integration with the web UI - This allows visualization of the docstring generation process in a web browser - Run the web UI separately using the run_web_ui.py script """ # Parse command line arguments parser = argparse.ArgumentParser( description='Generate docstrings for Python components in dependency order.' ) parser.add_argument( '--repo-path', type=str, default='data/raw_test_repo', help='Path to the repository (default: data/raw_test_repo)' ) parser.add_argument( '--config-path', type=str, default='config/agent_config.yaml', help='Path to the configuration file (default: config/agent_config.yaml)' ) parser.add_argument( '--test-mode', type=str, choices=['placeholder', 'context_print', 'none'], default='none', help='Test mode to run: "placeholder" for placeholder docstrings (no LLM calls), "context_print" to print context before writer calls, "none" for normal operation' ) parser.add_argument( '--order-mode', type=str, choices=['topo', 'random_node', 'random_file'], default='topo', help='Order mode for docstring generation: "topo" follows dependency order (default), "random_node" selects random Python nodes, "random_file" processes files in random order' ) parser.add_argument( '--enable-web', action='store_true', help='Enable integration with the web interface' ) parser.add_argument( '--overwrite-docstrings', action='store_true', help='Overwrite existing docstrings instead of skipping them (default: False)' ) args = parser.parse_args() repo_path = args.repo_path config_path = args.config_path test_mode = args.test_mode order_mode = args.order_mode overwrite_docstrings = args.overwrite_docstrings # Create output directory for dependency graph output_dir = os.path.join("output", "dependency_graphs") os.makedirs(output_dir, exist_ok=True) # Extract repository name from path for creating a unique filename repo_name = os.path.basename(os.path.normpath(repo_path)) # Create a sanitized version of the repo name (remove special characters) sanitized_repo_name = ''.join(c if c.isalnum() else '_' for c in repo_name) dependency_graph_path = os.path.join(output_dir, f"{sanitized_repo_name}_dependency_graph.json") # Initialize the orchestrator for docstring generation orchestrator = None # Initialize orchestrator unless we're in placeholder test mode if test_mode != 'placeholder': logger.info(f"Initializing orchestrator with config: {config_path}") # Pass the test_mode to the orchestrator if it's "context_print" orchestrator_test_mode = test_mode if test_mode != 'none' else None orchestrator = Orchestrator(repo_path=repo_path, config_path=config_path, test_mode=orchestrator_test_mode) # Check if the overwrite_docstrings option is in the config file # If it's there, it overrides the command-line argument if hasattr(orchestrator, 'config'): docstring_options = orchestrator.config.get('docstring_options', {}) config_overwrite = docstring_options.get('overwrite_docstrings') if config_overwrite is not None: overwrite_docstrings = config_overwrite logger.info(f"Using config file setting for overwrite_docstrings: {overwrite_docstrings}") else: logger.info("Running in PLACEHOLDER TEST MODE with placeholder docstrings (no LLM calls)") # Parse the repository to build the dependency graph logger.info(f"Parsing repository: {repo_path}") parser = DependencyParser(repo_path) components = parser.parse_repository() # Save the dependency graph for future reference parser.save_dependency_graph(dependency_graph_path) logger.info(f"Dependency graph saved to: {dependency_graph_path}") # Build the graph for traversal graph = build_graph_from_components(components) # Create a dependency graph in the format expected by the orchestrator # Dictionary mapping component paths to their dependencies dependency_graph = {} for component_id, deps in graph.items(): dependency_graph[component_id] = list(deps) # Perform DFS-based traversal logger.info("Performing DFS traversal on the dependency graph (starting from nodes with no dependencies)") sorted_components = dependency_first_dfs(graph) logger.info(f"Sorted {len(sorted_components)} components for processing") # Apply the selected ordering mode if order_mode == 'random_node': # Randomly shuffle all components logger.info("Using random node ordering mode - shuffling all components") random.shuffle(sorted_components) elif order_mode == 'random_file': # Group components by file path logger.info("Using random file ordering mode - processing files in random order") # Group components by file file_to_components = defaultdict(list) for component_id in sorted_components: component = components.get(component_id) if component: file_to_components[component.file_path].append(component_id) # Randomly shuffle the file order but maintain the order of components within each file file_paths = list(file_to_components.keys()) random.shuffle(file_paths) # Create a new ordering based on randomly shuffled files sorted_components = [] for file_path in file_paths: sorted_components.extend(file_to_components[file_path]) else: # Default to topological order (already set in sorted_components) logger.info("Using topological ordering mode - processing components based on dependencies") # Check if web interface is enabled if args.enable_web: try: from src.visualizer.web_bridge import patch_visualizers logger.info("Web interface integration enabled") patch_visualizers() except ImportError as e: logger.warning(f"Failed to enable web interface integration: {e}") logger.warning("Make sure you have installed the required web dependencies") # Initialize the progress visualizer visualizer = ProgressVisualizer(components, sorted_components) visualizer.initialize() # Show dependency statistics visualizer.show_dependency_stats() # Process components in order determined by DFS traversal for component_id in sorted_components: component = components.get(component_id) if not component: logger.warning(f"Component {component_id} not found in parsed components") continue # Skip __init__ methods as they don't need docstrings if component.component_type == "method" and component_id.endswith(".__init__"): logger.info(f"Skipping {component_id} - __init__ methods don't need docstrings") visualizer.update(component_id, "completed") continue # compute the length of docstring if exists (using white space as delimiter) docstring_length = len(component.docstring.split()) if component.has_docstring else 0 # Skip components that already have docstrings (unless overwrite_docstrings is True) if component.has_docstring and not overwrite_docstrings and docstring_length > 10: logger.info(f"Skipping {component_id} - already has docstring") visualizer.update(component_id, "completed") continue elif component.has_docstring and overwrite_docstrings: logger.info(f"Overwriting existing docstring for {component_id}") # Update the visualizer visualizer.update(component_id, "processing") # Log the component type comp_type = component.component_type logger.info(f"Processing {comp_type}: {component_id}") # Generate the docstring logger.info(f"Generating docstring for {component_id}") docstring = generate_docstring_for_component(component, orchestrator, test_mode, dependency_graph) # Update the file with the new docstring file_path = component.file_path success = set_docstring_in_file(file_path, component, docstring) if success: logger.info(f"Successfully updated docstring for {component_id}") visualizer.update(component_id, "completed") else: logger.error(f"Failed to update docstring for {component_id}") visualizer.update(component_id, "error") # Re-parse the file in case the line numbers changed due to docstring insertion # This is only necessary if there are more components from the same file same_file_components = [ comp_id for comp_id in sorted_components if comp_id != component_id and components[comp_id].file_path == file_path ] if same_file_components: logger.info(f"Re-parsing file {file_path} for updated line numbers") parser = DependencyParser(repo_path) updated_components = parser.parse_repository() # Update the components dictionary with new line numbers for comp_id, comp in updated_components.items(): if comp_id in components: components[comp_id] = comp # Finalize the visualization visualizer.finalize() # Create a more descriptive mode message based on the test mode if test_mode == 'placeholder': mode_str = "PLACEHOLDER TEST MODE (no LLM calls)" elif test_mode == 'context_print': mode_str = "CONTEXT PRINT TEST MODE (with context debugging)" else: mode_str = "normal mode" # Add ordering mode to the completion message order_mode_str = { 'topo': 'topological ordering', 'random_node': 'random node ordering', 'random_file': 'random file ordering' }.get(order_mode, 'unknown ordering') logger.info(f"Docstring generation complete ({mode_str}, {order_mode_str})") # Print usage statistics for LLM providers if available if orchestrator: try: # Access the rate limiters from agents rate_limiters = [] for agent_name in ['reader', 'writer', 'verifier']: agent = getattr(orchestrator, agent_name, None) if agent and hasattr(agent, 'llm') and hasattr(agent.llm, 'rate_limiter'): rate_limiters.append(agent.llm.rate_limiter) # Print statistics for each rate limiter if rate_limiters: logger.info("=" * 50) logger.info("TOKEN USAGE AND COST STATISTICS") logger.info("=" * 50) for limiter in rate_limiters: limiter.print_usage_stats() # Calculate total cost across all limiters total_cost = sum(limiter.total_cost for limiter in rate_limiters) logger.info("=" * 50) logger.info(f"TOTAL COST: ${total_cost:.6f}") logger.info("=" * 50) except Exception as e: logger.warning(f"Could not print token usage statistics: {e}") if __name__ == "__main__": main() ================================================ FILE: output/dependency_graphs/raw_test_repo_dependency_graph.json ================================================ { "helper.HelperClass": { "id": "helper.HelperClass", "component_type": "class", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [ "helper.HelperClass.get_result", "helper.DataProcessor", "helper.HelperClass._internal_process", "helper.HelperClass.process_data" ], "start_line": 1, "end_line": 14, "has_docstring": false, "docstring": "" }, "helper.HelperClass.__init__": { "id": "helper.HelperClass.__init__", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [], "start_line": 3, "end_line": 4, "has_docstring": false, "docstring": "" }, "helper.HelperClass.process_data": { "id": "helper.HelperClass.process_data", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [ "helper.DataProcessor" ], "start_line": 6, "end_line": 8, "has_docstring": false, "docstring": "" }, "helper.HelperClass._internal_process": { "id": "helper.HelperClass._internal_process", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [], "start_line": 10, "end_line": 11, "has_docstring": false, "docstring": "" }, "helper.HelperClass.get_result": { "id": "helper.HelperClass.get_result", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [], "start_line": 13, "end_line": 14, "has_docstring": false, "docstring": "" }, "helper.DataProcessor": { "id": "helper.DataProcessor", "component_type": "class", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [ "helper.DataProcessor.process", "helper.DataProcessor._internal_process" ], "start_line": 16, "end_line": 72, "has_docstring": true, "docstring": "\n \"\"\"Handles basic data processing tasks within a system.\n\n This class is designed to perform simple data processing operations, providing\n utility methods that can be used in various scenarios where basic data manipulation\n is required. It is particularly useful in contexts where a straightforward list of\n integers is needed for further processing or testing.\n\n The `DataProcessor` class fits into the larger system architecture as a utility\n component, offering static and internal methods to handle specific processing tasks.\n It achieves its purpose by providing a static method for general use and an internal\n method for class-specific operations.\n\n Example:\n # Initialize the DataProcessor class\n processor = DataProcessor()\n\n # Use the static method to process data\n result = DataProcessor.process()\n print(result) # Output: [1, 2, 3]\n\n # Use the internal method for internal processing\n internal_result = processor._internal_process()\n print(internal_result) # Output: 'processed'\n \"\"\"\n " }, "helper.DataProcessor.process": { "id": "helper.DataProcessor.process", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [], "start_line": 45, "end_line": 57, "has_docstring": true, "docstring": "\n \"\"\"Processes data and returns a list of integers.\n\n This static method is designed to perform a basic data processing task\n and return a predefined list of integers. It can be used whenever a simple\n list of integers is required for further operations or testing purposes.\n\n Returns:\n list of int: A list containing the integers [1, 2, 3].\n \"\"\"\n " }, "helper.DataProcessor._internal_process": { "id": "helper.DataProcessor._internal_process", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/helper.py", "relative_path": "helper.py", "depends_on": [], "start_line": 59, "end_line": 72, "has_docstring": true, "docstring": "\n \"\"\"Processes internal data and returns a status message.\n\n This method is used internally within the `DataProcessor` class to perform\n specific data processing tasks that are not exposed publicly. It is typically\n called by other methods within the class to handle intermediate processing\n steps.\n\n Returns:\n str: A string indicating the processing status, specifically 'processed'.\n \"\"\"\n " }, "main.main_function": { "id": "main.main_function", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/main.py", "relative_path": "main.py", "depends_on": [ "helper.HelperClass", "inner.inner_functions.generate_timestamp", "main.utility_function" ], "start_line": 5, "end_line": 10, "has_docstring": false, "docstring": "" }, "main.utility_function": { "id": "main.utility_function", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/main.py", "relative_path": "main.py", "depends_on": [], "start_line": 13, "end_line": 14, "has_docstring": false, "docstring": "" }, "processor.AdvancedProcessor": { "id": "processor.AdvancedProcessor", "component_type": "class", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/processor.py", "relative_path": "processor.py", "depends_on": [ "processor.AdvancedProcessor.run", "helper.HelperClass", "processor.AdvancedProcessor.process_result", "main.utility_function", "processor.DataProcessor" ], "start_line": 6, "end_line": 18, "has_docstring": false, "docstring": "" }, "processor.AdvancedProcessor.__init__": { "id": "processor.AdvancedProcessor.__init__", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/processor.py", "relative_path": "processor.py", "depends_on": [ "helper.HelperClass", "processor.DataProcessor" ], "start_line": 8, "end_line": 10, "has_docstring": false, "docstring": "" }, "processor.AdvancedProcessor.run": { "id": "processor.AdvancedProcessor.run", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/processor.py", "relative_path": "processor.py", "depends_on": [], "start_line": 12, "end_line": 15, "has_docstring": false, "docstring": "" }, "processor.AdvancedProcessor.process_result": { "id": "processor.AdvancedProcessor.process_result", "component_type": "method", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/processor.py", "relative_path": "processor.py", "depends_on": [ "main.utility_function" ], "start_line": 17, "end_line": 18, "has_docstring": false, "docstring": "" }, "test_file.test_function": { "id": "test_file.test_function", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/test_file.py", "relative_path": "test_file.py", "depends_on": [], "start_line": 1, "end_line": 2, "has_docstring": false, "docstring": "" }, "inner.inner_functions.inner_function": { "id": "inner.inner_functions.inner_function", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/inner/inner_functions.py", "relative_path": "inner/inner_functions.py", "depends_on": [], "start_line": 1, "end_line": 15, "has_docstring": true, "docstring": "\n Returns a greeting message from an inner function.\n\n This function is designed to return a simple greeting message, which can be used in nested or internal function calls to verify execution flow or for debugging purposes. It is typically used in development environments where confirming the execution of specific code paths is necessary.\n\n Returns:\n str: A greeting message stating 'Hello from inner function!'\n\n Example:\n >>> message = inner_function()\n >>> print(message)\n 'Hello from inner function!'\n " }, "inner.inner_functions.get_random_quote": { "id": "inner.inner_functions.get_random_quote", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/inner/inner_functions.py", "relative_path": "inner/inner_functions.py", "depends_on": [], "start_line": 17, "end_line": 31, "has_docstring": true, "docstring": "\n Fetches a predefined inspirational quote.\n\n This function is designed to provide users with a motivational quote, which can be used in applications that aim to inspire or uplift users. It is particularly useful in scenarios where a quick, positive message is needed to enhance user experience.\n\n Returns:\n str: A quote string stating 'The best way to predict the future is to create it.'\n\n Example:\n >>> quote = get_random_quote()\n >>> print(quote)\n 'The best way to predict the future is to create it.'\n " }, "inner.inner_functions.generate_timestamp": { "id": "inner.inner_functions.generate_timestamp", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/inner/inner_functions.py", "relative_path": "inner/inner_functions.py", "depends_on": [], "start_line": 33, "end_line": 34, "has_docstring": false, "docstring": "" }, "inner.inner_functions.get_system_status": { "id": "inner.inner_functions.get_system_status", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/inner/inner_functions.py", "relative_path": "inner/inner_functions.py", "depends_on": [], "start_line": 36, "end_line": 50, "has_docstring": true, "docstring": "\n Provides a static message indicating the operational status of systems.\n\n This function is used to retrieve a fixed status message that confirms all systems are functioning correctly. It is useful in monitoring dashboards or status pages where a quick confirmation of system health is required.\n\n Returns:\n str: A status message stating 'All systems operational.'\n\n Example:\n >>> status = get_system_status()\n >>> print(status)\n 'All systems operational'\n " }, "inner.inner_functions.fetch_user_message": { "id": "inner.inner_functions.fetch_user_message", "component_type": "function", "file_path": "/home/dayuyang/DocAgent/data/raw_test_repo/inner/inner_functions.py", "relative_path": "inner/inner_functions.py", "depends_on": [], "start_line": 52, "end_line": 67, "has_docstring": true, "docstring": "\n \"\"\"Fetches a predefined user message indicating notifications.\n\n This function is used to retrieve a static message that informs the user about the number of notifications they have. It is typically used in scenarios where a quick status update is needed for user engagement.\n\n Returns:\n str: A message string stating 'Welcome back! You have 3 notifications.'\n\n Example:\n >>> message = fetch_user_message()\n >>> print(message)\n 'Welcome back! You have 3 notifications.'\n \"\"\"\n " } } ================================================ FILE: run_web_ui.py ================================================ #!/usr/bin/env python3 import eventlet eventlet.monkey_patch() # Copyright (c) Meta Platforms, Inc. and affiliates """ Web UI Launcher for DocAgent Docstring Generator This script launches the web-based user interface for the docstring generation tool. The UI provides a more interactive and visual way to use the docstring generator, with real-time feedback and progress tracking. Usage: python run_web_ui.py [--host HOST] [--port PORT] [--debug] """ import argparse import os import sys import logging from pathlib import Path # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger("docstring_web") # Add the current directory to the path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) def check_dependencies(): """Check if all required dependencies are installed.""" try: import flask import flask_socketio import eventlet import yaml import tabulate import colorama return True except ImportError as e: missing_module = str(e).split("'")[1] logger.error(f"Missing dependency: {missing_module}") logger.error("Please install all required dependencies with:") logger.error("pip install -r requirements-web.txt") return False def main(): """Parse command line arguments and start the web UI.""" parser = argparse.ArgumentParser(description='Launch the DocAgent Web UI') parser.add_argument('--host', default='127.0.0.1', help='Host to bind the server to') parser.add_argument('--port', type=int, default=5000, help='Port to bind the server to') parser.add_argument('--debug', action='store_true', help='Run in debug mode') args = parser.parse_args() # Check dependencies if not check_dependencies(): return 1 # Print banner print("\n" + "=" * 80) print("DocAgent Web Interface".center(80)) print("=" * 80) # Import and run the web app try: # First try to import eventlet to ensure it's properly initialized import eventlet eventlet.monkey_patch() from src.web.app import create_app app, socketio = create_app(debug=args.debug) logger.info(f"Starting DocAgent Web UI at: http://{args.host}:{args.port}") logger.info("Press Ctrl+C to stop the server") # Start the server socketio.run(app, host=args.host, port=args.port, debug=args.debug, allow_unsafe_werkzeug=True) return 0 except ImportError as e: logger.error(f"Error importing web application: {e}") logger.error("Make sure the src/web directory exists and contains the necessary files.") return 1 except Exception as e: logger.error(f"Error running web application: {e}") return 1 if __name__ == '__main__': try: sys.exit(main()) except KeyboardInterrupt: print("\nServer stopped.") sys.exit(0) ================================================ FILE: setup.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from setuptools import setup, find_packages # Read the contents of README file from pathlib import Path this_directory = Path(__file__).parent long_description = (this_directory / "README.md").read_text() # Prepare all extras dev_requires = [ "pytest>=8.3.4", "pytest-cov>=2.0", "black>=22.0", "flake8>=3.9", ] web_requires = [ "flask>=3.1.0", "flask-socketio>=5.5.1", "eventlet>=0.39.0", "python-socketio>=5.12.1", "python-engineio>=4.11.2", "bidict>=0.23.0", "dnspython>=2.7.0", "six>=1.16.0", ] visualization_requires = [ "matplotlib>=3.10.0", "pygraphviz>=1.14", "networkx>=3.4.2", ] cuda_requires = [ "torch>=2.0.0", "accelerate>=1.4.0", ] # Combine all extras for the 'all' option all_requires = dev_requires + web_requires + visualization_requires + cuda_requires setup( name="DocstringGenerator", version="0.1.0", author="Dayu Yang", author_email="dayuyang@meta.com", description="DocAgent for High-quality docstring generation in Large-scale Python projects", long_description=long_description, long_description_content_type="text/markdown", packages=find_packages(where="src"), package_dir={"": "src"}, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], python_requires=">=3.8", install_requires=[ # Core dependencies "numpy>=1.23.5", "pyyaml>=6.0", "jinja2>=3.1.5", "requests>=2.32.0", "urllib3>=2.3.0", # Code analysis tools "astor>=0.8.1", "code2flow>=2.5.1", "pydeps>=3.0.0", # AI/LLM related dependencies "anthropic>=0.45.0", "openai>=1.60.1", "langchain-anthropic>=0.3.4", "langchain-openai>=0.3.2", "langchain-core>=0.3.31", "langgraph>=0.2.67", "tiktoken>=0.8.0", "transformers>=4.48.0", "huggingface-hub>=0.28.0", "google-generativeai>=0.6.0", # Utility packages "tqdm>=4.67.1", "tabulate>=0.9.0", "colorama>=0.4.6", "termcolor>=2.5.0", "pydantic>=2.10.0", # Web requirements "flask>=3.1.0", "flask-socketio>=5.5.1", "eventlet>=0.39.0", "python-socketio>=5.12.1", "python-engineio>=4.11.2", "bidict>=0.23.0", "dnspython>=2.7.0", "six>=1.16.0", # CUDA requirements "torch>=2.0.0", "accelerate>=1.4.0", ], extras_require={ "dev": dev_requires, "web": web_requires, # Keep for potential compatibility, now included in core "visualization": visualization_requires, "cuda": cuda_requires, # Keep for potential compatibility, now included in core "all": all_requires, } ) ================================================ FILE: src/DocstringGenerator.egg-info/PKG-INFO ================================================ Metadata-Version: 2.2 Name: DocstringGenerator Version: 0.1.0 Summary: DocAgent for High-quality docstring generation in Large-scale Python projects Author: Dayu Yang Author-email: dayuyang@meta.com Classifier: Development Status :: 3 - Alpha Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: MIT License Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Requires-Python: >=3.8 Description-Content-Type: text/markdown Requires-Dist: numpy>=1.23.5 Requires-Dist: pyyaml>=6.0 Requires-Dist: jinja2>=3.1.5 Requires-Dist: requests>=2.32.0 Requires-Dist: urllib3>=2.3.0 Requires-Dist: astor>=0.8.1 Requires-Dist: code2flow>=2.5.1 Requires-Dist: pydeps>=3.0.0 Requires-Dist: anthropic>=0.45.0 Requires-Dist: openai>=1.60.1 Requires-Dist: langchain-anthropic>=0.3.4 Requires-Dist: langchain-openai>=0.3.2 Requires-Dist: langchain-core>=0.3.31 Requires-Dist: langgraph>=0.2.67 Requires-Dist: tiktoken>=0.8.0 Requires-Dist: transformers>=4.48.0 Requires-Dist: huggingface-hub>=0.28.0 Requires-Dist: google-generativeai>=0.6.0 Requires-Dist: tqdm>=4.67.1 Requires-Dist: tabulate>=0.9.0 Requires-Dist: colorama>=0.4.6 Requires-Dist: termcolor>=2.5.0 Requires-Dist: pydantic>=2.10.0 Requires-Dist: flask>=3.1.0 Requires-Dist: flask-socketio>=5.5.1 Requires-Dist: eventlet>=0.39.0 Requires-Dist: python-socketio>=5.12.1 Requires-Dist: python-engineio>=4.11.2 Requires-Dist: bidict>=0.23.0 Requires-Dist: dnspython>=2.7.0 Requires-Dist: six>=1.16.0 Requires-Dist: torch>=2.0.0 Requires-Dist: accelerate>=1.4.0 Provides-Extra: dev Requires-Dist: pytest>=8.3.4; extra == "dev" Requires-Dist: pytest-cov>=2.0; extra == "dev" Requires-Dist: black>=22.0; extra == "dev" Requires-Dist: flake8>=3.9; extra == "dev" Provides-Extra: web Requires-Dist: flask>=3.1.0; extra == "web" Requires-Dist: flask-socketio>=5.5.1; extra == "web" Requires-Dist: eventlet>=0.39.0; extra == "web" Requires-Dist: python-socketio>=5.12.1; extra == "web" Requires-Dist: python-engineio>=4.11.2; extra == "web" Requires-Dist: bidict>=0.23.0; extra == "web" Requires-Dist: dnspython>=2.7.0; extra == "web" Requires-Dist: six>=1.16.0; extra == "web" Provides-Extra: visualization Requires-Dist: matplotlib>=3.10.0; extra == "visualization" Requires-Dist: pygraphviz>=1.14; extra == "visualization" Requires-Dist: networkx>=3.4.2; extra == "visualization" Provides-Extra: cuda Requires-Dist: torch>=2.0.0; extra == "cuda" Requires-Dist: accelerate>=1.4.0; extra == "cuda" Provides-Extra: all Requires-Dist: pytest>=8.3.4; extra == "all" Requires-Dist: pytest-cov>=2.0; extra == "all" Requires-Dist: black>=22.0; extra == "all" Requires-Dist: flake8>=3.9; extra == "all" Requires-Dist: flask>=3.1.0; extra == "all" Requires-Dist: flask-socketio>=5.5.1; extra == "all" Requires-Dist: eventlet>=0.39.0; extra == "all" Requires-Dist: python-socketio>=5.12.1; extra == "all" Requires-Dist: python-engineio>=4.11.2; extra == "all" Requires-Dist: bidict>=0.23.0; extra == "all" Requires-Dist: dnspython>=2.7.0; extra == "all" Requires-Dist: six>=1.16.0; extra == "all" Requires-Dist: matplotlib>=3.10.0; extra == "all" Requires-Dist: pygraphviz>=1.14; extra == "all" Requires-Dist: networkx>=3.4.2; extra == "all" Requires-Dist: torch>=2.0.0; extra == "all" Requires-Dist: accelerate>=1.4.0; extra == "all" Dynamic: author Dynamic: author-email Dynamic: classifier Dynamic: description Dynamic: description-content-type Dynamic: provides-extra Dynamic: requires-dist Dynamic: requires-python Dynamic: summary # DocAgent: Agentic Hierarchical Docstring Generation System

Meta Logo

DocAgent is a system designed to generate high-quality, context-aware docstrings for Python codebases using a multi-agent approach and hierarchical processing. ## Table of Contents - [Motivation](#motivation) - [Methodology](#methodology) - [Installation](#installation) - [Components](#components) - [Usage](#usage) - [Data Handling](#data-handling) - [Baselines](#baselines) - [Development Notes](#development-notes) ## Motivation High-quality docstrings are crucial for code readability, usability, and maintainability, especially in large repositories. They should explain the purpose, parameters, returns, exceptions, and usage within the broader context. Current LLMs often struggle with this, producing superficial or redundant comments and failing to capture essential context or rationale. DocAgent aims to address these limitations by generating informative, concise, and contextually aware docstrings. ## Methodology DocAgent employs two key strategies: 1. **Hierarchical Traversal**: Processes code components by analyzing dependencies, starting with files having fewer dependencies. This builds a documented foundation before tackling more complex code, addressing the challenge of documenting context that itself lacks documentation. 2. **Agentic System**: Utilizes a team of specialized agents (`Reader`, `Searcher`, `Writer`, `Verifier`) coordinated by an `Orchestrator`. This system gathers context (internal and external), drafts docstrings according to standards, and verifies their quality in an iterative process. System Overview For more details on the agentic framework, see the [Agent Component README](./src/agent/README.md). ## Installation Detailed installation instructions using `pip` or `conda`, including optional dependencies and troubleshooting tips, can be found in [INSTALL.md](./INSTALL.md). ## Components DocAgent is composed of several key parts: - **[Core Agent Framework](./src/agent/README.md)**: Implements the multi-agent system (Reader, Searcher, Writer, Verifier, Orchestrator) responsible for the generation logic. - **[Docstring Evaluator](./src/evaluator/README.md)**: Provides tools for evaluating docstring quality, primarily focusing on completeness based on static code analysis (AST). - **[Generation Web UI](./src/web/README.md)**: A web interface for configuring, running, and monitoring the docstring *generation* process in real-time, visualizing agent activity and repository structure. - **[Evaluation Web UI](./src/web_eval/README.md)**: A separate web interface for configuring and running docstring *evaluations*, assessing completeness and helpfulness (using LLMs). ## Usage The primary ways to interact with DocAgent are: 1. **Generation Web UI**: Recommended for visualizing the generation process. Launch via: ```bash python run_web_ui.py ``` Then access `http://localhost:5000` (or as configured). See the [Generation Web UI README](./src/web/README.md) for details. 2. **Evaluation Web UI**: Recommended for assessing docstring quality. Launch via: ```bash cd src/web_eval ./start_server.sh # Or python app.py ``` Then access `http://localhost:5000` (or as configured). See the [Evaluation Web UI README](./src/web_eval/README.md) for details. 3. **Command Line (Generation)**: Run the generation process directly: ```bash # Example: Run on a test repo, removing existing docstrings first ./tool/remove_docstrings.sh data/raw_test_repo python generate_docstrings.py --repo-path data/raw_test_repo ``` Use `--help` for more options. ## Data Handling Tools are included for managing datasets for evaluation: - **GitHub Repository Downloader (`src/data/parse/downloader.py`)**: Finds and downloads GitHub repositories based on configurable criteria (language, stars, size, etc.). - **Repository Selection (`experiments/select_repos.py`)**: Selects a diverse subset of downloaded repositories based on metrics like code size and complexity. (See original README sections for detailed usage if needed). ## Baselines A simple "copy and paste" baseline is implemented (`experiments/generate_docstrings_copy_and_paste.py`) for comparison. It sends isolated code components to an LLM without context. ```bash # Example: Run baseline on a test repo ./tool/remove_docstrings.sh data/raw_test_repo python experiments/generate_docstrings_copy_and_paste.py --repo-path data/raw_test_repo ``` ## Development Notes - Remember to activate your chosen environment (`pip` or `conda`). - Use `pip install -e ".[dev]"` for development dependencies. - Run tests using `pytest`. - See [INSTALL.md](./INSTALL.md) for setting up system dependencies like GraphViz if needed for visualizations. --- *This README provides a high-level overview. Please refer to the linked component READMEs and `INSTALL.md` for specific details.* # Todo - repo-level eval script - argument vs parameter - "Need more information" seems does not work for codellama34B and gemini. - some repo depends on "not-install" package(ask you to install autogen after download the repo) - Query should also search internally - truncated "called by", especially Class (too long) - Overkill issue # For ACL Experiments ## Note class evaluate: - really means eval the init function (if has init) ## Data ### GitHub Repository Downloader The project includes a GitHubRepoDownloader that automates the process of finding and downloading repositories for docstring generation tasks. This tool allows you to specify various criteria to target repositories that match your requirements. #### Features: - **Configurable Search Criteria**: Filter repositories by owner, creation date, language, stars, forks, size, and license. - **Python Language Filtering**: Ensures downloaded repositories contain a minimum percentage of Python code (default: 80%). - **Repository Metadata**: Automatically saves metadata about each downloaded repository. - **Rate Limit Handling**: Respects GitHub API rate limits to avoid throttling. - **Logging**: Comprehensive logging of the download process. #### Usage: To download repositories, create a configuration file and run the downloader: ```bash python -m src.data.parse.downloader ``` #### Configuration: Create a YAML configuration file with the following structure: ```yaml # GitHub authentication GITHUB_TOKEN: "your-github-token" # Output directory output_directory: "data/downloaded_repos" # Repository limits max_repos: 10 skip_archived: true skip_forks: true min_python_percentage: 80 # Minimum percentage of Python code required # Search criteria search_criteria: language: "python" stars: min: 100 forks: min: 10 dates: created_after: "2020-01-01" owners: - "username1" - "org_name" ``` The downloader will: 1. Search GitHub repositories matching your criteria 2. Check if each repository meets the Python percentage requirement 3. Clone qualifying repositories to the specified output directory 4. Save repository metadata for further analysis ### Repository Selection After downloading repositories, you may want to select a diverse subset for analysis. The project includes a repository selection tool that helps you choose repositories with varying characteristics: #### Features: - **Diversity-Based Selection**: Select repositories based on code size and structural complexity. - **Code Size Metrics**: Calculates the number of Python files and total lines of code. - **Topological Complexity**: Measures the depth of the repository directory structure. - **Visualization**: Generates scatter plots showing the distribution of selected repositories. #### Usage: To select repositories from your downloaded collection: ```bash python -m experiments.select_repos ``` #### Process: The selection process follows these steps: 1. Analyzes each repository to extract metrics (Python files count, total lines, directory depth) 2. Normalizes the metrics to ensure fair comparison 3. Creates clusters of repositories with similar characteristics 4. Selects representatives from each cluster to ensure diversity 5. Generates a visualization of the selection results This approach ensures that your analysis includes repositories with varying sizes and complexity levels, providing a more comprehensive evaluation of docstring generation techniques. ## Baseline ### Copy and Paste We implemented a simple "copy and paste" baseline system that mimics the approach of users copying code components and pasting them directly to an LLM interface. This baseline: 1. Extracts individual code components (functions, classes, methods) from Python files 2. Sends only the component's source code to an LLM without any surrounding context 3. Asks the LLM to generate a docstring based solely on that isolated component 4. Inserts the generated docstring back into the code This baseline serves as a comparison point to demonstrate the effectiveness of our full agentic hierarchical system, which considers dependency relationships and broader context when generating docstrings. To run the baseline system: ```bash clear ./tool/remove_docstrings.sh data/raw_test_repo python experiments/generate_docstrings_copy_and_paste.py --repo-path data/raw_test_repo ``` The baseline uses the same configuration file (agent_config.yaml) as the main system, so it can work with any supported LLM (Claude, OpenAI, HuggingFace). To run in placeholder mode (no actual LLM calls): ```bash python experiments/generate_docstrings_copy_and_paste.py --repo-path data/raw_test_repo --test-mode placeholder ``` To overwrite existing docstrings: ```bash python experiments/generate_docstrings_copy_and_paste.py --repo-path data/raw_test_repo --overwrite-docstrings ``` ### Main Experiments ## Motivation: In the realm of large-scale software repositories, the presence of high-quality, user-oriented docstrings is crucial for maintaining code readability, usability, and maintainability. A well-crafted docstring should not only provide comprehensive details about parameters, return values, exceptions, and usage examples but also clearly articulate the purpose of the function or class within the broader context of the repository. This includes explaining when and how to use the function or class, as well as its relationship to other components in the codebase. Despite the importance of such documentation, current large language models (LLMs) often fall short in generating docstrings that meet these expectations. Common issues include the production of redundant or superficial commentary, a failure to highlight the underlying rationale behind implementation choices, and the omission of crucial constraints and assumptions. These shortcomings can lead to misunderstandings and inefficiencies for developers who rely on these docstrings for guidance. The challenge, therefore, is to develop methods that enable the generation of high-quality docstrings that are both informative and concise, avoiding redundancy by not reiterating information that can be inferred from the code itself, such as parameter types when type hints are present. Addressing these challenges is essential for enhancing the utility of docstrings in large-scale repositories, ultimately contributing to more efficient and effective software development processes. ## Challenges and Limitations of Existing Docstring Generation System: The task of generating high-quality docstrings in large-scale repositories presents several significant challenges. One of the primary difficulties lies in the evaluation of docstring quality. There is inherent ambiguity in assessing what constitutes a "good" docstring, as gold-standard data is scarce. Even highly-rated repositories often contain docstrings that are either inadequate or only partially effective, complicating the establishment of reliable benchmarks for quality assessment. Another challenge is the limitation imposed by the context window of large language models (LLMs). It is impractical to include an entire repository in a single prompt, necessitating a focus on selecting and summarizing relevant information. Determining what is "relevant" is crucial for providing the LLM with a comprehensive understanding of the purpose of the focal function or class. This involves discerning which aspects of the codebase should be included to give the LLM a "global sense" of the function's or class's role and significance. Furthermore, there is a "chicken and egg" problem inherent in this task. Generating high-quality docstrings requires a well-rounded understanding of the context in which the focal function or class operates. However, the context itself often lacks sufficient documentation to clearly convey its purpose and interrelations. This lack of existing high-quality docstrings in the surrounding code complicates the process of generating new ones, as the foundational understanding needed to inform the generation process is itself incomplete. Addressing these challenges is essential for advancing the capability of LLMs to produce docstrings that are not only informative and concise but also contextually aware and aligned with the broader objectives of the codebase. ## Methodology: To address the challenges of generating high-quality docstrings in large-scale repositories, we propose a hierarchical traversal approach combined with an agentic system composed of specialized roles: reader, searcher, writer, and verifier. This methodology is designed to systematically and efficiently produce comprehensive and contextually aware docstrings. Hierarchical Traverse The hierarchical traverse principle is central to our approach. By prioritizing the generation of docstrings for source code files with fewer dependencies, we aim to build a solid foundation of well-documented base classes and utility functions before tackling more complex implementations. This strategy effectively addresses the "chicken and egg" problem by ensuring that the foundational components of the codebase are well-understood and documented first. Unlike existing systems that generate docstrings in a random order, our method provides a structured and logical progression through the codebase. Agentic System Our agentic system is designed to facilitate the docstring generation process through a series of coordinated roles: - Reader: The reader initiates the process by examining the focal code component and identifying any additional internal or external information needed to understand its purpose and context. If further information is required, the reader sends a request to the searcher. - Searcher: The searcher traverses the dependency graph to gather relevant information, both from within the codebase and from open-internet sources if necessary. This information is then used to update the context state, providing a more comprehensive understanding of the focal component. - Writer: Once the context is deemed sufficient, the reader passes the focal code component and its context to the writer. The - writer drafts the docstring, ensuring it adheres to the specified quality and instructional guidelines. - Verifier: The verifier conducts a final quality check of the drafted docstring. If formatting issues are detected, the docstring is returned to the writer for revision. If additional context is needed to enhance informativeness, the process returns to the reader for further information gathering. This iterative and collaborative approach ensures that each docstring is not only accurate and informative but also contextually aligned with the broader objectives of the codebase. By leveraging the strengths of each agent, our methodology provides a robust framework for generating high-quality documentation in large-scale repositories. # For Test The easiest way to interact with DocAgent is through Web App. Assuming the Web App is hosted on remote server. ## Docstring Generation System (Agentic + Hierarchical) Without Web UI: ```bash clear ./tool/remove_docstrings.sh data/raw_test_repo python generate_docstrings.py --repo-path data/raw_test_repo --test-mode context_print ``` With Web UI: ```bash clear ./tool/remove_docstrings.sh data/raw_test_repo python run_web_ui.py --host 0.0.0.0 --port 5000 ``` ## Docstring Eval system Without WebUI: Manual run the test files under `test/evaluator`. With WebUI: ```bash python src/web_eval/app.py --host 0.0.0.0 --port 5001 ``` ## Test Hierarchical Generation Only (no LLM call) For test hierarchical generation only (no LLM call), run the following command: (testing on `data/test_repo_vm` and `data/downloaded_repos/AutoSurvey`) ```bash ./tool/remove_docstrings.sh data/downloaded_repos/AutoSurvey clear python generate_docstrings.py --repo-path data/downloaded_repos/AutoSurvey --test-mode bash tool/visualize.sh output/dependency_graphs/dependency_graph.json output/dependency_graphs/dependency_graph_visualization.png ``` ## Depreciated Tests Test Completeness ```bash python test/evaluator/test_completeness.py data/downloaded_repos/AutoSurvey/src/agents/judge.py ``` Test reader-searcher communication. ```bash python test/agent/depreciated_test_orchestrator.py --mode reader-searcher --verbose-context ``` Remove all docstrings from a repository. ```bash ./tool/remove_docstrings.sh ``` Test hierarchical generation. ```bash python generate_docstrings.py --repo-path data/test_repo_vm --test-mode ``` Visualize dependency graph. ```bash bash tool/visualize.sh ``` # Installation ## Create Config File Create a config folder `config/` and a config file `agent_config.yaml`under `config/`. e.g. `config/agent_config.yaml`: The structure of config is as follows: ```bash llm: type: "claude" # Options: openai, claude, huggingface api_key: "your-anthropic-api-key-here" # Replace with your Anthropic API key model: "claude-3-5-haiku-latest" temperature: 0.1 max_output_tokens: 4096 # Flow control parameters flow_control: max_reader_search_attempts: 2 # Maximum times reader can call searcher max_verifier_rejections: 3 # Maximum times verifier can reject a docstring status_sleep_time: 3 # Time to sleep between status updates (seconds) # Perplexity API configuration perplexity: api_key: "your-perplexity-api-key-here" # Replace with your Perplexity API key model: "sonar" # Default model temperature: 0.1 max_output_tokens: 4096 ``` ## Installation ### Basic Installation To install the basic package with core dependencies: ```bash pip install -e . ``` ### Install with Additional Features You can install the package with additional optional dependencies: ```bash # For development tools (pytest, black, flake8) pip install -e ".[dev]" # For web UI components pip install -e ".[web]" # For visualization tools pip install -e ".[visualization]" # For CUDA support pip install -e ".[cuda]" # For all optional dependencies pip install -e ".[all]" ``` You can also combine multiple optional dependencies: ```bash pip install -e ".[web,visualization]" ``` ## Access Web UI from Local (if running on remote server) ## Running Docstring Generation System In remote: ```bash python run_web_ui.py --host 0.0.0.0 --port 5000 ``` This tells Flask to listen on all network interfaces, not just the loopback interface. In local: ```bash ssh -L 5000:localhost:5000 ``` For example, for devserver, `ssh -L 5000:localhost:5000 dayuyang@devgpu003.rva5.facebook.com`. This command creates a tunnel from your local port 5000 to port 5000 on the remote server. After running this command, you can open your browser and go to http://localhost:5000 to access the web interface running on the remote server. kill any program running on port 5000: ```bash lsof -i :5000 | awk 'NR>1 {print $2}' | sort -u | xargs -r kill ``` ## Running Docstring Eval System In remote: ```bash python src/web_eval/app.py --host 0.0.0.0 --port 5001 ``` In local: ``` ssh -L 5001:localhost:5001 dayuyang@devgpu003.rva5.facebook.com ``` If run both: in local, run: ```bash ssh -L 5000:localhost:5000 -L 5001:localhost:5001 -L 5002:localhost:5002 dayuyang@devgpu003.rva5.facebook.com # 5002 for backup ``` ## Serve local LLM First install vllm ```bash pip install vllm ``` Run `bash serve_local_llm.sh` # Concerns - hierachical generation - Circular dependencies - Import from external source? - assuming "external source" is well-known library and LLM should already know about it? # TO FIX/ADD/IMPROVE - If already has docstring, skip. (or al least give an option to skip) - Improve generation instructions (the LLM will strictly follow the instructions, leading the generation usually too long.) - high-quality docstring also needs to be concise. - add time out warning (if stucked...) - Claude has rate limit: `50,000 input tokens per minute per organization` - add error handling capability system wise - add price calculation - simple code will no need to use this tool. - before system, a small LLM/ determinstic way to determinate if using the system. (balance between efficiency and effectiveness) now the logic is file-level, function-level, method-level, class-level. # Evaluator ## Completeness ## Helpfulness For Summary, Description, Arguments, Parameters, Attributes, each docstring component is evaluated on a 1-5 scale (POOR to EXCELLENT): For Example, the docstring is evaluated on Binary scale (0 or 1). - Evaluates if docstring examples enable users to correctly use the code by comparing predicted usage against ground truth. # Vulnerability 1. helpfulness description, when class is too long, may need truncate. 2. when evaluating parameters/arguments/attributes, the input context (class/function signatures) should be reasonably sized to avoid LLM token limits. # Logic Control flow (process function under orchestrator.py) once searcher is called, reader's memory is refreshed. once more context is needed by judge from verifier, writer, verifier's memory is refreshed. # Note When evaluating examples, the signature must contain decorator `@staticmethod` or `@classmethod`. when writing docstring for class, first write docstring for __init__ method, then write docstring for other methods, finally write docstring for class. (provide full class code as code component when writing docstring for class) Method is extremely difficult to handle. Now only support self.method(), instance.method() and ClassName.method(). See `get_child_method` under `CallGraphBuilder` for more details. Error handle: ask reader: if "XXX is not accessible", do not ask the same code component again. If unsuccessful, callgraphbuilder will return something like "XXX is not accessible". for LLM generated docstring, no triple quotes (\"\"\") is added originally. For generate_docstrings.py. Add features: Multiple Passes ("Category" Approach): • We split docstring generation for each file into three passes, in this order: (a) Top-level functions (i.e., "function") (b) Methods inside classes (i.e., "method") (c) Classes (i.e., "class") Each pass visits all .py files in the repo. Immediate File Rewrite After Each Code Component: • In each pass, we repeatedly parse a file, gather all code components of the chosen category in ascending line order, pick off the first component, generate a docstring, and immediately rewrite the file. Then we re-parse the updated file before moving on to the next component. • This ensures that each code component is added in the final version of the file before the next code component's docstring is generated. It also meets the request for "refresh the python file after each generation for a single code component." However, this approach is more computationally expensive than generating all docstrings in memory and rewriting once per file, but it achieves the desired incremental rewriting and strict function → method → class ordering. Why the python file is not updated after the docstring is generated for each code component? - because updating file needs re-parsing the file and rebuild the AST, which is expensive. Dependency clarification for # Future Work human in the loop: - human can be the judge and can provide more information to the system. ================================================ FILE: src/DocstringGenerator.egg-info/SOURCES.txt ================================================ README.md setup.py src/DocstringGenerator.egg-info/PKG-INFO src/DocstringGenerator.egg-info/SOURCES.txt src/DocstringGenerator.egg-info/dependency_links.txt src/DocstringGenerator.egg-info/requires.txt src/DocstringGenerator.egg-info/top_level.txt src/agent/__init__.py src/agent/base.py src/agent/orchestrator.py src/agent/reader.py src/agent/searcher.py src/agent/verifier.py src/agent/workflow.py src/agent/writer.py src/agent/llm/__init__.py src/agent/llm/base.py src/agent/llm/claude_llm.py src/agent/llm/factory.py src/agent/llm/gemini_llm.py src/agent/llm/huggingface_llm.py src/agent/llm/openai_llm.py src/agent/llm/rate_limiter.py src/dependency_analyzer/__init__.py src/dependency_analyzer/ast_parser.py src/dependency_analyzer/topo_sort.py src/evaluator/__init__.py src/evaluator/base.py src/evaluator/completeness.py src/evaluator/evaluation_common.py src/evaluator/helpfulness_attributes.py src/evaluator/helpfulness_description.py src/evaluator/helpfulness_evaluator.py src/evaluator/helpfulness_evaluator_ablation.py src/evaluator/helpfulness_examples.py src/evaluator/helpfulness_parameters.py src/evaluator/helpfulness_summary.py src/evaluator/segment.py src/evaluator/truthfulness.py src/visualizer/__init__.py src/visualizer/progress.py src/visualizer/status.py src/visualizer/web_bridge.py src/web/__init__.py src/web/app.py src/web/config_handler.py src/web/process_handler.py src/web/run.py src/web/visualization_handler.py ================================================ FILE: src/DocstringGenerator.egg-info/dependency_links.txt ================================================ ================================================ FILE: src/DocstringGenerator.egg-info/requires.txt ================================================ numpy>=1.23.5 pyyaml>=6.0 jinja2>=3.1.5 requests>=2.32.0 urllib3>=2.3.0 astor>=0.8.1 code2flow>=2.5.1 pydeps>=3.0.0 anthropic>=0.45.0 openai>=1.60.1 langchain-anthropic>=0.3.4 langchain-openai>=0.3.2 langchain-core>=0.3.31 langgraph>=0.2.67 tiktoken>=0.8.0 transformers>=4.48.0 huggingface-hub>=0.28.0 google-generativeai>=0.6.0 tqdm>=4.67.1 tabulate>=0.9.0 colorama>=0.4.6 termcolor>=2.5.0 pydantic>=2.10.0 flask>=3.1.0 flask-socketio>=5.5.1 eventlet>=0.39.0 python-socketio>=5.12.1 python-engineio>=4.11.2 bidict>=0.23.0 dnspython>=2.7.0 six>=1.16.0 torch>=2.0.0 accelerate>=1.4.0 [all] pytest>=8.3.4 pytest-cov>=2.0 black>=22.0 flake8>=3.9 flask>=3.1.0 flask-socketio>=5.5.1 eventlet>=0.39.0 python-socketio>=5.12.1 python-engineio>=4.11.2 bidict>=0.23.0 dnspython>=2.7.0 six>=1.16.0 matplotlib>=3.10.0 pygraphviz>=1.14 networkx>=3.4.2 torch>=2.0.0 accelerate>=1.4.0 [cuda] torch>=2.0.0 accelerate>=1.4.0 [dev] pytest>=8.3.4 pytest-cov>=2.0 black>=22.0 flake8>=3.9 [visualization] matplotlib>=3.10.0 pygraphviz>=1.14 networkx>=3.4.2 [web] flask>=3.1.0 flask-socketio>=5.5.1 eventlet>=0.39.0 python-socketio>=5.12.1 python-engineio>=4.11.2 bidict>=0.23.0 dnspython>=2.7.0 six>=1.16.0 ================================================ FILE: src/DocstringGenerator.egg-info/top_level.txt ================================================ agent dependency_analyzer evaluator visualizer web ================================================ FILE: src/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates ================================================ FILE: src/agent/README.md ================================================ # Agent Framework for Docstring Generation This directory contains the core components of the multi-agent system responsible for generating high-quality docstrings for code components. ## Overview The system employs a collaborative workflow involving several specialized agents, managed by an Orchestrator. The goal is to analyze code, gather necessary context (both internal and external), generate a docstring, and verify its quality before finalizing. The main workflow is initiated via the `generate_docstring` function in `workflow.py`. ## Agents 1. **`BaseAgent` (`base.py`)** * **Role:** Abstract base class for all agents. * **Functionality:** Provides common infrastructure including LLM initialization (using `LLMFactory`), configuration loading, memory management (storing conversation history), and basic LLM interaction (`generate_response`). Ensures consistency across agents. 2. **`Reader` (`reader.py`)** * **Role:** Contextual Analysis and Information Needs Assessment. * **Functionality:** Analyzes the input code component (`focal_component`) and any existing context. Determines if additional information is required to write a comprehensive docstring. If more information is needed, it generates a structured request specifying whether internal codebase details (e.g., callers, callees) or external web search results are required. 3. **`Searcher` (`searcher.py`)** * **Role:** Information Retrieval. * **Functionality:** Acts upon the requests generated by the `Reader`. It retrieves the specified information by: * Querying the internal codebase using AST analysis (`ASTNodeAnalyzer`) and dependency graphs. * Performing external web searches via APIs (e.g., `PerplexityAPI`). * Returns the gathered context in a structured format. 4. **`Writer` (`writer.py`)** * **Role:** Docstring Generation. * **Functionality:** Takes the original code component and the accumulated context (provided by the `Orchestrator` after `Reader` and `Searcher` steps) as input. Uses its configured LLM and detailed prompts (tailored for classes vs. functions/methods, adhering to Google style guide) to generate the docstring. Outputs the generated docstring within specific XML tags (``). 5. **`Verifier` (`verifier.py`)** * **Role:** Quality Assurance. * **Functionality:** Evaluates the docstring produced by the `Writer` against the original code and the context used. Checks for clarity, accuracy, completeness, information value (avoiding redundancy), and appropriate level of detail. Determines if the docstring meets quality standards or requires revision. If revision is needed, it specifies whether more context is required or provides direct suggestions for improvement. 6. **`Orchestrator` (`orchestrator.py`)** * **Role:** Workflow Management. * **Functionality:** Coordinates the entire process. It manages the sequence of agent interactions: * Calls `Reader` to assess context needs. * Calls `Searcher` iteratively if more context is requested (up to a limit). * Calls `Writer` to generate the docstring. * Calls `Verifier` to evaluate the docstring. * Manages revision loops based on `Verifier` feedback, potentially involving further searches or refinement by the `Writer` (up to a limit). * Handles context accumulation, token limit constraints, and status visualization. ## Supporting Files * **`workflow.py`:** Provides the primary entry point function `generate_docstring` to initiate the docstring generation process for a given code component. * **`__init__.py`:** Makes the `agent` directory a Python package. * **`llm/`:** Contains LLM-related code, including the `LLMFactory` and base LLM classes. * **`tool/`:** Contains tools used by agents, such as the `ASTNodeAnalyzer` for internal code traversal and the `PerplexityAPI` wrapper for external search. ================================================ FILE: src/agent/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates # Import only essential components to avoid circular imports from .reader import CodeComponentType # Explicitly list what should be accessible, but don't import until needed # to prevent circular imports __all__ = ['generate_docstring', 'CodeComponentType'] # Lazy load generate_docstring when it's actually needed def __getattr__(name): if name == 'generate_docstring': from .workflow import generate_docstring return generate_docstring raise AttributeError(f"module '{__name__}' has no attribute '{name}'") ================================================ FILE: src/agent/base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from typing import Any, Dict, Optional, List import os from pathlib import Path from .llm.factory import LLMFactory from .llm.base import BaseLLM class BaseAgent(ABC): """Base class for all agents in the docstring generation system.""" def __init__(self, name: str, config_path: Optional[str] = None): """Initialize the base agent. Args: name: The name of the agent config_path: Optional path to the configuration file """ self.name = name self._memory: list[Dict[str, Any]] = [] # Initialize LLM and parameters from config self.llm, self.llm_params = self._initialize_llm(name, config_path) def _initialize_llm(self, agent_name: str, config_path: Optional[str] = None) -> tuple[BaseLLM, Dict[str, Any]]: """Initialize the LLM for this agent. Args: agent_name: Name of the agent config_path: Optional path to the configuration file Returns: Tuple of (Initialized LLM instance, LLM parameters dictionary) """ # Load configuration if config_path is None: config_path = "config/agent_config.yaml" print(f"Using default config from {config_path}") config = LLMFactory.load_config(config_path) # Check for agent-specific configuration agent_config = config.get("agent_llms", {}).get(agent_name.lower()) # Use agent-specific config if available, otherwise use default llm_config = agent_config if agent_config else config.get("llm", {}) # Verify api_key is provided in config if ("api_key" not in llm_config or not llm_config["api_key"]) and (llm_config["type"] not in ["huggingface", "local"]): raise ValueError("API key must be specified directly in the config file") # Extract LLM parameters llm_params = { "max_output_tokens": llm_config.get("max_output_tokens", 4096), "temperature": llm_config.get("temperature", 0.1), "model": llm_config.get("model") } return LLMFactory.create_llm(llm_config), llm_params def add_to_memory(self, role: str, content: str) -> None: """Add a message to the agent's memory. Args: role: The role of the message sender (e.g., 'system', 'user', 'assistant') content: The content of the message """ assert content is not None and content != "", "Content cannot be empty" self._memory.append(self.llm.format_message(role, content)) def refresh_memory(self, new_memory: list[Dict[str, Any]]) -> None: """Replace the current memory with new memory. Args: new_memory: The new memory to replace the current memory """ self._memory = [ self.llm.format_message(msg["role"], msg["content"]) for msg in new_memory ] def clear_memory(self) -> None: """Clear the agent's memory.""" self._memory = [] @property def memory(self) -> list[Dict[str, Any]]: """Get the agent's memory. Returns: The agent's memory as a list of message dictionaries """ return self._memory.copy() def generate_response(self, messages: Optional[List[Dict[str, Any]]] = None) -> str: """Generate a response using the agent's LLM and memory. Args: messages: Optional list of messages to use instead of memory Returns: Generated response text """ return self.llm.generate( messages=messages if messages is not None else self._memory, temperature=self.llm_params["temperature"], max_tokens=self.llm_params["max_output_tokens"] ) @abstractmethod def process(self, *args, **kwargs) -> Any: """Process the input and generate output. This method should be implemented by each specific agent. """ pass ================================================ FILE: src/agent/llm/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from .base import BaseLLM from .openai_llm import OpenAILLM from .claude_llm import ClaudeLLM from .huggingface_llm import HuggingFaceLLM from .gemini_llm import GeminiLLM from .factory import LLMFactory __all__ = [ 'BaseLLM', 'OpenAILLM', 'ClaudeLLM', 'HuggingFaceLLM', 'GeminiLLM', 'LLMFactory' ] ================================================ FILE: src/agent/llm/base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional class BaseLLM(ABC): """Base class for LLM wrappers.""" @abstractmethod def generate( self, messages: List[Dict[str, str]], temperature: float = 0.7, max_output_tokens: Optional[int] = None ) -> str: """Generate a response from the LLM. Args: messages: List of message dictionaries with 'role' and 'content' keys temperature: Sampling temperature (0.0 to 1.0) max_output_tokens: Maximum number of tokens to generate Returns: The generated response text """ pass @abstractmethod def format_message(self, role: str, content: str) -> Dict[str, str]: """Format a message for the specific LLM API. Args: role: The role of the message sender content: The content of the message Returns: Formatted message dictionary """ pass ================================================ FILE: src/agent/llm/claude_llm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Dict, Any, Optional import anthropic from .base import BaseLLM from .rate_limiter import RateLimiter import logging class ClaudeLLM(BaseLLM): """Anthropic Claude API wrapper.""" def __init__( self, api_key: str, model: str, rate_limits: Optional[Dict[str, Any]] = None ): """Initialize Claude LLM. Args: api_key: Anthropic API key model: Model identifier (e.g., "claude-3-sonnet-20240229") rate_limits: Optional dictionary with rate limit settings """ self.client = anthropic.Anthropic(api_key=api_key) self.model = model # Default rate limits for Claude 3.7 Sonnet default_limits = { "requests_per_minute": 50, "input_tokens_per_minute": 20000, "output_tokens_per_minute": 8000, "input_token_price_per_million": 3.0, "output_token_price_per_million": 15.0 } # Use provided rate limits or defaults limits = rate_limits or default_limits # Initialize rate limiter self.rate_limiter = RateLimiter( provider="Claude", requests_per_minute=limits.get("requests_per_minute", default_limits["requests_per_minute"]), input_tokens_per_minute=limits.get("input_tokens_per_minute", default_limits["input_tokens_per_minute"]), output_tokens_per_minute=limits.get("output_tokens_per_minute", default_limits["output_tokens_per_minute"]), input_token_price_per_million=limits.get("input_token_price_per_million", default_limits["input_token_price_per_million"]), output_token_price_per_million=limits.get("output_token_price_per_million", default_limits["output_token_price_per_million"]) ) def _count_tokens(self, text: str) -> int: """Count tokens in a string using Claude's tokenizer. Args: text: Text to count tokens for Returns: Token count """ if not text: return 0 try: # Format text as a message for token counting count = self.client.beta.messages.count_tokens( model=self.model, messages=[ {"role": "user", "content": text} ] ) return count.input_tokens except Exception as e: # Log the error but don't fail logging.warning(f"Failed to count tokens with Claude tokenizer: {e}") # Fallback: rough estimate if tokenizer fails return len(text.split()) * 1.3 def _count_messages_tokens(self, messages: List[Dict[str, str]], system_message: Optional[str] = None) -> int: """Count tokens in message list with optional system message. Args: messages: List of message dictionaries system_message: Optional system message Returns: Total token count """ if not messages: return 0 # Convert messages to Claude format claude_messages = [self._convert_to_claude_message(msg) for msg in messages if msg["role"] != "system"] # Format system message if provided system_content = None if system_message: system_content = system_message try: # Use the API to count tokens for all messages at once count = self.client.beta.messages.count_tokens( model=self.model, messages=claude_messages, system=system_content ) return count.input_tokens except Exception as e: # Log the error but don't fail logging.warning(f"Failed to count tokens with Claude tokenizer: {e}") # Fallback: count tokens individually total_tokens = 0 for msg in claude_messages: if "content" in msg and msg["content"]: total_tokens += self._count_tokens(msg["content"]) # Add system message tokens if provided if system_message: total_tokens += self._count_tokens(system_message) # Add overhead for message formatting total_tokens += 10 * len(claude_messages) # Add ~10 tokens per message for formatting return total_tokens def generate( self, messages: List[Dict[str, str]], temperature: float, max_tokens: Optional[int] ) -> str: """Generate a response using Claude API with rate limiting. Args: messages: List of message dictionaries temperature: Sampling temperature max_output_tokens: Maximum tokens to generate Returns: Generated response text """ # Extract system message if present system_message = None chat_messages = [] for msg in messages: if msg["role"] == "system": system_message = msg["content"] else: chat_messages.append(self._convert_to_claude_message(msg)) # Count input tokens input_tokens = self._count_messages_tokens(messages, system_message) # Wait if we're approaching rate limits (estimate output tokens as max_output_tokens) self.rate_limiter.wait_if_needed(input_tokens, max_tokens) # Make the API call response = self.client.messages.create( model=self.model, messages=chat_messages, system=system_message, temperature=temperature, max_tokens=max_tokens ) result_text = response.content[0].text # Count output tokens and record request output_tokens = self._count_tokens(result_text) self.rate_limiter.record_request(input_tokens, output_tokens) return result_text def format_message(self, role: str, content: str) -> Dict[str, str]: """Format message for Claude API. Args: role: Message role (system, user, assistant) content: Message content Returns: Formatted message dictionary """ # Store in standard format, conversion happens in generate() return {"role": role, "content": content} def _convert_to_claude_message(self, message: Dict[str, str]) -> Dict[str, str]: """Convert standard message format to Claude's format. Args: message: Standard format message Returns: Claude format message """ role_mapping = { "user": "user", "assistant": "assistant" } role = role_mapping[message["role"]] content = message["content"] return {"role": role, "content": content} ================================================ FILE: src/agent/llm/factory.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, Optional from pathlib import Path import yaml from .base import BaseLLM from .openai_llm import OpenAILLM from .claude_llm import ClaudeLLM from .huggingface_llm import HuggingFaceLLM from .gemini_llm import GeminiLLM class LLMFactory: """Factory class for creating LLM instances.""" @staticmethod def create_llm(config: Dict[str, Any]) -> BaseLLM: """Create an LLM instance based on configuration. Args: config: Configuration dictionary containing LLM settings Returns: An instance of BaseLLM Raises: ValueError: If the LLM type is not supported """ llm_type = config["type"].lower() model = config.get("model") if not model: raise ValueError("Model must be specified in the config file") # Extract rate limit settings from config # First check if there are specific rate limits in the LLM config rate_limits = config.get("rate_limits", {}) # If not, check if there are global rate limits for this provider type global_config = LLMFactory.load_config() if not rate_limits and "rate_limits" in global_config: # Map LLM types to provider names in rate_limits section provider_map = { "openai": "openai", "claude": "claude", "gemini": "gemini" } provider_key = provider_map.get(llm_type, llm_type) provider_limits = global_config.get("rate_limits", {}).get(provider_key, {}) if provider_limits: rate_limits = provider_limits if llm_type == "openai": return OpenAILLM( api_key=config["api_key"], model=model, rate_limits=rate_limits ) elif llm_type == "claude": return ClaudeLLM( api_key=config["api_key"], model=model, rate_limits=rate_limits ) elif llm_type == "gemini": return GeminiLLM( api_key=config["api_key"], model=model, rate_limits=rate_limits ) elif llm_type == "huggingface": return HuggingFaceLLM( model_name=model, device=config.get("device", "cuda"), torch_dtype=config.get("torch_dtype", "float16") ) else: raise ValueError(f"Unsupported LLM type: {llm_type}") @staticmethod def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: """Load LLM configuration from file. Args: config_path: Path to the configuration file. If None, uses default path. Returns: Configuration dictionary Raises: FileNotFoundError: If the configuration file doesn't exist """ if config_path is None: config_path = str(Path(__file__).parent.parent.parent.parent / "config" / "agent_config.yaml") if not Path(config_path).exists(): raise FileNotFoundError(f"Configuration file not found: {config_path}") with open(config_path, 'r') as f: config = yaml.safe_load(f) return config ================================================ FILE: src/agent/llm/gemini_llm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Dict, Any, Optional import tiktoken import google.generativeai as genai from .base import BaseLLM from .rate_limiter import RateLimiter class GeminiLLM(BaseLLM): """Google Gemini API wrapper.""" def __init__( self, api_key: str, model: str, rate_limits: Optional[Dict[str, Any]] = None ): """Initialize Gemini LLM. Args: api_key: Google API key model: Model identifier (e.g., "gemini-1.5-flash", "gemini-1.5-pro") rate_limits: Optional dictionary with rate limit settings """ genai.configure(api_key=api_key) self.model_name = model self.model = genai.GenerativeModel(model) try: # Initialize tokenizer for token counting # Gemini doesn't have a direct tokenizer in the public API # Using tiktoken cl100k_base as a reasonable approximation self.tokenizer = tiktoken.get_encoding("cl100k_base") except: # Fallback to basic word counting if tokenizer fails self.tokenizer = None # Default rate limits for Gemini (adjust based on actual API limits) default_limits = { "requests_per_minute": 60, "input_tokens_per_minute": 100000, "output_tokens_per_minute": 50000, "input_token_price_per_million": 0.125, # Approximate for gemini-1.5-flash "output_token_price_per_million": 0.375 # Approximate for gemini-1.5-flash } # Use provided rate limits or defaults limits = rate_limits or default_limits # Initialize rate limiter self.rate_limiter = RateLimiter( provider="Gemini", requests_per_minute=limits.get("requests_per_minute", default_limits["requests_per_minute"]), input_tokens_per_minute=limits.get("input_tokens_per_minute", default_limits["input_tokens_per_minute"]), output_tokens_per_minute=limits.get("output_tokens_per_minute", default_limits["output_tokens_per_minute"]), input_token_price_per_million=limits.get("input_token_price_per_million", default_limits["input_token_price_per_million"]), output_token_price_per_million=limits.get("output_token_price_per_million", default_limits["output_token_price_per_million"]) ) def _count_tokens(self, text: str) -> int: """Count tokens in a string using the model's tokenizer. Args: text: Text to count tokens for Returns: Token count """ if not text: return 0 try: if self.tokenizer: return len(self.tokenizer.encode(text)) else: # Fallback: rough estimate if tokenizer not available return len(text.split()) * 1.3 except Exception as e: # Log the error but don't fail import logging logging.warning(f"Failed to count tokens for Gemini: {e}") # Fallback: rough estimate if tokenizer fails return len(text.split()) * 1.3 def _count_messages_tokens(self, messages: List[Dict[str, str]]) -> int: """Count tokens in all messages. Args: messages: List of message dictionaries Returns: Total token count """ if not messages: return 0 total_tokens = 0 # Count tokens in each message for message in messages: if "content" in message and message["content"]: total_tokens += self._count_tokens(message["content"]) # Add overhead for message formatting (estimated) total_tokens += 4 * len(messages) return total_tokens def _convert_messages_to_gemini_format(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """Convert standard message format to Gemini-specific format. Args: messages: List of message dictionaries with 'role' and 'content' keys Returns: List of Gemini-formatted messages """ gemini_messages = [] # Gemini uses "user" and "model" for roles role_mapping = { "user": "user", "assistant": "model", "system": "user" # Gemini doesn't have a system role, handle specifically } # Check if first message is a system message if messages and messages[0].get("role") == "system": # For system message, we'll add it as a user message with a prefix system_content = messages[0].get("content", "") if system_content: # Add the rest of the messages for message in messages[1:]: role = role_mapping.get(message.get("role", "user"), "user") content = message.get("content", "") gemini_messages.append({"role": role, "parts": content}) else: # No system message, just convert roles for message in messages: role = role_mapping.get(message.get("role", "user"), "user") content = message.get("content", "") gemini_messages.append({"role": role, "parts": content}) return gemini_messages def generate( self, messages: List[Dict[str, str]], temperature: float = 0.7, max_tokens: Optional[int] = None ) -> str: """Generate a response using Gemini API with rate limiting. Args: messages: List of message dictionaries temperature: Sampling temperature max_output_tokens: Maximum tokens to generate Returns: Generated response text """ # Count input tokens input_tokens = self._count_messages_tokens(messages) # Wait if we're approaching rate limits self.rate_limiter.wait_if_needed(input_tokens, max_tokens if max_tokens else 1000) # Format messages for Gemini API gemini_messages = self._convert_messages_to_gemini_format(messages) # Check if we need to start a chat or just generate if len(gemini_messages) > 1: # Start a chat with history history = gemini_messages[:-1] # All but the last message last_message = gemini_messages[-1] # The last message to send chat = self.model.start_chat( history=history, ) # Send the last message to get a response response = chat.send_message(last_message.get("parts", "")) result_text = response.text else: # Single message, use generate_content content = gemini_messages[0].get("parts", "") if gemini_messages else "" response = self.model.generate_content( content, generation_config={ "temperature": temperature, "max_tokens": max_tokens if max_tokens else None } ) result_text = response.text # Estimate output tokens (Gemini API doesn't provide usage stats) output_tokens = self._count_tokens(result_text) # Record the request self.rate_limiter.record_request(input_tokens, output_tokens) return result_text def format_message(self, role: str, content: str) -> Dict[str, str]: """Format message for standard API. Args: role: Message role (system, user, assistant) content: Message content Returns: Formatted message dictionary """ # Standard format - conversion to Gemini format happens in generate method return {"role": role, "content": content} ================================================ FILE: src/agent/llm/huggingface_llm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Dict, Any, Optional from openai import OpenAI import torch import tiktoken from .base import BaseLLM class HuggingFaceLLM(BaseLLM): """HuggingFace model wrapper using vLLM's OpenAI-compatible API.""" def __init__( self, model_name: str, api_base: str = "http://localhost:8000/v1", api_key: str = "EMPTY", device: str = None, # Kept for backward compatibility torch_dtype: torch.dtype = None, # Kept for backward compatibility max_input_tokens: int = 10000 # Maximum input tokens allowed ): """Initialize HuggingFace LLM via vLLM API. Args: model_name: Name of the model api_base: Base URL for the vLLM API endpoint api_key: API key (typically "EMPTY" for local vLLM deployments) device: Ignored (handled by vLLM server) torch_dtype: Ignored (handled by vLLM server) max_input_tokens: Maximum number of input tokens allowed """ self.model_name = model_name self.client = OpenAI( api_key=api_key, base_url=api_base, ) self.max_input_tokens = max_input_tokens # Initialize tokenizer based on model try: self.tokenizer = tiktoken.encoding_for_model(model_name) except KeyError: # Fall back to cl100k_base for unknown models (used by GPT-4, GPT-3.5-turbo) self.tokenizer = tiktoken.get_encoding("cl100k_base") def _count_tokens(self, messages: List[Dict[str, str]]) -> int: """Count the number of tokens in a list of messages. Args: messages: List of message dictionaries Returns: Total token count """ token_count = 0 for message in messages: # Count tokens in content token_count += len(self.tokenizer.encode(message["content"])) # Add overhead for message format (role, etc.) token_count += 4 # Approximate tokens for message formatting # Add tokens for the formatting between messages token_count += 2 # Final assistant message tokens return token_count def _truncate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: """Truncate messages to stay within the token limit. Args: messages: List of message dictionaries Returns: Truncated list of message dictionaries """ if not messages: return [] system_messages = [m for m in messages if m["role"].lower() == "system"] non_system_messages = [m for m in messages if m["role"].lower() != "system"] # Always keep system messages intact result = system_messages.copy() token_budget = self.max_input_tokens - self._count_tokens(result) # Process non-system messages from newest to oldest for message in reversed(non_system_messages): message_tokens = self._count_tokens([message]) if message_tokens <= token_budget: # We can include the entire message result.insert(len(system_messages), message) token_budget -= message_tokens elif message["role"].lower() == "user" and token_budget > 20: # For user messages, we can truncate content if needed # Keep enough tokens for comprehension (at least some portion) content = message["content"] # Estimate how much content to keep keep_ratio = token_budget / message_tokens # Truncate from beginning to keep most recent content if keep_ratio < 0.5: # If we need to cut more than half, add indicator of truncation truncated_content = f"[...truncated...] {content[int(len(content) * (1 - keep_ratio + 0.1)):].strip()}" else: truncated_content = content[int(len(content) * (1 - keep_ratio)):].strip() truncated_message = { "role": message["role"], "content": truncated_content } # Verify the truncated message fits truncated_tokens = self._count_tokens([truncated_message]) if truncated_tokens <= token_budget: result.insert(len(system_messages), truncated_message) token_budget -= truncated_tokens # If we can't fit any more messages, stop if token_budget <= 20: # Keep some buffer break # Ensure the messages are in the correct order (system first, then chronological) result.sort(key=lambda m: 0 if m["role"].lower() == "system" else 1) return result def generate( self, messages: List[Dict[str, str]], temperature: float, max_tokens: Optional[int] ) -> str: """Generate a response using the vLLM API. Args: messages: List of message dictionaries temperature: Sampling temperature max_output_tokens: Maximum tokens to generate Returns: Generated response text """ max_output_tokens = max_tokens if max_tokens is not None else self.max_output_tokens # Check token count and truncate if needed total_tokens = self._count_tokens(messages) if total_tokens > self.max_input_tokens: messages = self._truncate_messages(messages) # vLLM expects strictly alternating user/assistant roles with an optional system message at the beginning # Prepare the messages with the proper format formatted_messages = [] # First, check for a system message to include at the beginning system_messages = [m for m in messages if m["role"].lower() == "system"] if system_messages: # Use the last system message if multiple exist formatted_messages.append({ "role": "system", "content": system_messages[-1]["content"] }) # Filter out system messages and process the rest user_assistant_messages = [m for m in messages if m["role"].lower() != "system"] # Ensure messages alternate between user and assistant current_role = "user" # Start with user message for message in user_assistant_messages: role = message["role"].lower() # Map roles to either user or assistant if role in ["user", "human"]: mapped_role = "user" else: mapped_role = "assistant" # If this message would create consecutive messages with the same role, # skip adding it to avoid the alternating pattern error if formatted_messages and mapped_role == formatted_messages[-1]["role"]: continue # Add the properly mapped message formatted_messages.append({ "role": mapped_role, "content": message["content"] }) # Make sure the last message is from the user, so the model will respond as assistant if not formatted_messages or formatted_messages[-1]["role"] != "user": # If we don't have any messages or the last one isn't from user, we need to add a user message # Use an empty message or the last assistant message as context formatted_messages.append({ "role": "user", "content": "Please continue." if not formatted_messages else f"Based on your last response: '{formatted_messages[-1]['content']}', please continue." }) # Call the API response = self.client.chat.completions.create( model=self.model_name, messages=formatted_messages, temperature=temperature, max_tokens=max_output_tokens ) # Extract the generated text return response.choices[0].message.content def format_message(self, role: str, content: str) -> Dict[str, str]: """Format message for OpenAI API compatible format. Args: role: Message role (system, user, assistant) content: Message content Returns: Formatted message dictionary """ # Map to standard OpenAI roles if needed if role.lower() not in ["system", "user", "assistant"]: if role.lower() in ["human"]: role = "user" elif role.lower() in ["ai", "assistant"]: role = "assistant" else: # Default unexpected roles to user role = "user" return {"role": role, "content": content} def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: """Convert messages to a single prompt string. This method is kept for backward compatibility but is not used in the API-based implementation. Args: messages: List of message dictionaries Returns: Formatted prompt string """ prompt_parts = [] for message in messages: role = message["role"] content = message["content"] if role == "system": prompt_parts.append(f"System: {content}") elif role == "user": prompt_parts.append(f"Human: {content}") elif role == "assistant": prompt_parts.append(f"Assistant: {content}") prompt_parts.append("Assistant: ") # Add final prompt for generation return "\n".join(prompt_parts) ================================================ FILE: src/agent/llm/openai_llm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Dict, Any, Optional import openai import tiktoken from .base import BaseLLM from .rate_limiter import RateLimiter class OpenAILLM(BaseLLM): """OpenAI API wrapper.""" def __init__( self, api_key: str, model: str, rate_limits: Optional[Dict[str, Any]] = None ): """Initialize OpenAI LLM. Args: api_key: OpenAI API key model: Model identifier (e.g., "gpt-4", "gpt-3.5-turbo") rate_limits: Optional dictionary with rate limit settings """ self.client = openai.OpenAI(api_key=api_key) self.model = model try: # Initialize tokenizer for the model self.tokenizer = tiktoken.encoding_for_model(model) except: # Fallback to cl100k_base for new models self.tokenizer = tiktoken.get_encoding("cl100k_base") # Default rate limits for GPT-4o-mini default_limits = { "requests_per_minute": 500, "input_tokens_per_minute": 200000, "output_tokens_per_minute": 100000, "input_token_price_per_million": 0.15, "output_token_price_per_million": 0.60 } # Use provided rate limits or defaults limits = rate_limits or default_limits # Initialize rate limiter self.rate_limiter = RateLimiter( provider="OpenAI", requests_per_minute=limits.get("requests_per_minute", default_limits["requests_per_minute"]), input_tokens_per_minute=limits.get("input_tokens_per_minute", default_limits["input_tokens_per_minute"]), output_tokens_per_minute=limits.get("output_tokens_per_minute", default_limits["output_tokens_per_minute"]), input_token_price_per_million=limits.get("input_token_price_per_million", default_limits["input_token_price_per_million"]), output_token_price_per_million=limits.get("output_token_price_per_million", default_limits["output_token_price_per_million"]) ) def _count_tokens(self, text: str) -> int: """Count tokens in a string using the model's tokenizer. Args: text: Text to count tokens for Returns: Token count """ if not text: return 0 try: return len(self.tokenizer.encode(text)) except Exception as e: # Log the error but don't fail import logging logging.warning(f"Failed to count tokens with OpenAI tokenizer: {e}") # Fallback: rough estimate if tokenizer fails return len(text.split()) * 1.3 def _count_messages_tokens(self, messages: List[Dict[str, str]]) -> int: """Count tokens in all messages. Args: messages: List of message dictionaries Returns: Total token count """ if not messages: return 0 total_tokens = 0 # Count tokens in each message for message in messages: if "content" in message and message["content"]: total_tokens += self._count_tokens(message["content"]) # Add overhead for message formatting (varies by model, but ~4 tokens per message) total_tokens += 4 * len(messages) # Add tokens for model overhead (varies by model) total_tokens += 3 # Every reply is primed with <|start|>assistant<|message|> return total_tokens def generate( self, messages: List[Dict[str, str]], temperature: float, max_tokens: Optional[int] ) -> str: """Generate a response using OpenAI API with rate limiting. Args: messages: List of message dictionaries temperature: Sampling temperature max_output_tokens: Maximum tokens to generate Returns: Generated response text """ # Count input tokens input_tokens = self._count_messages_tokens(messages) # Wait if we're approaching rate limits (estimate output tokens as max_output_tokens) self.rate_limiter.wait_if_needed(input_tokens, max_tokens) # Make the API call response = self.client.chat.completions.create( model=self.model, messages=messages, temperature=temperature, max_tokens=max_tokens if max_tokens else None ) result_text = response.choices[0].message.content # Count output tokens and record request output_tokens = response.usage.completion_tokens if hasattr(response, 'usage') else self._count_tokens(result_text) input_tokens = response.usage.prompt_tokens if hasattr(response, 'usage') else input_tokens self.rate_limiter.record_request(input_tokens, output_tokens) return result_text def format_message(self, role: str, content: str) -> Dict[str, str]: """Format message for OpenAI API. Args: role: Message role (system, user, assistant) content: Message content Returns: Formatted message dictionary """ # OpenAI uses standard role names return {"role": role, "content": content} ================================================ FILE: src/agent/llm/rate_limiter.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import time from typing import Dict, List, Optional from collections import deque import threading import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("RateLimiter") class RateLimiter: """ Rate limiter for LLM API calls. Tracks requests, input tokens, and output tokens per minute. Also tracks cost based on token pricing. """ def __init__( self, provider: str, requests_per_minute: int, input_tokens_per_minute: int, output_tokens_per_minute: int, input_token_price_per_million: float, output_token_price_per_million: float, buffer_percentage: float = 0.1 # Buffer to avoid hitting exact limits ): """ Initialize the rate limiter. Args: provider: LLM provider name ("openai" or "claude") requests_per_minute: Maximum requests per minute input_tokens_per_minute: Maximum input tokens per minute output_tokens_per_minute: Maximum output tokens per minute input_token_price_per_million: Price per million input tokens output_token_price_per_million: Price per million output tokens buffer_percentage: Percentage buffer to avoid hitting exact limits """ self.provider = provider self.requests_per_minute = requests_per_minute * (1 - buffer_percentage) self.input_tokens_per_minute = input_tokens_per_minute * (1 - buffer_percentage) self.output_tokens_per_minute = output_tokens_per_minute * (1 - buffer_percentage) # Pricing self.input_token_price = input_token_price_per_million / 1_000_000 self.output_token_price = output_token_price_per_million / 1_000_000 # Track usage within a sliding window (1 minute) self.request_timestamps = deque() self.input_token_usage = deque() # Tuples of (timestamp, token_count) self.output_token_usage = deque() # Tuples of (timestamp, token_count) # Total usage stats self.total_requests = 0 self.total_input_tokens = 0 self.total_output_tokens = 0 self.total_cost = 0.0 # Thread lock for thread safety self.lock = threading.Lock() def _clean_old_entries(self, usage_queue: deque, current_time: float): """Remove entries older than 1 minute from the queue.""" one_minute_ago = current_time - 60 # Handle different queue formats (timestamps vs. (timestamp, value) tuples) if usage_queue and isinstance(usage_queue[0], tuple): # For token usage queues that store (timestamp, count) tuples while usage_queue and usage_queue[0][0] < one_minute_ago: usage_queue.popleft() else: # For request_timestamps queue that stores timestamp floats directly while usage_queue and usage_queue[0] < one_minute_ago: usage_queue.popleft() def _get_usage_count(self, usage_queue: deque): """Get the total count from a usage queue.""" return sum(count for _, count in usage_queue) def wait_if_needed(self, input_tokens: int, estimated_output_tokens: Optional[int] = None): """ Check if we're about to exceed rate limits and wait if necessary. This improved version uses a while loop instead of recursion to avoid potential infinite waiting scenarios. Args: input_tokens: Number of input tokens for the upcoming request estimated_output_tokens: Estimated number of output tokens """ with self.lock: if estimated_output_tokens is None: estimated_output_tokens = input_tokens // 2 # Rough fallback estimate # If this single request is bigger than the entire capacity, warn or handle if input_tokens > self.input_tokens_per_minute or estimated_output_tokens > self.output_tokens_per_minute: logger.warning( f"Request uses more tokens ({input_tokens} in / {estimated_output_tokens} out) " f"than the configured per-minute capacity. This request may never succeed." ) while True: current_time = time.time() # Clean up old entries self._clean_old_entries(self.request_timestamps, current_time) self._clean_old_entries(self.input_token_usage, current_time) self._clean_old_entries(self.output_token_usage, current_time) # Calculate current usage current_requests = len(self.request_timestamps) current_input_tokens = self._get_usage_count(self.input_token_usage) current_output_tokens = self._get_usage_count(self.output_token_usage) # Check if adding this request would exceed limits if ((current_requests + 1) <= self.requests_per_minute and (current_input_tokens + input_tokens) <= self.input_tokens_per_minute and (current_output_tokens + estimated_output_tokens) <= self.output_tokens_per_minute): # We can proceed now break # Otherwise, compute how long to wait wait_time = 0 if self.request_timestamps: wait_time = max(wait_time, 60 - (current_time - self.request_timestamps[0])) if self.input_token_usage: wait_time = max(wait_time, 60 - (current_time - self.input_token_usage[0][0])) if self.output_token_usage: wait_time = max(wait_time, 60 - (current_time - self.output_token_usage[0][0])) # If wait_time is still <= 0, we won't fix usage by waiting if wait_time <= 0: logger.warning( "Waiting cannot reduce usage enough to allow this request; " "request exceeds per-minute capacity or usage remains too high." ) break logger.info(f"Rate limit approaching for {self.provider}. Waiting {wait_time:.2f} seconds...") time.sleep(wait_time) def record_request(self, input_tokens: int, output_tokens: int): """ Record an API request and its token usage. Args: input_tokens: Number of input tokens used output_tokens: Number of output tokens generated """ with self.lock: current_time = time.time() # Record request and token usage self.request_timestamps.append(current_time) self.input_token_usage.append((current_time, input_tokens)) self.output_token_usage.append((current_time, output_tokens)) # Update total stats self.total_requests += 1 self.total_input_tokens += input_tokens self.total_output_tokens += output_tokens # Calculate cost input_cost = input_tokens * self.input_token_price output_cost = output_tokens * self.output_token_price total_cost = input_cost + output_cost self.total_cost += total_cost # Log usage and cost logger.info( f"{self.provider} Request: {self.total_requests} | " f"Tokens: {input_tokens}in/{output_tokens}out | " f"Cost: ${total_cost:.6f} | " f"Total Cost: ${self.total_cost:.6f}" ) def print_usage_stats(self): """Print current usage statistics.""" with self.lock: logger.info(f"{self.provider} Usage Statistics:") logger.info(f" Total Requests: {self.total_requests}") logger.info(f" Total Input Tokens: {self.total_input_tokens}") logger.info(f" Total Output Tokens: {self.total_output_tokens}") logger.info(f" Total Cost: ${self.total_cost:.6f}") ================================================ FILE: src/agent/orchestrator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, Optional, List import time from .base import BaseAgent from .reader import Reader from .searcher import Searcher from .writer import Writer from .verifier import Verifier from visualizer import StatusVisualizer import re import yaml import ast import tiktoken # Dummy visualizer class that mimics StatusVisualizer but does nothing class DummyVisualizer: """A no-op visualizer that implements the same interface as StatusVisualizer but does nothing.""" def reset(self): """Do nothing.""" pass def set_current_component(self, component, file_path): """Do nothing.""" pass def update(self, agent_name, status): """Do nothing.""" pass class Orchestrator(BaseAgent): """Agent responsible for managing the workflow between all other agents.""" def __init__(self, repo_path: str, config_path: Optional[str] = None, test_mode: Optional[str] = None): """Initialize the Orchestrator agent and its sub-agents. Args: repo_path: Path to the repository being analyzed config_path: Optional path to the configuration file test_mode: Optional test mode to run only specific components. Values: "reader_searcher", "context_print" or None """ super().__init__("Orchestrator") self.repo_path = repo_path self.context = "" self.test_mode = test_mode # Load configuration self.config = {} if config_path: with open(config_path, 'r') as f: self.config = yaml.safe_load(f) # Get flow control parameters with defaults flow_config = self.config.get('flow_control', {}) self.max_reader_search_attempts = flow_config.get('max_reader_search_attempts', 4) self.max_verifier_rejections = flow_config.get('max_verifier_rejections', 3) self.status_sleep_time = flow_config.get('status_sleep_time', 3) # Check model type for context constraints llm_config = self.config.get('llm', {}) self.model_type = llm_config.get('type', 'openai') # Add max_input_tokens to config for context length constraint if 'max_input_tokens' not in self.config: self.config['max_input_tokens'] = llm_config.get('max_input_tokens', 10000) # Initialize visualization - use dummy visualizer for "context_print" test mode if test_mode == "context_print": self.visualizer = DummyVisualizer() else: self.visualizer = StatusVisualizer() # Initialize all sub-agents self.reader = Reader(config_path=config_path) self.searcher = Searcher(repo_path, config_path=config_path) # Only initialize writer and verifier if not in reader_searcher test mode if test_mode != "reader_searcher": self.writer = Writer(config_path=config_path) self.verifier = Verifier(config_path=config_path) def _parse_verifier_response(self, response: str) -> Dict[str, Any]: """Parse the verifier's XML response into a structured format. Args: response: The XML response from the verifier Returns: Dictionary containing parsed verification results with structure: { 'needs_revision': bool, 'needs_context': bool, 'suggestion': str, 'context_suggestion': str } """ result = { 'needs_revision': False, 'needs_context': False, 'suggestion': '', 'context_suggestion': '' } # Parse NEED_REVISION need_revision_match = re.search(r'(.*?)', response, re.DOTALL) if need_revision_match: result['needs_revision'] = need_revision_match.group(1).strip().lower() == 'true' if result['needs_revision']: # Parse MORE_CONTEXT more_context_match = re.search(r'(.*?)', response, re.DOTALL) if more_context_match: result['needs_context'] = more_context_match.group(1).strip().lower() == 'true' if result['needs_context']: # Extract context suggestion context_suggestion_match = re.search(r'(.*?)', response, re.DOTALL) if context_suggestion_match: result['context_suggestion'] = context_suggestion_match.group(1).strip() else: # Extract improvement suggestion suggestion_match = re.search(r'(.*?)', response, re.DOTALL) if suggestion_match: result['suggestion'] = suggestion_match.group(1).strip() return result def process( self, focal_component: str, file_path: str, ast_node: ast.AST = None, ast_tree: ast.AST = None, dependency_graph: Dict[str, List[str]] = None, focal_node_dependency_path: str = None, token_consume_focal: int = 0 ) -> str: """Process a docstring generation request through the entire agent workflow. Args: focal_component: The code component needing a docstring (full code snippet) file_path: Path to the file containing the component (Only input relative file path to the belonged repo!) ast_node: Optional AST node representing the focal component ast_tree: Optional AST tree for the entire file Returns: The generated and verified docstring, or reader response in test mode """ # Reset visualization and set current component self.visualizer.reset() self.visualizer.set_current_component(focal_component, file_path) # context should be reset to empty string self.context = "" # Initialize attempt counters reader_search_attempts = 0 verifier_rejection_count = 0 while True: # Step 1: Reader determines if more context is needed self.visualizer.update('reader', "Analyzing code component...") reader_response = self.reader.process( focal_component, self.context ) # add reader_response to reader's memory (assistant) self.reader.add_to_memory("assistant", reader_response) # Step 2: Check if more information is needed match = re.search(r'(.*?)', reader_response, re.DOTALL) needs_info = match and match.group(1).strip().lower() == 'true' if needs_info and reader_search_attempts < self.max_reader_search_attempts: reader_search_attempts += 1 self.visualizer.update('reader', f"Need more information (attempt {reader_search_attempts}/{self.max_reader_search_attempts}), ask Searcher to search additional context...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) # Use Searcher to gather more information self.visualizer.update('searcher', "Searching for additional context...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) search_results = self.searcher.process(reader_response, ast_node, ast_tree, dependency_graph, focal_node_dependency_path) self._update_context(search_results, token_consume_focal) # Refresh reader's memory with new context self.reader.refresh_memory([ {"role": "system", "content": self.reader.system_prompt}, {"role": "user", "content": f"Current context:\n{self.context}"} ]) self.visualizer.update('reader', "Search complete, Context updated, restarting analysis...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) continue elif needs_info: self.visualizer.update('reader', f"Max search attempts ({self.max_reader_search_attempts}) reached, proceeding with current context...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) self.visualizer.update('reader', "No additional context needed, starting docstring generation...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) # If in reader_searcher test mode, return after context gathering if self.test_mode == "reader_searcher": return reader_response while True: # Inner loop for writer-verifier cycle # Step 3: When enough context is gathered, use Writer to generate docstring self.visualizer.update('writer', "Generating docstring...") # Print context if in context_print test mode if self.test_mode == "context_print": print("\n=== CONTEXT BEFORE WRITER CALL ===") print(self.context) print("=== END OF CONTEXT ===\n") docstring = self.writer.process( focal_component, self.context ) # assert docstring is not empty # add writer_response to writer's memory (assistant) self.writer.add_to_memory("assistant", docstring) # Step 4: Use Verifier to check the quality self.visualizer.update('verifier', "Verifying docstring quality...") verification_response = self.verifier.process( focal_component, docstring, self.context ) # Step 5: Parse and process verification results verification_result = self._parse_verifier_response(verification_response) if not verification_result['needs_revision'] or verifier_rejection_count >= self.max_verifier_rejections: if verifier_rejection_count >= self.max_verifier_rejections: self.visualizer.update('verifier', f"Max rejection attempts ({self.max_verifier_rejections}) reached, accepting current docstring.") else: self.visualizer.update('verifier', "Docstring generated successfully! No need for revision.") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) return docstring # if needs_revision is true, then needs_context is true else: verifier_rejection_count += 1 # clean verifier's memory self.verifier.clear_memory() if verification_result['needs_context'] and reader_search_attempts < self.max_reader_search_attempts: self.visualizer.update('verifier', f"Need more context (rejection {verifier_rejection_count}/{self.max_verifier_rejections}), hands back to reader...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) # Add context suggestion to reader's memory and break inner loop to get more context self.reader.add_to_memory( "user", f"Additional context needed: {verification_result['context_suggestion']}" ) # clean writer's and verifier's memory self.writer.clear_memory() break # Break inner loop to return to reader-searcher cycle else: self.visualizer.update('verifier', f"Content is not good enough (rejection {verifier_rejection_count}/{self.max_verifier_rejections}), hands back to writer...") if self.test_mode != "context_print": time.sleep(self.status_sleep_time) # Add improvement suggestion to writer's memory and continue inner loop self.writer.add_to_memory( "user", f"Please improve the docstring based on this suggestion: {verification_result['suggestion']}" ) # Continue inner loop to generate new docstring def _update_context(self, search_results: Dict[str, Any], token_consume_focal: int) -> None: """Update the context with new search results by merging content within existing XML tags. Args: search_results: Dictionary containing new context information structured as: { 'internal': { 'calls': { 'class': {'class1': 'content1', ...}, 'function': {'func1': 'content1', ...}, 'method': {'method1': 'content1', ...}, }, 'called_by': ['code snippet1', ...] }, 'external': { 'query1': 'result1', 'query2': 'result2' } } """ if not self.context: # Initialize empty context structure if none exists self.context = """ """ if 'internal' in search_results: internal_info = search_results['internal'] # Handle calls (class, function, method) if 'calls' in internal_info: calls = internal_info['calls'] # Helper function to safely update XML content def update_xml_section(tag: str, content_list: list) -> None: if not content_list: return pattern = f'<{tag}>(.*?)' match = re.search(pattern, self.context, re.DOTALL) if not match: # If pattern doesn't exist, something is wrong with context structure return existing_text = match.group(1).strip() new_content = existing_text + "\n" + "\n".join(content_list) if existing_text else "\n".join(content_list) # Escape backslashes in new_content to prevent regex interpretation issues new_content = new_content.replace('\\', '\\\\') self.context = re.sub(pattern, f'<{tag}>\n{new_content}\n', self.context, flags=re.DOTALL) # Update class calls if 'class' in calls: class_content = [f"<{class_name}>{content}" for class_name, content in calls['class'].items()] update_xml_section('CLASS', class_content) # Update function calls if 'function' in calls: func_content = [f"<{func_name}>{content}" for func_name, content in calls['function'].items()] update_xml_section('FUNCTION', func_content) # Update method calls if 'method' in calls: method_content = [f"<{method_name}>{content}" for method_name, content in calls['method'].items()] update_xml_section('METHOD', method_content) # Update called_by if 'called_by' in internal_info: called_by_content = internal_info['called_by'] update_xml_section('CALL_BY', called_by_content) # Update external info if 'external' in search_results: external_content = [] for query, result in search_results['external'].items(): external_content.append(f"{query}") external_content.append(f"{result}") update_xml_section('EXTERNAL_RETRIEVAL_INFO', external_content) # Apply context length constraint for all models if hasattr(self, 'config') and 'max_input_tokens' in self.config: max_input_tokens = self.config.get('max_input_tokens', 10000) else: max_input_tokens = 10000 # Default fallback self._constrain_context_length(max_input_tokens=max_input_tokens, token_consume_focal=token_consume_focal) def _constrain_context_length(self, max_input_tokens: int = 10000, token_consume_focal: int = 0) -> None: """Constrain context length for models by truncating the longest component. Args: max_input_tokens: Maximum number of tokens allowed in the input context token_consume_focal: Number of tokens consumed by the focal component itself """ try: # Use tiktoken to count tokens encoding = tiktoken.get_encoding("cl100k_base") # Using a common encoding current_tokens = len(encoding.encode(self.context)) # Check if we need to truncate considering both context and focal component tokens if current_tokens + token_consume_focal <= max_input_tokens: return # No need to truncate # Find the XML section with the most tokens to truncate component_tokens = {} components = [ ('CODE_CONTEXT', r'(.*?)'), ('FOCAL_COMPONENT', r'(.*?)'), ('RELATED_COMPONENTS', r'(.*?)'), ('FOCAL_DEPENDENCIES', r'(.*?)'), ('EXTERNAL_RETRIEVAL_INFO', r'(.*?)') ] for name, pattern in components: match = re.search(pattern, self.context, re.DOTALL) if match: content = match.group(1) tokens = len(encoding.encode(content)) component_tokens[name] = (content, tokens) # Find the component with the most tokens if not component_tokens: return # No components found longest_component = max(component_tokens.items(), key=lambda x: x[1][1]) component_name = longest_component[0] content = longest_component[1][0] component_token_count = longest_component[1][1] # Calculate tokens to remove, considering focal component tokens_to_remove = current_tokens + token_consume_focal - max_input_tokens if tokens_to_remove <= 0: return # No need to truncate # Print information about truncation print(f"Truncating {component_name}: removing {tokens_to_remove} tokens from {component_token_count} tokens. Current total: {current_tokens} tokens") if tokens_to_remove >= component_token_count: # If removing the entire component isn't enough, we'll just remove it and deal with the rest later new_content = "" else: # Truncate the content by removing tokens from the end encoded_content = encoding.encode(content) truncated_encoded = encoded_content[:-tokens_to_remove] new_content = encoding.decode(truncated_encoded) # Update the context with truncated content pattern = f'<{component_name}>(.*?)' self.context = re.sub(pattern, f'<{component_name}>\n{new_content}\n', self.context, flags=re.DOTALL) except Exception as e: print(f"Error constraining context length: {e}") ================================================ FILE: src/agent/reader.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple from .base import BaseAgent class CodeComponentType(Enum): """Enum for different types of code components.""" FUNCTION = "function" METHOD = "method" CLASS = "class" @dataclass class InformationRequest: """Data class for structured information requests.""" internal_requests: List[str] external_requests: List[str] class Reader(BaseAgent): """Agent responsible for determining if more context is needed for docstring generation.""" def __init__(self, config_path: Optional[str] = None): """Initialize the Reader agent. Args: config_path: Optional path to the configuration file """ super().__init__("Reader", config_path) self.system_prompt = """You are a Reader agent responsible for determining if more context is needed to generate a high-quality docstring. You should analyze the code component and current context to make this determination. You have access to two types of information sources: 1. Internal Codebase Information (from local code repository): For Functions: - Code components called within the function body - Places where this function is called For Methods: - Code components called within the method body - Places where this method is called - The class this method belongs to For Classes: - Code components called in the __init__ method - Places where this class is instantiated - Complete class implementation beyond __init__ 2. External Open Internet retrieval Information: - External Retrieval is extremely expensive. Only request external open internet retrieval information if the component involves a novel, state of the art, recently-proposed algorithms or techniques. (e.g. computing a novel loss function (NDCG Loss, Alignment and Uniformity Loss, etc), certain novel metrics (Cohen's Kappa, etc), specialized novel ideas) - Each query should be a clear, natural language question Your response should: 1. First provide a free text analysis of the current code and context 2. Explain what additional information might be needed (if any) 3. Include an true tag if more information is needed, or false if current context is sufficient 4. If more information is needed, end your response with a structured request in XML format: class1,class2 func1,func2 self.method1,instance.method2,class.method3 true/false query1,query2 Important rules for structured request: 1. For CALLS sections, only include names that are explicitly needed 2. If no items exist for a category, use empty tags (e.g., ) 3. CALL_BY should be "true" only if you need to know what calls/uses a component 4. Each external QUERY should be a concise, clear, natural language search query 5. Use comma-separated values without spaces for multiple items 6. For METHODS, keep dot notation in the same format as the input. 7. Only first-level calls of the focal code component are accessible. Do not request information on code components that are not directly called by the focal component. 8. External Open-Internet Retrieval is extremely expensive. Only request external open internet retrieval information if the component involves a novel, state of the art, recently-proposed algorithms or techniques. (e.g. computing a novel loss function (NDCG Loss, Alignment and Uniformity Loss, etc), certain novel metrics (Cohen's Kappa, etc), specialized novel ideas) Important rules: 1. Only request internal codebase information that you think is necessary for docstring generation task. For some components that is simple and obvious, you do not need any other information for docstring generation. 2. External Open-Internet retrieval request is extremely expensive. Only request information that you think is absolutely necessary for docstring generation task. The current code shows a database connection function. To write a comprehensive docstring, we need to understand: 1. Where this function is called - this will reveal the expected input patterns and common use cases 2. What internal database functions it relies on - this will help document any dependencies or prerequisites This additional context is necessary because database connections often have specific setup requirements and usage patterns that should be documented for proper implementation. true execute_query,connect_db self.process_data,data_processor._internal_process true Keep in mind that: 3. You do not need to generate docstring for the component. Just determine if more information is needed. """ self.add_to_memory("system", self.system_prompt) def process(self, focal_component: str, context: str = "") -> str: """Process the input and determine if more context is needed. Args: instruction: The instruction for docstring generation focal_component: The code component needing a docstring (full code snippet) component_type: The type of the code component (function, method, or class) context: Current context information (if any) Returns: A string containing the analysis and tag indicating if more information is needed """ # Add the current task to memory task_description = f""" Current context: {context if context else 'No context provided yet.'} Analyze the following code component: {focal_component} """ self.add_to_memory("user", task_description) # Generate response using LLM response = self.generate_response() return response ================================================ FILE: src/agent/searcher.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, List, Any, Optional from .base import BaseAgent from .reader import InformationRequest from .tool.internal_traverse import ASTNodeAnalyzer # Updated import to use only ASTNodeAnalyzer from .tool.perplexity_api import PerplexityAPI, PerplexityResponse import re from dataclasses import dataclass, field import xml.etree.ElementTree as ET from io import StringIO import ast # Keep for type annotations @dataclass class ParsedInfoRequest: """Structured format for parsed information requests. Attributes: internal_requests: Dictionary containing: - call: Dictionary with keys 'class', 'function', 'method', each containing a list of code component names that are called - call_by: Boolean indicating if caller information is needed external_requests: List of query strings for external information search """ internal_requests: Dict[str, Any] = field(default_factory=lambda: { 'call': { 'class': [], 'function': [], 'method': [] }, 'call_by': False }) external_requests: List[str] = field(default_factory=list) class Searcher(BaseAgent): """Agent responsible for gathering requested information from internal and external sources.""" def __init__(self, repo_path: str, config_path: Optional[str] = None): """Initialize the Searcher agent. Args: repo_path: Path to the repository being analyzed config_path: Optional path to the configuration file """ super().__init__("Searcher", config_path=config_path) self.repo_path = repo_path self.ast_analyzer = ASTNodeAnalyzer(repo_path) def process( self, reader_response: str, ast_node: ast.AST, ast_tree: ast.AST, dependency_graph: Dict[str, List[str]], focal_node_dependency_path: str ) -> Dict[str, Any]: """Process the reader's response and gather the requested information. Args: reader_response: Response from the Reader agent containing information requests in structured XML format ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_graph: Dictionary mapping component paths to their dependencies focal_node_dependency_path: Dependency path of the focal component Returns: A dictionary containing the gathered information, structured as: { 'internal': { 'calls': { 'class': ['class1': 'content1', 'class2': 'content2', ...], 'function': ['func1': 'content1', 'func2': 'content2', ...], 'method': ['method1': 'content1', 'method2': 'content2', ...], }, 'called_by': ['code snippet1', 'code snippet2', ...], }, 'external': { 'query1': 'result1', 'query2': 'result2' } } """ # Parse the reader's response into structured format parsed_request = self._parse_reader_response(reader_response) # Gather internal information using dependency graph and AST analyzer internal_info = self._gather_internal_info( ast_node, ast_tree, focal_node_dependency_path, dependency_graph, parsed_request ) # Gather external information using Perplexity API external_info = self._gather_external_info(parsed_request.external_requests) return { 'internal': internal_info, 'external': external_info } def _parse_reader_response(self, reader_response: str) -> ParsedInfoRequest: """Parse the reader's structured XML response. Args: reader_response: Response from Reader agent containing XML Returns: ParsedInfoRequest object containing structured requests """ # Extract the XML content between REQUEST tags xml_match = re.search(r'(.*?)', reader_response, re.DOTALL) if not xml_match: # Return empty request if no valid XML found return ParsedInfoRequest() xml_content = f'{xml_match.group(1)}' try: # Parse XML root = ET.fromstring(xml_content) # Parse internal requests internal = root.find('INTERNAL') calls = internal.find('CALLS') internal_requests = { 'call': { 'class': self._parse_comma_list(calls.find('CLASS').text), 'function': self._parse_comma_list(calls.find('FUNCTION').text), 'method': self._parse_comma_list(calls.find('METHOD').text) }, 'call_by': internal.find('CALL_BY').text.lower() == 'true' } # Parse external requests external = root.find('RETRIEVAL') external_requests = self._parse_comma_list(external.find('QUERY').text) return ParsedInfoRequest( internal_requests=internal_requests, external_requests=external_requests ) except (ET.ParseError, AttributeError) as e: print(f"Error parsing XML: {e}") # Return empty request if XML parsing fails return ParsedInfoRequest() def _parse_comma_list(self, text: str | None) -> List[str]: """Parse comma-separated text into list of strings. Args: text: Comma-separated text or None Returns: List of non-empty strings """ if not text: return [] return [item.strip() for item in text.split(',') if item.strip()] def _gather_internal_info( self, ast_node: ast.AST, ast_tree: ast.AST, focal_dependency_path: str, dependency_graph: Dict[str, List[str]], parsed_request: ParsedInfoRequest ) -> Dict[str, Any]: """Gather internal information using the dependency graph and AST analyzer. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file focal_dependency_path: Dependency path of the focal component dependency_graph: Dictionary mapping component paths to their dependencies parsed_request: Structured format of information requests Returns: Dictionary containing gathered internal information structured as: { 'calls': { 'class': {'class_name': 'code_content', ...}, 'function': {'function_name': 'code_content', ...}, 'method': {'method_name': 'code_content', ...} }, 'called_by': ['code_snippet1', 'code_snippet2', ...] } """ result = { 'calls': { 'class': {}, 'function': {}, 'method': {} }, 'called_by': [] } # Get dependencies of the focal component from the dependency graph component_dependencies = dependency_graph.get(focal_dependency_path, []) # Process class dependencies if parsed_request.internal_requests['call']['class']: requested_classes = parsed_request.internal_requests['call']['class'] for dependency_path in component_dependencies: # Check if this is a class dependency by looking at capitalization of the last part path_parts = dependency_path.split('.') if path_parts and path_parts[-1][0].isupper(): # This looks like a class dependency class_name = path_parts[-1] # Check if this class is in the requested classes # Use flexible matching for partial class names or with prefixes for requested_class in requested_classes: # Match by exact name, or as part of a path if (requested_class == class_name or requested_class in dependency_path or class_name.endswith(requested_class)): # Get the class initialization code class_code = self.ast_analyzer.get_component_by_path( ast_node, ast_tree, dependency_path ) if class_code: result['calls']['class'][requested_class] = class_code break # Process function dependencies if parsed_request.internal_requests['call']['function']: requested_functions = parsed_request.internal_requests['call']['function'] for dependency_path in component_dependencies: # Check if this is likely a function (last part starts with lowercase) path_parts = dependency_path.split('.') if path_parts and path_parts[-1][0].islower(): # This looks like a function or method, differentiate by checking if it's in a class # If the second-to-last part starts with uppercase, it's likely a method if len(path_parts) >= 2 and path_parts[-2][0].isupper(): # This is likely a method, skip for now continue function_name = path_parts[-1] # Check if this function is in the requested functions for requested_function in requested_functions: # Match by exact name, or as part of a path if (requested_function == function_name or requested_function in dependency_path or function_name.endswith(requested_function)): # Get the function code function_code = self.ast_analyzer.get_component_by_path( ast_node, ast_tree, dependency_path ) if function_code: result['calls']['function'][requested_function] = function_code break # Process method dependencies if parsed_request.internal_requests['call']['method']: requested_methods = parsed_request.internal_requests['call']['method'] for dependency_path in component_dependencies: # Check if this is likely a method (part after a part that starts with uppercase) path_parts = dependency_path.split('.') if len(path_parts) >= 2 and path_parts[-1][0].islower() and path_parts[-2][0].isupper(): method_name = path_parts[-1] class_name = path_parts[-2] full_method_name = f"{class_name}.{method_name}" # Check if this method is in the requested methods for requested_method in requested_methods: # Match by exact name, class.method, or just method name if (requested_method == full_method_name or requested_method == method_name or requested_method in dependency_path or method_name.endswith(requested_method)): # Get the method code method_code = self.ast_analyzer.get_component_by_path( ast_node, ast_tree, dependency_path ) if method_code: result['calls']['method'][requested_method] = method_code break # Handle call_by (what calls this component) if parsed_request.internal_requests['call_by']: parent_components = self.ast_analyzer.get_parent_components( ast_node, ast_tree, focal_dependency_path, dependency_graph ) if parent_components: result['called_by'].extend(parent_components) else: result['called_by'].append("This component is never called by any other component.") return result def _gather_external_info(self, queries: List[str]) -> Dict[str, str]: """Gather external information using Perplexity API. Args: queries: List of search queries Returns: Dictionary mapping queries to their responses """ if not queries: return {} try: perplexity = PerplexityAPI() responses = perplexity.batch_query( questions=queries, system_prompt="You are a helpful assistant providing concise and accurate information about programming concepts and code. Focus on technical accuracy and clarity.", temperature=0.1 ) # Create mapping of queries to responses results = {} for query, response in zip(queries, responses): if response is not None: results[query] = response.content else: results[query] = "Error: Failed to get response from Perplexity API" return results except Exception as e: print(f"Error using Perplexity API: {str(e)}") return {query: f"Error: {str(e)}" for query in queries} ================================================ FILE: src/agent/tool/README.md ================================================ # AST Call Graph Analysis Tool This tool provides functionality to analyze Python codebases by building and querying call graphs using Abstract Syntax Tree (AST) parsing. It helps in understanding code relationships and dependencies between functions, methods, and classes. ## Features ### Call Graph Building - Automatically builds a complete call graph for a Python repository - Tracks relationships between functions, methods, and classes - Handles cross-file dependencies - Caches AST parsing results for better performance ### Code Component Analysis The tool provides six main functionalities for analyzing code relationships: 1. **Child Function Analysis** (`get_child_function`) - Input: Component signature, file path, and child function name - Output: Full code of the function being called - Use case: Finding implementation of functions called within your code 2. **Child Method Analysis** (`get_child_method`) - Input: Component signature, file path, and child method name - Output: Full code of the method being called - Use case: Finding implementation of methods called on objects 3. **Child Class Analysis** (`get_child_class`) - Input: Component signature, file path, and child class name - Output: Class signature and initialization code - Use case: Finding class definitions for instantiated objects 4. **Parent Function Analysis** (`get_parent_function`) - Input: Component signature, file path, and parent function name - Output: Full code of the function that calls the component - Use case: Finding where a function is being used 5. **Parent Method Analysis** (`get_parent_method`) - Input: Component signature, file path, and parent method name - Output: Full code of the method that calls the component - Use case: Finding where a method is being called 6. **Parent Class Analysis** (`get_parent_class`) - Input: Component signature, file path, and parent class name - Output: Full code of the class that uses the component - Use case: Finding classes that depend on other classes ## Usage Example ```python from agent.tool.ast import CallGraphBuilder # Initialize the builder with repository path builder = CallGraphBuilder("/path/to/repo") # Find where a function is called parent_code = builder.get_parent_function( "def process_data(self):", "src/data/processor.py", "main_function" ) # Find what methods a class uses child_code = builder.get_child_method( "class DataProcessor:", "src/data/processor.py", "transform_data" ) ``` ## Implementation Details - Uses Python's built-in `ast` module for code parsing - Maintains parent-child relationships in AST nodes - Handles various Python constructs: - Function definitions and calls - Class definitions and instantiations - Method calls (both direct and through objects) - Static methods - Internal methods - Cross-file dependencies ## Limitations - Currently only supports Python files - Requires valid Python syntax in source files - Does not handle dynamic code execution (eval, exec) - Method resolution is name-based (doesn't handle complex inheritance) - Doesn't track calls through variables or complex expressions ================================================ FILE: src/agent/tool/ast.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import ast from typing import Dict, List, Optional, Set, Tuple, Union from pathlib import Path import os from abc import ABC, abstractmethod class ASTUtility(ABC): """Abstract base class for AST utilities.""" @abstractmethod def _get_component_name_from_code(self, code_snippet: str) -> Optional[str]: """Extract component name from a code snippet. Args: code_snippet (str): The full code snippet of a function/method/class Returns: Optional[str]: The name of the component if found, None otherwise Example: >>> builder = CallGraphBuilder("repo_path") >>> builder._get_component_name_from_code("def process_data(self):\\n return data") 'process_data' >>> builder._get_component_name_from_code("class DataProcessor:\\n def __init__(self):") 'DataProcessor' """ pass def _is_code_similar(self, code1: str, code2: str, threshold: float = 0.9) -> bool: """Check if two code snippets are similar using fuzzy matching. Args: code1 (str): First code snippet code2 (str): Second code snippet threshold (float): Similarity threshold (0.0 to 1.0). Default is 0.9 Returns: bool: True if similarity score is above threshold """ # Special handling for class components if code1.lstrip().startswith('class ') and code2.lstrip().startswith('class '): # For classes, just compare the class names class1_name = self._get_component_name_from_code(code1) class2_name = self._get_component_name_from_code(code2) return class1_name == class2_name # Normalize whitespace and remove empty lines def normalize(code: str) -> str: return '\n'.join(line.strip() for line in code.split('\n') if line.strip()) code1_norm = normalize(code1) code2_norm = normalize(code2) # Simple length-based early check if abs(len(code1_norm) - len(code2_norm)) / max(len(code1_norm), len(code2_norm)) > (1 - threshold): return False # Character-based similarity score matches = sum(a == b for a, b in zip(code1_norm, code2_norm)) similarity = matches / max(len(code1_norm), len(code2_norm)) return similarity >= threshold def _get_component_name_from_code(code_snippet: str) -> Optional[str]: """Extract component name from a code snippet. Args: code_snippet (str): The full code snippet of a function/method/class Returns: Optional[str]: The name of the component if found, None otherwise Example: >>> _get_component_name_from_code("def process_data(self):\\n return data") 'process_data' >>> _get_component_name_from_code("class DataProcessor:\\n def __init__(self):") 'DataProcessor' """ # Remove leading whitespace and get first line first_line = code_snippet.lstrip().split('\n')[0] # Check if it's a class if first_line.startswith('class '): # Find the class name - it's between 'class ' and either '(' or ':' class_decl = first_line[6:].strip() # Remove 'class ' prefix class_name = class_decl.split('(')[0].split(':')[0].strip() return class_name # Check if it's a function/method elif first_line.startswith('def '): # Find the function name - it's between 'def ' and '(' func_decl = first_line[4:].strip() # Remove 'def ' prefix func_name = func_decl.split('(')[0].strip() return func_name return None class ParentNodeTransformer(ast.NodeTransformer): """AST transformer that adds parent references to each node.""" def visit(self, node): for child in ast.iter_child_nodes(node): child.parent = node return super().visit(node) class CallGraphBuilder(ASTUtility): """A class to build and analyze call graphs for Python code. This class helps analyze function calls, method calls, and class relationships within a Python repository. """ def __init__(self, repo_path: str): """Initialize the CallGraphBuilder with a repository path. Args: repo_path (str): Path to the Python repository to analyze """ self.repo_path = Path(repo_path) self.call_graph = {} self.class_info = {} self.method_info = {} self.function_info = {} self.file_asts = {} self._build_call_graph() def _parse_file(self, file_path: str) -> ast.AST: """Parse a Python file and return its AST. Args: file_path (str): Path to the file relative to repo_path """ if file_path in self.file_asts: return self.file_asts[file_path] # Construct absolute path by joining repo_path with file_path abs_path = self.repo_path / file_path with open(abs_path) as f: content = f.read() tree = ast.parse(content) # Add parent references transformer = ParentNodeTransformer() tree = transformer.visit(tree) self.file_asts[file_path] = tree return tree def _get_signature_from_code(self, code: str, is_class: bool = False) -> str: """Extract signature from code. For functions/methods: signature ends with first ':' after first matching ')' For classes: signature is the class definition line ending with ':'""" lines = code.split('\n') first_line = lines[0].strip() if is_class: return first_line # For functions/methods # Find the closing parenthesis paren_count = 0 end_paren_idx = -1 for i, char in enumerate(first_line): if char == '(': paren_count += 1 elif char == ')': paren_count -= 1 if paren_count == 0: end_paren_idx = i break if end_paren_idx == -1: return first_line # Find the first : after the closing parenthesis colon_idx = first_line.find(':', end_paren_idx) if colon_idx == -1: return first_line return first_line[:colon_idx+1] def _get_node_code(self, file_path: str, node: ast.AST) -> str: """Get the source code for a node. Args: file_path (str): Path to the file relative to repo_path node (ast.AST): The AST node to get code for """ abs_path = self.repo_path / file_path with open(abs_path) as f: content = f.readlines() return ''.join(content[node.lineno-1:node.end_lineno]) def _is_method(self, node: ast.FunctionDef) -> bool: """Check if a function definition is a method.""" parent = getattr(node, 'parent', None) while parent is not None: if isinstance(parent, ast.ClassDef): return True parent = getattr(parent, 'parent', None) return False def _build_call_graph(self): """Build the complete call graph for the repository.""" for root, _, files in os.walk(self.repo_path): for file in files: if not file.endswith('.py'): continue abs_file_path = Path(root) / file # Convert absolute path to relative path rel_file_path = str(abs_file_path.relative_to(self.repo_path)) tree = self._parse_file(rel_file_path) for node in ast.walk(tree): if isinstance(node, ast.ClassDef): # Store class info class_code = self._get_node_code(rel_file_path, node) self.class_info[(rel_file_path, class_code)] = node # Store method info for item in node.body: if isinstance(item, ast.FunctionDef): method_code = self._get_node_code(rel_file_path, item) self.method_info[(rel_file_path, method_code)] = item elif isinstance(node, ast.FunctionDef): if not self._is_method(node): # Store function info func_code = self._get_node_code(rel_file_path, node) self.function_info[(rel_file_path, func_code)] = node def _get_component_name_from_code(self, code_snippet: str) -> Optional[str]: """Extract component name from a code snippet. Args: code_snippet (str): The full code snippet of a function/method/class Returns: Optional[str]: The name of the component if found, None otherwise """ return _get_component_name_from_code(code_snippet) def get_child_function(self, code_component: str, file_path: str, child_function: str) -> Optional[str]: """Get the code of a child function that is called by the component. Args: code_component (str): The full code snippet of the calling component. This is used to uniquely identify the component in case of name collisions. file_path (str): Path to the file containing the component child_function (str): Name of the function being called Returns: Optional[str]: The code of the child function if found, None otherwise Example: >>> builder = CallGraphBuilder("repo_path") >>> builder.get_child_function( ... "def main_function():\\n result = utility_function()\\n return result", ... "main.py", ... "utility_function" ... ) 'def utility_function():\\n return "utility"' """ tree = self._parse_file(file_path) target_node = None component_name = self._get_component_name_from_code(code_component) if not component_name: return None # Find the target node for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == component_name: # Get the code of this node and verify it matches using fuzzy matching node_code = self._get_node_code(file_path, node) if self._is_code_similar(node_code, code_component): target_node = node break if not target_node: return None # Look for calls to the child function for node in ast.walk(target_node): if isinstance(node, ast.Call): if isinstance(node.func, ast.Name) and node.func.id == child_function: # Find the function definition for func_file, func_code in self.function_info: func_node = self.function_info[(func_file, func_code)] if func_node.name == child_function: return func_code return None def _resolve_instance_type(self, node: ast.AST, instance_name: str) -> Optional[str]: """Resolve the class type of an instance variable by looking at assignments. Args: node: The AST node to start searching from (usually a function/method) instance_name: The name of the instance variable to resolve Returns: Optional[str]: The name of the class if found, None otherwise """ # First check local assignments in the current function/method for n in ast.walk(node): if isinstance(n, ast.Assign): for target in n.targets: if isinstance(target, ast.Name) and target.id == instance_name: if isinstance(n.value, ast.Call) and isinstance(n.value.func, ast.Name): return n.value.func.id # If not found locally and we're in a method, check class __init__ if isinstance(node, ast.FunctionDef): class_node = self._get_class_node(node) if class_node: for method in class_node.body: if isinstance(method, ast.FunctionDef) and method.name == '__init__': for n in ast.walk(method): if isinstance(n, ast.Assign): for target in n.targets: if isinstance(target, ast.Attribute) and \ isinstance(target.value, ast.Name) and \ target.value.id == 'self' and \ target.attr == instance_name and \ isinstance(n.value, ast.Call) and \ isinstance(n.value.func, ast.Name): return n.value.func.id return None def _get_class_node(self, method_node: ast.FunctionDef) -> Optional[ast.ClassDef]: """Get the ClassDef node that contains this method.""" parent = getattr(method_node, 'parent', None) while parent is not None: if isinstance(parent, ast.ClassDef): return parent parent = getattr(parent, 'parent', None) return None def get_child_method(self, code_component: str, file_path: str, method_name: str, prefix: Optional[str] = None, find_all: bool = False) -> Union[Optional[str], Dict[str, str]]: """Get the code of a child method that is called by the component. Args: code_component (str): The full code snippet of the calling component. This is used to uniquely identify the component in case of name collisions. file_path (str): Path to the file containing the component method_name (str): Name of the method being called prefix (Optional[str]): Optional prefix before method name (e.g., 'self', instance name, or class name) find_all (bool): Whether to find all methods with this name across classes Returns: If find_all=False: Optional[str]: The code of the child method if found, None otherwise If find_all=True: Dict[str, str]: Dictionary mapping class names to method code for all matching methods Note: This method handles three types of method calls: 1. self.method() - method in same class 2. ClassName.method() - direct class method call 3. instance.method() - method call through instance variable If prefix is provided: - If prefix is 'self': looks for method in the same class - If prefix starts with uppercase: treats it as a class name - If prefix starts with lowercase: treats it as an instance variable """ tree = self._parse_file(file_path) target_node = None component_name = self._get_component_name_from_code(code_component) if not component_name: return {} if find_all else None # Find the target node for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == component_name: # Get the code of this node and verify it matches using fuzzy matching node_code = self._get_node_code(file_path, node) if self._is_code_similar(node_code, code_component): target_node = node break if not target_node: return {} if find_all else None if find_all: # Find all methods with this name across all classes results = {} for method_file, method_code in self.method_info: method_node = self.method_info[(method_file, method_code)] if method_node.name == method_name: class_node = self._get_class_node(method_node) if class_node: results[class_node.name] = method_code return results # If prefix is provided, use it to narrow down the search if prefix is not None: target_class = None if prefix == 'self': # Case 1: self.method() target_class = self._get_class_of_method(target_node) elif prefix[0].isupper(): # Case 2: ClassName.method() target_class = prefix else: # Case 3: instance.method() target_class = self._resolve_instance_type(target_node, prefix) if target_class: for method_file, method_code in self.method_info: method_node = self.method_info[(method_file, method_code)] if method_node.name == method_name: # Verify this method belongs to the target class method_class = self._get_class_of_method(method_node) if method_class == target_class: return method_code return None # If no prefix or target class not found, fall back to original behavior # Look for method calls for node in ast.walk(target_node): if isinstance(node, ast.Call): if isinstance(node.func, ast.Attribute) and node.func.attr == method_name: target_class = None if isinstance(node.func.value, ast.Name): if node.func.value.id == 'self': # Case 1: self.method() target_class = self._get_class_of_method(target_node) else: # Case 2: ClassName.method() or Case 3: instance.method() # Try as class name first for class_file, class_code in self.class_info: class_node = self.class_info[(class_file, class_code)] if class_node.name == node.func.value.id: target_class = class_node.name break # If not found as class name, try as instance variable if not target_class: target_class = self._resolve_instance_type(target_node, node.func.value.id) elif isinstance(node.func.value, ast.Attribute): # Handle nested attributes like self.processor.process() if isinstance(node.func.value.value, ast.Name): if node.func.value.value.id == 'self': # Get type of self.processor instance_var = node.func.value.attr target_class = self._resolve_instance_type(target_node, instance_var) # If we found the target class, find the method if target_class: for method_file, method_code in self.method_info: method_node = self.method_info[(method_file, method_code)] if method_node.name == method_name: # Verify this method belongs to the target class method_class = self._get_class_of_method(method_node) if method_class == target_class: return method_code return None def get_child_class(self, code_component: str, file_path: str, child_class: str) -> Optional[str]: """Get the class signature and init function of a child class used by the component. Args: code_component (str): The full code snippet of the calling component. This is used to uniquely identify the component in case of name collisions. file_path (str): Path to the file containing the calling component child_class (str): Name of the class being used Returns: Optional[str]: The code of the child class and its __init__ if found, None otherwise Example: >>> builder = CallGraphBuilder("repo_path") >>> builder.get_child_class( ... "def main_function():\\n helper = HelperClass()\\n return helper.data", ... "main.py", ... "HelperClass" ... ) 'class HelperClass:\\n def __init__(self):\\n self.data = []' """ tree = self._parse_file(file_path) target_node = None component_name = self._get_component_name_from_code(code_component) if not component_name: return None # Find the target node for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == component_name: # Get the code of this node and verify it matches using fuzzy matching node_code = self._get_node_code(file_path, node) if self._is_code_similar(node_code, code_component): target_node = node break if not target_node: return None # Look for class usage for node in ast.walk(target_node): if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): if node.func.id == child_class: # Find the class definition for class_file, class_code in self.class_info: class_node = self.class_info[(class_file, class_code)] if class_node.name == child_class: # Get class signature and __init__ init_method = None for item in class_node.body: if isinstance(item, ast.FunctionDef) and item.name == '__init__': init_method = self._get_node_code(class_file, item) break if init_method: return f"{class_code}\n{init_method}" return class_code return None def get_child_class_init(self, code_component: str, file_path: str, child_class: str) -> Optional[str]: """Get the class signature and init function of a child class used by the component. Similar to get_child_class but only returns up to the end of __init__ if it exists. Args: code_component (str): The full code snippet of the calling component. This is used to uniquely identify the component in case of name collisions. file_path (str): Path to the file containing the calling component child_class (str): Name of the class being used Returns: Optional[str]: The code of the child class up to the end of __init__ if found, or the full class code if __init__ doesn't exist, None if class not found Example: >>> builder = CallGraphBuilder("repo_path") >>> builder.get_child_class_init( ... "def main_function():\\n helper = HelperClass()\\n return helper.data", ... "main.py", ... "HelperClass" ... ) 'class HelperClass:\\n def __init__(self):\\n self.data = []' """ # Get the full class code first using existing method full_code = self.get_child_class(code_component, file_path, child_class) if not full_code: return None # Split into lines for analysis lines = full_code.split('\n') # Find the __init__ method init_start = -1 for i, line in enumerate(lines): if line.strip().startswith('def __init__'): init_start = i break # If no __init__, return full code if init_start == -1: return full_code # Find the next method definition after __init__ next_method_start = -1 for i, line in enumerate(lines[init_start + 1:], start=init_start + 1): if line.strip().startswith('def '): next_method_start = i break # If no next method found, return up to the end if next_method_start == -1: return full_code # Return code up to the start of next method return '\n'.join(lines[:next_method_start]) def _get_class_of_method(self, method_node: ast.FunctionDef) -> Optional[str]: """Get the name of the class that contains this method.""" parent = getattr(method_node, 'parent', None) while parent is not None: if isinstance(parent, ast.ClassDef): return parent.name parent = getattr(parent, 'parent', None) return None def get_parent(self, code_component: str, file_path: str, class_name: Optional[str] = None) -> List[str]: """Get the code of any components that use the focal component. Args: code_component: String representation of the component file_path: Path to the file containing the component class_name: If the component is a method, specify its class name to avoid false matches with methods of same name in other classes Returns: List[str]: List of code blocks of parent components that use this component """ results = [] component_name = self._get_component_name_from_code(code_component) if not component_name: return [] tree = self._parse_file(file_path) found_target = False for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == component_name: node_code = self._get_node_code(file_path, node) if self._is_code_similar(node_code, code_component): found_target = True break if not found_target: return [] # Check functions for func_file, func_code in self.function_info: func_node = self.function_info[(func_file, func_code)] # Check if this function calls our component for node in ast.walk(func_node): if isinstance(node, ast.Call): if isinstance(node.func, ast.Name) and node.func.id == component_name: results.append(func_code) break # Found usage in this function, move to next # Check methods for method_file, method_code in self.method_info: method_node = self.method_info[(method_file, method_code)] # Skip __init__ methods if method_node.name == '__init__': continue # Check if this method calls our component for node in ast.walk(method_node): if isinstance(node, ast.Call): if isinstance(node.func, ast.Attribute) and node.func.attr == component_name: # If class_name is specified, verify the method belongs to that class if class_name: # Get the class of the target method target_class = None if isinstance(node.func.value, ast.Name): # For self.method() calls if node.func.value.id == 'self': target_class = self._get_class_of_method(method_node) # For ClassName.method() calls else: target_class = node.func.value.id # For instance.method() calls through instance variables elif isinstance(node.func.value, ast.Attribute): # Try to find the instance variable in __init__ method_class = self._get_class_of_method(method_node) if method_class: # Look up the class definition for class_file, class_code in self.class_info: class_node = self.class_info[(class_file, class_code)] if class_node.name == method_class: # Find __init__ method for init_node in class_node.body: if isinstance(init_node, ast.FunctionDef) and init_node.name == '__init__': # Look for assignments to this instance variable instance_var = node.func.value.value.id # e.g., 'self' from self.data_processor var_name = node.func.value.attr # e.g., 'data_processor' from self.data_processor if instance_var == 'self': for n in ast.walk(init_node): if isinstance(n, ast.Assign): for target in n.targets: if isinstance(target, ast.Attribute) and \ isinstance(target.value, ast.Name) and \ target.value.id == 'self' and \ target.attr == var_name and \ isinstance(n.value, ast.Call): # Found the initialization if isinstance(n.value.func, ast.Name): target_class = n.value.func.id break if target_class == class_name: results.append(method_code) else: results.append(method_code) break # Found usage in this method, move to next elif isinstance(node.func, ast.Name) and node.func.id == component_name: results.append(method_code) break # Found usage in this method, move to next # Check class __init__ methods for class_file, class_code in self.class_info: class_node = self.class_info[(class_file, class_code)] # Look for __init__ method for node in class_node.body: if isinstance(node, ast.FunctionDef) and node.name == '__init__': # Check if __init__ uses our component for call_node in ast.walk(node): if isinstance(call_node, ast.Call): if isinstance(call_node.func, ast.Name) and call_node.func.id == component_name: # Get class signature and init method class_sig = self._get_node_code(class_file, class_node).split('\n')[0] init_code = self._get_node_code(class_file, node) results.append(f"{class_sig}\n{init_code}") break # Found usage in this class, move to next return results # Add this new class after the CallGraphBuilder class class ASTNodeAnalyzer: """A class to analyze AST nodes directly without string matching. This class works directly with AST nodes to analyze function calls, method calls, and class relationships within a Python repository, avoiding the need to re-parse files that have already been parsed. """ def __init__(self, repo_path: str): """Initialize the ASTNodeAnalyzer with a repository path. Args: repo_path (str): Path to the Python repository to analyze """ self.repo_path = Path(repo_path) # Reference to an existing CallGraphBuilder to reuse the pre-built info self.call_graph_builder = CallGraphBuilder(repo_path) def get_child_function(self, focal_node: ast.AST, file_tree: ast.AST, file_path: str, child_function: str) -> Optional[str]: """Get the code of a child function that is called by the component. Args: focal_node: The AST node representing the focal component file_tree: The AST tree for the entire file file_path: Path to the file containing the component child_function: Name of the function being called Returns: Optional[str]: The code of the child function if found, None otherwise """ # Look for calls to the child function in the focal node for node in ast.walk(focal_node): if isinstance(node, ast.Call): if isinstance(node.func, ast.Name) and node.func.id == child_function: # Find the function definition in the function_info dictionary for func_file, func_code in self.call_graph_builder.function_info: func_node = self.call_graph_builder.function_info[(func_file, func_code)] if func_node.name == child_function: return func_code return None def get_child_method(self, focal_node: ast.AST, file_tree: ast.AST, file_path: str, method_name: str, prefix: Optional[str] = None, find_all: bool = False) -> Union[Optional[str], Dict[str, str]]: """Get the code of a child method that is called by the component. Args: focal_node: The AST node representing the focal component file_tree: The AST tree for the entire file file_path: Path to the file containing the component method_name: Name of the method being called prefix: Optional prefix before method name (e.g., 'self', instance name, or class name) find_all: Whether to find all methods with this name across classes Returns: If find_all=False: Optional[str]: The code of the child method if found, None otherwise If find_all=True: Dict[str, str]: Dictionary mapping class names to method code for all matching methods """ if find_all: # Find all methods with this name across all classes results = {} for method_file, method_code in self.call_graph_builder.method_info: method_node = self.call_graph_builder.method_info[(method_file, method_code)] if method_node.name == method_name: class_node = self.call_graph_builder._get_class_node(method_node) if class_node: results[class_node.name] = method_code return results # If prefix is provided, use it to narrow down the search if prefix is not None: target_class = None if prefix == 'self': # Case 1: self.method() target_class = self.call_graph_builder._get_class_of_method(focal_node) elif prefix[0].isupper(): # Case 2: ClassName.method() target_class = prefix else: # Case 3: instance.method() target_class = self.call_graph_builder._resolve_instance_type(focal_node, prefix) if target_class: for method_file, method_code in self.call_graph_builder.method_info: method_node = self.call_graph_builder.method_info[(method_file, method_code)] if method_node.name == method_name: # Verify this method belongs to the target class method_class = self.call_graph_builder._get_class_of_method(method_node) if method_class == target_class: return method_code return None # If no prefix or target class not found, fall back to searching in the AST for node in ast.walk(focal_node): if isinstance(node, ast.Call): if isinstance(node.func, ast.Attribute) and node.func.attr == method_name: target_class = None if isinstance(node.func.value, ast.Name): if node.func.value.id == 'self': # Case 1: self.method() target_class = self.call_graph_builder._get_class_of_method(focal_node) else: # Case 2: ClassName.method() or Case 3: instance.method() # Try as class name first for class_file, class_code in self.call_graph_builder.class_info: class_node = self.call_graph_builder.class_info[(class_file, class_code)] if class_node.name == node.func.value.id: target_class = class_node.name break # If not found as class name, try as instance variable if not target_class: target_class = self.call_graph_builder._resolve_instance_type(focal_node, node.func.value.id) elif isinstance(node.func.value, ast.Attribute): # Handle nested attributes like self.processor.process() if isinstance(node.func.value.value, ast.Name): if node.func.value.value.id == 'self': # Get type of self.processor instance_var = node.func.value.attr target_class = self.call_graph_builder._resolve_instance_type(focal_node, instance_var) # If we found the target class, find the method if target_class: for method_file, method_code in self.call_graph_builder.method_info: method_node = self.call_graph_builder.method_info[(method_file, method_code)] if method_node.name == method_name: # Verify this method belongs to the target class method_class = self.call_graph_builder._get_class_of_method(method_node) if method_class == target_class: return method_code return None def get_child_class_init(self, focal_node: ast.AST, file_tree: ast.AST, file_path: str, child_class: str) -> Optional[str]: """Get the class signature and init function of a child class used by the component. Args: focal_node: The AST node representing the focal component file_tree: The AST tree for the entire file file_path: Path to the file containing the component child_class: Name of the class being used Returns: Optional[str]: The code of the child class up to the end of __init__ if found, or the full class code if __init__ doesn't exist, None if class not found """ # Look for calls to the child class in the focal node for node in ast.walk(focal_node): if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): if node.func.id == child_class: # Find the class definition for class_file, class_code in self.call_graph_builder.class_info: class_node = self.call_graph_builder.class_info[(class_file, class_code)] if class_node.name == child_class: # Get class signature and __init__ init_method = None for item in class_node.body: if isinstance(item, ast.FunctionDef) and item.name == '__init__': init_method = self.call_graph_builder._get_node_code(class_file, item) break if init_method: return f"{class_code}\n{init_method}" return class_code return None def get_parent_components(self, focal_node: ast.AST, file_tree: ast.AST, file_path: str, class_name: Optional[str] = None) -> List[str]: """Get the code of any components that use the focal component. Args: focal_node: The AST node representing the focal component file_tree: The AST tree for the entire file file_path: Path to the file containing the component class_name: If the component is a method, specify its class name to avoid false matches with methods of same name in other classes Returns: List[str]: List of code blocks of parent components that use the focal component """ # Check what type of node this is component_name = None if isinstance(focal_node, ast.FunctionDef): component_name = focal_node.name elif isinstance(focal_node, ast.ClassDef): component_name = focal_node.name else: return [] # Get the source code of the focal node focal_code = self.call_graph_builder._get_node_code(file_path, focal_node) # Now use the existing implementation from CallGraphBuilder return self.call_graph_builder.get_parent(focal_code, file_path, class_name) ================================================ FILE: src/agent/tool/internal_traverse.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import ast import os from typing import List, Optional, Dict, Any, Tuple class ASTNodeAnalyzer: """ Tool for analyzing AST nodes to find relationships between components in code. Used to identify calls (child components) and called_by (parent components). """ def __init__(self, repo_path: str): """ Initialize the AST Node Analyzer. Args: repo_path: Path to the repository being analyzed """ self.repo_path = repo_path def get_component_by_path( self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str ) -> Optional[str]: """ Universal function to get any code component (class, function, method) by its dependency path. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.component_name or: folder1.folder2.file.class_name.method_name Returns: The code of the component if found, None otherwise """ path_parts = dependency_path.split('.') if len(path_parts) < 2: return None # Determine the component type based on the path structure if len(path_parts) >= 3 and path_parts[-2] != 'self': # This could be a method: folder1.folder2.file.class_name.method_name last_part = path_parts[-1] second_last_part = path_parts[-2] # Check if this is likely a method if last_part[0].islower() and second_last_part[0].isupper(): # This looks like a method return self._get_method_component(ast_node, ast_tree, dependency_path) # Check if this is a class (typically starts with uppercase) if path_parts[-1][0].isupper(): # This looks like a class return self._get_class_component(ast_node, ast_tree, dependency_path) # Default to function (or could be a module) return self._get_function_component(ast_node, ast_tree, dependency_path) def _get_class_component(self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str) -> Optional[str]: """ Get a class component by its dependency path. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.ClassName Returns: The code of the class if found, None otherwise """ path_parts = dependency_path.split('.') class_name = path_parts[-1] file_name = path_parts[-2] + '.py' folder_path = os.path.join(*path_parts[:-2]) if len(path_parts) > 2 else '' # Special case for 'self' which refers to the current component if class_name == 'self': if isinstance(ast_node, ast.ClassDef): return self._get_node_source(file_path=os.path.relpath(ast_tree.file_path, self.repo_path) if hasattr(ast_tree, 'file_path') else "", node=ast_node) return None # First check if the class is used in the current file local_class_info = self._find_class_init_in_node(ast_node, class_name) if local_class_info: return local_class_info # Try to find the file in the repository target_file_path = os.path.join(folder_path, file_name) full_file_path = os.path.join(self.repo_path, target_file_path) # If file doesn't exist, return None if not os.path.exists(full_file_path): return None # Parse the target file and find the class try: with open(full_file_path, 'r') as f: file_content = f.read() target_ast = ast.parse(file_content) # Find the class in the target file for node in ast.walk(target_ast): if isinstance(node, ast.ClassDef) and node.name == class_name: return self._get_node_source(target_file_path, node) except Exception as e: return f"Error retrieving class {class_name}: {e}" return None def _get_function_component(self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str) -> Optional[str]: """ Get a function component by its dependency path. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.function_name Returns: The code of the function if found, None otherwise """ path_parts = dependency_path.split('.') function_name = path_parts[-1] file_name = path_parts[-2] + '.py' folder_path = os.path.join(*path_parts[:-2]) if len(path_parts) > 2 else '' # Special case for 'self' which refers to the current component if function_name == 'self': if isinstance(ast_node, ast.FunctionDef): return self._get_node_source(file_path=os.path.relpath(ast_tree.file_path, self.repo_path) if hasattr(ast_tree, 'file_path') else "", node=ast_node) return None # Try to find the file in the repository target_file_path = os.path.join(folder_path, file_name) full_file_path = os.path.join(self.repo_path, target_file_path) # If file doesn't exist, check the current file if not os.path.exists(full_file_path): # Look for the function in the current file for node in ast.walk(ast_tree): if isinstance(node, ast.FunctionDef) and node.name == function_name: return self._get_node_source(file_path=os.path.relpath(ast_tree.file_path, self.repo_path) if hasattr(ast_tree, 'file_path') else "", node=node) return None # Parse the target file and find the function try: with open(full_file_path, 'r') as f: file_content = f.read() target_ast = ast.parse(file_content) # Find the function in the target file for node in ast.walk(target_ast): if isinstance(node, ast.FunctionDef) and node.name == function_name: return self._get_node_source(target_file_path, node) except Exception as e: return f"Error retrieving function {function_name}: {e}" return None def _get_method_component(self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str) -> Optional[str]: """ Get a method component by its dependency path. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.ClassName.method_name Returns: The code of the method if found, None otherwise """ path_parts = dependency_path.split('.') if len(path_parts) < 3: # Need at least file.class.method return None method_name = path_parts[-1] class_name = path_parts[-2] file_name = path_parts[-3] + '.py' folder_path = os.path.join(*path_parts[:-3]) if len(path_parts) > 3 else '' # Special case for 'self' which refers to the current component if class_name == 'self': # Find the method in the current node if it's a class if isinstance(ast_node, ast.ClassDef): for item in ast_node.body: if isinstance(item, ast.FunctionDef) and item.name == method_name: return self._get_node_source(file_path=os.path.relpath(ast_tree.file_path, self.repo_path) if hasattr(ast_tree, 'file_path') else "", node=item) return None # Try to find the file in the repository target_file_path = os.path.join(folder_path, file_name) full_file_path = os.path.join(self.repo_path, target_file_path) # If file doesn't exist, check the current file if not os.path.exists(full_file_path): # Look for the class and method in the current file for node in ast.walk(ast_tree): if isinstance(node, ast.ClassDef) and node.name == class_name: for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == method_name: return self._get_node_source(file_path=os.path.relpath(ast_tree.file_path, self.repo_path) if hasattr(ast_tree, 'file_path') else "", node=item) return None # Parse the target file and find the class and method try: with open(full_file_path, 'r') as f: file_content = f.read() target_ast = ast.parse(file_content) # Find the class in the target file for node in ast.walk(target_ast): if isinstance(node, ast.ClassDef) and node.name == class_name: # Find the method in the class for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == method_name: return self._get_node_source(target_file_path, item) except Exception as e: return f"Error retrieving method {class_name}.{method_name}: {e}" return None def get_child_class_init( self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str ) -> Optional[str]: """ Get the class signature and init function of a child class used by the component. Returns up to the end of __init__ if it exists (to save tokens). Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.ClassName Returns: The code of the class initialization if found, None otherwise """ class_code = self.get_component_by_path(ast_node, ast_tree, dependency_path) if not class_code: return None # Parse the class code to find the __init__ method if it exists try: class_ast = ast.parse(class_code) for node in ast.walk(class_ast): if isinstance(node, ast.ClassDef): # Look for the __init__ method init_method = None for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == "__init__": init_method = item break if init_method: # Get the class signature and everything up to the end of __init__ class_lines = class_code.split('\n') init_end_line = init_method.end_lineno - node.lineno + 1 # Ensure init_end_line doesn't exceed the total lines init_end_line = min(init_end_line, len(class_lines)) # Return class signature through the end of __init__ return '\n'.join(class_lines[:init_end_line]) except: # If we can't parse the class code, just return it as is pass return class_code def get_child_function( self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str ) -> Optional[str]: """ Find a function that is called by the focal component. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.function_name Returns: The code of the function if found, None otherwise """ return self.get_component_by_path(ast_node, ast_tree, dependency_path) def get_child_method( self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str ) -> Optional[str]: """ Find a method that is called by the focal component. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the dependency in format: folder1.folder2.file.ClassName.method_name Returns: The code of the method if found, None otherwise """ return self.get_component_by_path(ast_node, ast_tree, dependency_path) def get_parent_components( self, ast_node: ast.AST, ast_tree: ast.AST, dependency_path: str, dependency_graph: Optional[Dict[str, List[str]]] = None ) -> List[str]: """ Find components that call/depend on the focal component by looking at the dependency graph. Args: ast_node: AST node representing the focal component ast_tree: AST tree for the entire file dependency_path: Path to the focal component in format: folder1.folder2.file.component_name dependency_graph: Optional dictionary mapping component ids to their dependencies. If not provided, will only check the current file. Returns: List of code snippets of components that call/depend on the focal component """ parent_components = [] # If no dependency graph provided, fall back to checking just the current file if not dependency_graph: component_name = self._get_component_name(ast_node) if not component_name: return parent_components # Parse the dependency path to get the file path for the current file path_parts = dependency_path.split('.') if len(path_parts) < 2: return parent_components file_name = path_parts[-2] + '.py' folder_path = os.path.join(*path_parts[:-2]) if len(path_parts) > 2 else '' target_file_path = os.path.join(folder_path, file_name) # Check for calls in the current file for node in ast.walk(ast_tree): # Skip the component itself if node == ast_node: continue # Check if this is a function, async function, or class definition if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): if self._contains_call_to(node, component_name): parent_components.append(self._get_node_source(target_file_path, node)) return parent_components # With dependency graph, we can find all components that depend on this component parent_ids = [] for component_id, dependencies in dependency_graph.items(): if dependency_path in dependencies: parent_ids.append(component_id) # Now retrieve the source code for each parent component for parent_id in parent_ids: parent_code = self.get_component_by_path(ast_node, ast_tree, parent_id) if parent_code: parent_components.append(parent_code) return parent_components def _find_class_init_in_node(self, ast_node: ast.AST, class_name: str) -> Optional[str]: """ Find class instantiation in the given node. Args: ast_node: AST node to search in class_name: Name of the class to find Returns: The code of the class instantiation if found, None otherwise """ for node in ast.walk(ast_node): if isinstance(node, ast.Call) and self._get_call_name(node) == class_name: return self._format_call_node(node) return None def _find_function_call_in_node(self, ast_node: ast.AST, function_name: str) -> bool: """ Check if a function is called in the given node. Args: ast_node: AST node to search in function_name: Name of the function to find Returns: True if the function is called, False otherwise """ for node in ast.walk(ast_node): if isinstance(node, ast.Call): call_name = self._get_call_name(node) if call_name == function_name: return True return False def _find_method_call_in_node( self, ast_node: ast.AST, method_name: str, prefix: Optional[str] = None ) -> bool: """ Check if a method is called in the given node. Args: ast_node: AST node to search in method_name: Name of the method to find prefix: Optional prefix (object name) of the method Returns: True if the method is called, False otherwise """ for node in ast.walk(ast_node): if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): if node.func.attr == method_name: if prefix is None or ( isinstance(node.func.value, ast.Name) and node.func.value.id == prefix ): return True return False def _find_class_for_prefix(self, ast_tree: ast.AST, prefix: Optional[str]) -> Optional[str]: """ Try to determine the class name for a given object prefix. This is a naive approach that checks for: prefix = ClassName() or prefix: ClassName Args: ast_tree: AST tree for the entire file prefix: The object name to find the class for Returns: Name of the class if found, None otherwise """ if not prefix: return None # Look for prefix = ClassName() for node in ast.walk(ast_tree): if isinstance(node, ast.Assign): for target in node.targets: if isinstance(target, ast.Name) and target.id == prefix: if ( isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) ): return node.value.func.id # Look for prefix: ClassName for node in ast.walk(ast_tree): if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): if node.target.id == prefix and isinstance(node.annotation, ast.Name): return node.annotation.id return None def _get_component_name(self, ast_node: ast.AST) -> Optional[str]: """ Get the name of a component (function, async function, or class). Args: ast_node: AST node representing the component Returns: Name of the component if present, None otherwise """ if isinstance(ast_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): return ast_node.name return None def _contains_call_to(self, ast_node: ast.AST, component_name: str) -> bool: """ Check if ast_node contains a call to the specified component name. Args: ast_node: AST node to check component_name: Name of the component to look for Returns: True if the node contains a call to the component, False otherwise """ for node in ast.walk(ast_node): if isinstance(node, ast.Call): call_name = self._get_call_name(node) if call_name == component_name: return True return False def _get_call_name(self, call_node: ast.Call) -> Optional[str]: """ Get the name being called in a Call node. Args: call_node: AST Call node Returns: Name being called, or None if it cannot be determined """ if isinstance(call_node.func, ast.Name): return call_node.func.id elif isinstance(call_node.func, ast.Attribute): return call_node.func.attr return None def _format_call_node(self, call_node: ast.Call) -> str: """ Format a call node as a string for demonstration. Args: call_node: AST Call node Returns: String representation of the call """ call_name = self._get_call_name(call_node) return f"{call_name}(...)" def _get_node_source(self, file_path: str, node: ast.AST) -> str: """ Get the source code for an AST node from the original file. Args: file_path: Path to the file containing the node node: AST node to get the source for Returns: Source code for the node, or an error message """ try: full_path = os.path.join(self.repo_path, file_path) with open(full_path, 'r') as f: file_content = f.read() start_line = node.lineno end_line = self._get_end_line(node, file_content) lines = file_content.split('\n') # Check for docstring if this is a function or class definition if isinstance(node, (ast.FunctionDef, ast.ClassDef)): # The docstring would be the first element in the body if it exists if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)): # Docstring is already included in the range from lineno to end_lineno pass # Safeguard: ensure end_line does not exceed total line count end_line = min(end_line, len(lines)) return '\n'.join(lines[start_line - 1:end_line]) except Exception as e: return f"Error retrieving source for {type(node).__name__}: {e}" def _get_end_line(self, node: ast.AST, file_content: str) -> int: """ Get the end line number for an AST node, using end_lineno if present. Args: node: AST node file_content: Content of the file Returns: End line number of the node """ if hasattr(node, 'end_lineno') and node.end_lineno: return node.end_lineno if hasattr(node, 'body') and node.body: last_subnode = node.body[-1] return self._get_end_line(last_subnode, file_content) return node.lineno ================================================ FILE: src/agent/tool/perplexity_api.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import os import requests from typing import List, Dict, Any from dataclasses import dataclass import yaml @dataclass class PerplexityResponse: """Structured response from Perplexity API""" content: str raw_response: Dict[str, Any] class PerplexityAPI: """Wrapper for Perplexity API interactions""" def __init__(self, api_key: str | None = None, config_path: str = "config/agent_config.yaml"): """Initialize the API wrapper. Args: api_key: Perplexity API key. If None, will try to get from config. config_path: Path to the configuration file """ self.config = self._load_config(config_path) self.api_key = api_key or self.config.get('api_key') if not self.api_key: raise ValueError("Perplexity API key not provided and not found in config") self.base_url = "https://api.perplexity.ai/chat/completions" self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } def _load_config(self, config_path: str) -> Dict[str, Any]: """Load configuration from yaml file.""" try: with open(config_path, 'r') as f: config = yaml.safe_load(f) return config.get('perplexity', {}) except Exception as e: print(f"Warning: Could not load config file: {e}") return {} def query(self, question: str, system_prompt: str = "Be precise and concise.", temperature: float | None = None, model: str | None = None, max_output_tokens: int | None = 4096) -> PerplexityResponse: """Send a single query to Perplexity API. Args: question: The question to ask system_prompt: System prompt to guide the response temperature: Temperature for response generation (0.0-1.0) model: Model to use for generation max_output_tokens: Maximum tokens in response Returns: PerplexityResponse containing the response content and raw API response Raises: requests.exceptions.RequestException: If API request fails ValueError: If API response is invalid """ payload = { "model": model or self.config.get('model', 'sonar'), "messages": [ { "role": "system", "content": system_prompt }, { "role": "user", "content": question } ], "temperature": temperature or self.config.get('temperature', 0.1), "max_tokens": max_output_tokens or self.config.get('max_output_tokens', 200), "top_p": 0.9, "return_images": False, "return_related_questions": False } response = requests.post(self.base_url, json=payload, headers=self.headers) response.raise_for_status() response_data = response.json() if "choices" not in response_data or not response_data["choices"]: raise ValueError("Invalid API response: missing choices") content = response_data["choices"][0].get("message", {}).get("content", "") if not content: raise ValueError("Invalid API response: missing content") return PerplexityResponse(content=content, raw_response=response_data) def batch_query(self, questions: List[str], system_prompt: str = "Be precise and concise.", temperature: float | None = None, model: str | None = None, max_output_tokens: int | None = None) -> List[PerplexityResponse]: """Send multiple queries to Perplexity API. Args: questions: List of questions to ask system_prompt: System prompt to guide the responses temperature: Temperature for response generation (0.0-1.0) model: Model to use for generation max_output_tokens: Maximum tokens in response Returns: List of PerplexityResponse objects """ responses = [] for question in questions: try: response = self.query( question=question, system_prompt=system_prompt, temperature=temperature, model=model, max_output_tokens=max_output_tokens ) responses.append(response) except Exception as e: # If a query fails, add None to maintain order with input questions print(f"Error querying Perplexity API: {str(e)}") responses.append(None) return responses ================================================ FILE: src/agent/verifier.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Optional, List from .base import BaseAgent class Verifier(BaseAgent): """Agent responsible for verifying the quality of generated docstrings.""" def __init__(self, config_path: Optional[str] = None): """Initialize the Verifier agent. Args: config_path: Optional path to the configuration file """ super().__init__("Verifier", config_path=config_path) self.system_prompt = """You are a Verifier agent responsible for ensuring the quality of generated docstrings. Your role is to evaluate docstrings from the perspective of a first-time user encountering the code component. Analysis Process: 1. First read the code component as if you're seeing it for the first time 2. Read the docstring and analyze how well it helps you understand the code 3. Evaluate if the docstring provides the right level of abstraction and information Verification Criteria: 1. Information Value: - Identify parts that merely repeat the code without adding value - Flag docstrings that state the obvious without providing insights - Check if explanations actually help understand the purpose and usage 2. Appropriate Detail Level: - Flag overly detailed technical explanations of implementation - Ensure focus is on usage and purpose, not line-by-line explanation - Check if internal implementation details are unnecessarily exposed 3. Completeness Check: - Verify all required sections are present (summary, args, returns, etc.) - Check if each section provides meaningful information - Ensure critical usage information is not missing Output Format: Your analysis must include: 1. true/false - Indicates if docstring needs improvement 2. If revision needed: true/false - Indicates if additional context is required for improvement - Keep in mind that collecting context is very expensive and may fail, so only use it when absolutely necessary 3. Based on MORE_CONTEXT, provide suggestions at the end of your response: If true: explain why and what specific context is needed If false: specific improvement suggestions Do not generate other things after or . """ self.add_to_memory("system", self.system_prompt) def process( self, focal_component: str, docstring: str, context: str = "" ) -> str: """Verify the quality of a generated docstring. Args: instruction: The original instruction for docstring generation focal_component: The code component with the docstring component_type: The type of the code component docstring: The generated docstring to verify context: The context used to generate the docstring Returns: List of VerificationFeedback objects for each aspect that needs improvement """ task_description = f""" Context Used: {context if context else 'No context was used.'} Verify the quality of the following docstring for the following Code Component: Code Component: {focal_component} Generated Docstring: {docstring} """ self.add_to_memory("user", task_description) full_response = self.generate_response() return full_response ================================================ FILE: src/agent/workflow.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Optional from pathlib import Path from .orchestrator import Orchestrator from .reader import CodeComponentType def generate_docstring( repo_path: str, file_path: str, focal_component: str, component_type: CodeComponentType, instruction: Optional[str] = None ) -> str: """Generate a high-quality docstring for a code component using the multi-agent system. Args: repo_path: Path to the repository containing the code file_path: Path to the file containing the component focal_component: The code component needing a docstring component_type: The type of the code component (function, method, or class) instruction: Optional specific instructions for docstring generation Returns: The generated and verified docstring Raises: FileNotFoundError: If the repository or file path doesn't exist ValueError: If the component type is invalid """ # Validate inputs repo_path = str(Path(repo_path).resolve()) file_path = str(Path(file_path).resolve()) if not Path(repo_path).exists(): raise FileNotFoundError(f"Repository path does not exist: {repo_path}") if not Path(file_path).exists(): raise FileNotFoundError(f"File path does not exist: {file_path}") # Use default instruction if none provided if instruction is None: instruction = """Generate a comprehensive and helpful docstring that includes: 1. A clear description of what the component does 2. All parameters and their types 3. Return value and type 4. Any exceptions that may be raised 5. Usage examples where appropriate The docstring should follow PEP 257 style guidelines.""" # Create orchestrator and generate docstring orchestrator = Orchestrator(repo_path) return orchestrator.process( instruction=instruction, focal_component=focal_component, component_type=component_type, file_path=file_path ) ================================================ FILE: src/agent/writer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, Optional from abc import abstractmethod from .base import BaseAgent from .reader import CodeComponentType class Writer(BaseAgent): """Agent responsible for generating high-quality docstrings based on the code and context.""" def __init__(self, config_path: Optional[str] = None): """Initialize the Writer agent. Args: config_path: Optional path to the configuration file """ super().__init__("Writer", config_path=config_path) # Base prompt that applies to all documentation self.base_prompt = """You are a Writer agent responsible for generating high-quality docstrings that are both complete and helpful. Accessible context is provided to you for generating the docstring. General Guidelines: 1. Make docstrings actionable and specific: - Focus on practical usage - Highlight important considerations - Include warnings or gotchas 2. Use clear, concise language: - Avoid jargon unless necessary - Use active voice - Be direct and specific 3. Type Information: - Include precise type hints - Note any type constraints - Document generic type parameters 4. Context and Integration: - Explain component relationships - Note any dependencies - Describe side effects 5. Follow Google docstring format: - Use consistent indentation - Maintain clear section separation - Keep related information grouped""" self.add_to_memory("system", self.base_prompt) # Class-specific prompt self.class_prompt = """You are documenting a CLASS. Focus on describing the object it represents and its role in the system. Required sections: 1. Summary: - One-line description focusing on WHAT the class represents - Avoid repeating the class name or obvious terms - Focus on the core purpose or responsibility 2. Description: - WHY: Explain the motivation and purpose behind this class - WHEN: Describe scenarios or conditions where this class should be used - WHERE: Explain how it fits into the larger system architecture - HOW: Provide a high-level overview of how it achieves its purpose 3. Example: - Show a practical, real-world usage scenario - Include initialization and common method calls - Demonstrate typical workflow Conditional sections: 1. Parameters (if class's __init__ has parameters): - Focus on explaining the significance of each parameter - Include valid value ranges or constraints - Explain parameter relationships if they exist 2. Attributes: - Explain the purpose and significance of each attribute - Include type information and valid values - Note any dependencies between attributes""" # Function/Method-specific prompt self.function_prompt = """You are documenting a FUNCTION or METHOD. Focus on describing the action it performs and its effects. Required sections: 1. Summary: - One-line description focusing on WHAT the function does - Avoid repeating the function name - Emphasize the outcome or effect 2. Description: - WHY: Explain the purpose and use cases - WHEN: Describe when to use this function - WHERE: Explain how it fits into the workflow - HOW: Provide high-level implementation approach Conditional sections: 1. Args (if present): - Explain the significance of each parameter - Include valid value ranges or constraints - Note any parameter interdependencies 2. Returns: - Explain what the return value represents - Include possible return values or ranges - Note any conditions affecting the return value 3. Raises: - List specific conditions triggering each exception - Explain how to prevent or handle exceptions 4. Examples (if public and not abstract): - Show practical usage scenarios - Include common parameter combinations - Demonstrate error handling if relevant""" def is_class_component(code: str) -> bool: """Determine if the given code component is a class definition. Args: code: The code component to analyze Returns: bool: True if the component is a class definition, False otherwise """ return "class " in code.split('\n')[0] def get_custom_prompt(self, code: str) -> str: """Get the appropriate system prompt based on the component type. Args: code: The code component to analyze Returns: str: The appropriate system prompt for the component type """ is_class = Writer.is_class_component(code) specific_prompt = self.class_prompt if is_class else self.function_prompt return specific_prompt def extract_docstring(self, response: str) -> str: """Extract the docstring from the LLM response. Args: response: The full response from the LLM containing the docstring between XML tags Returns: str: The extracted docstring, or empty string if no docstring found """ start_tag = "" end_tag = "" try: start_idx = response.index(start_tag) + len(start_tag) end_idx = response.index(end_tag) return response[start_idx:end_idx].strip() except ValueError: import logging logger = logging.getLogger(__name__) logger.warning("\033[93mError parsing, no DOCSTRING XML tags found in response, directly return the response as docstring %s\033[0m") return response def process( self, focal_component: str, context: Dict[str, Any], ) -> str: """Generate a docstring for the given code component. Args: focal_component: The code component needing a docstring context: Dictionary containing gathered context information Returns: str: The generated docstring following the specified format """ task_description = f""" Available context: {context} {self.get_custom_prompt(focal_component)} Now, generate a high-quality docstring for the following Code Component based on the Available context: {focal_component} Keep in mind: 1. Generate docstring between XML tag: and 2. First analysis the code component and then generate the docstring at the end based on the context. 3. Do not add triple quotes (\"\"\") to your generated docstring. 4. Always double check if the generated docstring is within the XML tags: and . This is critical for parsing the docstring. """ self.add_to_memory("user", task_description) # Generate response using LLM full_response = self.generate_response() # Extract and return just the docstring part return self.extract_docstring(full_response) ================================================ FILE: src/analyze_helpfulness_significance.py ================================================ #!/usr/bin/env python # Copyright (c) Meta Platforms, Inc. and affiliates """ Script to analyze statistical significance between docstring helpfulness scores of different systems. Usage: conda activate docstringgen python src/analyze_significance.py """ import json import os import argparse import numpy as np from scipy import stats import pandas as pd from typing import Dict, List, Tuple, Any def load_results(filepath: str) -> Dict[str, Any]: """Load the helpfulness evaluation results from JSON file.""" with open(filepath, 'r') as f: return json.load(f) def get_system_scores(results: Dict[str, Any], system: str) -> Dict[str, List[int]]: """ Extract scores for a specific system, organized by aspect. Returns: Dictionary mapping aspect to list of scores """ system_results = [r for r in results["results"] if r["system"] == system] scores_by_aspect = {} for result in system_results: aspect = result["aspect"] score = result["score"] if aspect not in scores_by_aspect: scores_by_aspect[aspect] = [] scores_by_aspect[aspect].append(score) return scores_by_aspect def get_paired_scores(results: Dict[str, Any], system1: str, system2: str) -> Dict[str, Tuple[List[int], List[int]]]: """ Extract paired scores for two systems, organized by aspect. Only includes components that have scores for both systems. Returns: Dictionary mapping aspect to tuple of (system1_scores, system2_scores) """ # Get all component IDs evaluated by both systems system1_results = [r for r in results["results"] if r["system"] == system1] system2_results = [r for r in results["results"] if r["system"] == system2] system1_components = {(r["component_id"], r["aspect"]): r for r in system1_results} system2_components = {(r["component_id"], r["aspect"]): r for r in system2_results} # Find common component-aspect pairs common_pairs = set(system1_components.keys()).intersection(system2_components.keys()) # Organize paired scores by aspect paired_scores = {} for component_id, aspect in common_pairs: if aspect not in paired_scores: paired_scores[aspect] = ([], []) paired_scores[aspect][0].append(system1_components[(component_id, aspect)]["score"]) paired_scores[aspect][1].append(system2_components[(component_id, aspect)]["score"]) return paired_scores def run_significance_tests(results: Dict[str, Any]) -> Dict[str, Any]: """ Run statistical significance tests between specified system pairs. Returns: Dictionary with test results """ system_pairs = [ ("copy_paste_codellama34b", "docassist-codellama34b"), ("copy_paste_gpt4o_mini", "docassist-gpt4o_mini"), ("fim-codellama13b", "docassist-codellama34b") ] significance_results = {} for system1, system2 in system_pairs: pair_key = f"{system1} vs {system2}" significance_results[pair_key] = {} # Get paired scores for the two systems paired_scores = get_paired_scores(results, system1, system2) # Calculate overall paired scores across all aspects all_scores_sys1 = [] all_scores_sys2 = [] for aspect, (scores1, scores2) in paired_scores.items(): all_scores_sys1.extend(scores1) all_scores_sys2.extend(scores2) # Run tests for each aspect if len(scores1) >= 5: # Only run tests if we have enough samples # Perform Wilcoxon signed-rank test (non-parametric paired test) try: w_stat, p_value = stats.wilcoxon(scores1, scores2) is_significant = p_value < 0.05 better_system = system2 if np.mean(scores2) > np.mean(scores1) else system1 significance_results[pair_key][aspect] = { "mean_1": np.mean(scores1), "mean_2": np.mean(scores2), "p_value": p_value, "is_significant": is_significant, "better_system": better_system if is_significant else "No significant difference", "n_samples": len(scores1) } except ValueError as e: # This can happen if the differences are all zero significance_results[pair_key][aspect] = { "mean_1": np.mean(scores1), "mean_2": np.mean(scores2), "p_value": 1.0, "is_significant": False, "better_system": "No significant difference", "n_samples": len(scores1), "note": "Test could not be performed: " + str(e) } # Run test for overall scores if len(all_scores_sys1) >= 5: try: w_stat, p_value = stats.wilcoxon(all_scores_sys1, all_scores_sys2) is_significant = p_value < 0.05 better_system = system2 if np.mean(all_scores_sys2) > np.mean(all_scores_sys1) else system1 significance_results[pair_key]["overall"] = { "mean_1": np.mean(all_scores_sys1), "mean_2": np.mean(all_scores_sys2), "p_value": p_value, "is_significant": is_significant, "better_system": better_system if is_significant else "No significant difference", "n_samples": len(all_scores_sys1) } except ValueError as e: significance_results[pair_key]["overall"] = { "mean_1": np.mean(all_scores_sys1), "mean_2": np.mean(all_scores_sys2), "p_value": 1.0, "is_significant": False, "better_system": "No significant difference", "n_samples": len(all_scores_sys1), "note": "Test could not be performed: " + str(e) } return significance_results def format_significance_markdown(significance_results: Dict[str, Any]) -> str: """Format significance test results as markdown.""" md = "## Statistical Significance Tests\n\n" md += "Statistical significance was assessed using the Wilcoxon signed-rank test with a significance level of p < 0.05.\n\n" for pair_key, pair_results in significance_results.items(): md += f"### {pair_key}\n\n" # Create a table for this pair md += "| Aspect | System 1 Mean | System 2 Mean | p-value | Significant? | Better System | n |\n" md += "| ------ | ------------ | ------------ | ------- | ------------ | ------------- | --- |\n" # Add overall results first if "overall" in pair_results: overall = pair_results["overall"] md += f"| Overall | {overall['mean_1']:.2f} | {overall['mean_2']:.2f} | {overall['p_value']:.4f} | {overall['is_significant']} | {overall['better_system']} | {overall['n_samples']} |\n" # Add results for each aspect for aspect, results in pair_results.items(): if aspect != "overall": md += f"| {aspect.capitalize()} | {results['mean_1']:.2f} | {results['mean_2']:.2f} | {results['p_value']:.4f} | {results['is_significant']} | {results['better_system']} | {results['n_samples']} |\n" md += "\n" return md def update_markdown_report(stats_path: str, significance_md: str): """Update the markdown report to include significance test results.""" with open(stats_path, 'r') as f: content = f.read() # Append significance test results updated_content = content + "\n" + significance_md with open(stats_path, 'w') as f: f.write(updated_content) def main(): parser = argparse.ArgumentParser(description="Analyze statistical significance of docstring helpfulness") parser.add_argument("--results-path", type=str, default="experiments/eval/results/helpfulness/helpfulness_evaluation_results.json", help="Path to the helpfulness evaluation results JSON") parser.add_argument("--stats-path", type=str, default="experiments/eval/results/helpfulness/helpfulness_evaluation_stats.md", help="Path to the helpfulness evaluation stats markdown file") parser.add_argument("--output-dir", type=str, default="experiments/eval/results/helpfulness", help="Directory to store significance test results") args = parser.parse_args() # Check if result file exists if not os.path.exists(args.results_path): print(f"Error: Results file not found at {args.results_path}") return # Load results results = load_results(args.results_path) # Run significance tests significance_results = run_significance_tests(results) # Format results as markdown significance_md = format_significance_markdown(significance_results) # Save significance test results as separate file significance_path = os.path.join(args.output_dir, "significance_tests.md") with open(significance_path, 'w') as f: f.write(significance_md) # Update the stats markdown file if os.path.exists(args.stats_path): update_markdown_report(args.stats_path, significance_md) print(f"Significance test results saved to {significance_path}") if os.path.exists(args.stats_path): print(f"Updated stats report with significance tests at {args.stats_path}") if __name__ == "__main__": main() ================================================ FILE: src/data/parse/data_process.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import os import ast import json from tqdm import tqdm import argparse import re from langdetect import detect def is_english(text): """Check if text contains only English using langdetect.""" try: return detect(text) == 'en' and text.isascii() except: return False def is_high_quality_file_docstring(docstring): """Heuristic for file-level docstrings: - At least one meaningful sentence and length ≥ 10 chars.""" if not docstring or len(docstring.strip()) < 10: return False else: return True # Check if it seems like a sentenc def is_high_quality_class_docstring(docstring): """Heuristic for class docstrings: - At least 2 lines - Possibly mentions common docstring sections (Attributes, Args, Returns)""" if not docstring: return False lines = docstring.strip().split('\n') if len(lines) < 2: return False keywords = ["Attributes", "Args", "Returns", "Example", "Methods", "Param", "arguments", "Parameters"] if any(kw in docstring for kw in keywords): return True # If at least moderately long, consider it acceptable if len(docstring.strip()) > 30: return True return False def is_high_quality_function_docstring(docstring): """Heuristic for function or class method docstrings: - At least 3 lines - Mention parameters, args, or returns """ if not docstring: return False lines = docstring.strip().split('\n') if len(lines) < 3: return False keywords = ["Parameters", "Args", "Returns", "Param", "arguments"] if any(kw.lower() in docstring.lower() for kw in keywords): return True # If reasonably long (>30 chars), consider it good if len(docstring.strip()) > 30: return True return False def is_high_quality_docstring(docstring, doc_type): """Check if docstring meets quality criteria and is in English.""" if not docstring: return False # First check if it's English if not is_english(docstring): return False # Then apply other quality checks if doc_type == "file": return is_high_quality_file_docstring(docstring) elif doc_type == "class": return is_high_quality_class_docstring(docstring) elif doc_type in ("function", "class_method"): return is_high_quality_function_docstring(docstring) return False def get_repo_name_from_path(path): """Extract repo name from path like: data/downloaded_repos/USERNAME/REPO_NAME""" parts = path.split(os.sep) try: # Find the index where the username starts (after downloaded_repos) for i, part in enumerate(parts): if part == "downloaded_repos": # Return username/repo_name format return f"{parts[i+1]}/{parts[i+2]}" except IndexError: pass return None def extract_docstrings_from_file(file_path): """ Parse a single Python file with AST and extract: - file-level docstring - class-level docstrings - function-level docstrings (including class methods) """ with open(file_path, "r", encoding="utf-8", errors='replace') as f: source = f.read() try: tree = ast.parse(source) except SyntaxError: return [] docstrings_info = [] repo_name = get_repo_name_from_path(file_path) # File-level docstring module_docstring = ast.get_docstring(tree) if is_high_quality_docstring(module_docstring, "file"): signature = f"File: {os.path.basename(file_path)}" docstrings_info.append({ "type": "file", "location": file_path, "repo_name": repo_name, "content": module_docstring.strip(), "signature": signature }) # Classes and functions for node in ast.walk(tree): if isinstance(node, ast.ClassDef): class_docstring = ast.get_docstring(node) if hasattr(ast, "unparse"): bases = [ast.unparse(base) for base in node.bases] else: # fallback for older python versions: just get the name of base classes if simple bases = [] for base in node.bases: if isinstance(base, ast.Name): bases.append(base.id) else: # If complex base, just ignore bases.append("Base") class_signature = f"class {node.name}" if bases: class_signature += f"({', '.join(bases)})" if is_high_quality_docstring(class_docstring, "class"): docstrings_info.append({ "type": "class", "location": file_path, "repo_name": repo_name, "content": class_docstring.strip(), "signature": class_signature }) # Class methods for body_item in node.body: if isinstance(body_item, ast.FunctionDef): func_docstring = ast.get_docstring(body_item) args_list = [arg.arg for arg in body_item.args.args] func_signature = f"def {body_item.name}({', '.join(args_list)})" if is_high_quality_docstring(func_docstring, "class_method"): docstrings_info.append({ "type": "class_method", "location": file_path, "repo_name": repo_name, "content": func_docstring.strip(), "signature": func_signature }) elif isinstance(node, ast.FunctionDef): # Top-level functions if isinstance(node.parent, ast.Module): # We'll add a small hack to set parents func_docstring = ast.get_docstring(node) args_list = [arg.arg for arg in node.args.args] func_signature = f"def {node.name}({', '.join(args_list)})" if is_high_quality_docstring(func_docstring, "function"): docstrings_info.append({ "type": "function", "location": file_path, "repo_name": repo_name, "content": func_docstring.strip(), "signature": func_signature }) return docstrings_info def add_parent_references(tree): """Add parent references to nodes, so we can distinguish top-level functions from class methods easily.""" for node in ast.walk(tree): for child in ast.iter_child_nodes(node): child.parent = node def gather_python_files(top_dir): py_files = [] for root, dirs, files in os.walk(top_dir): for file in files: if file.endswith(".py"): py_files.append(os.path.join(root, file)) return py_files def process_all_repos(top_dir, output_file): """Process all repositories and extract docstrings. Args: top_dir (str): Path to directory containing downloaded repos output_file (str): Path where to save the output JSONL file """ py_files = gather_python_files(top_dir) # Setup output file # We'll write each docstring object as a single JSON line. # This allows incremental updates without invalidating JSON format. with open(output_file, "w", encoding="utf-8") as out_f: # Using tqdm to show progress over Python files for file_path in tqdm(py_files, desc="Processing files"): # Parse the file and extract docstrings with open(file_path, "r", encoding="utf-8", errors='replace') as f: source = f.read() try: tree = ast.parse(source) add_parent_references(tree) except SyntaxError: # Skip files that have syntax errors continue docstrings = [] # File-level docstring repo_name = get_repo_name_from_path(file_path) module_docstring = ast.get_docstring(tree) if is_high_quality_docstring(module_docstring, "file"): docstrings.append({ "type": "file", "location": file_path, "repo_name": repo_name, "content": module_docstring.strip(), "signature": f"File: {os.path.basename(file_path)}" }) for node in ast.walk(tree): if isinstance(node, ast.ClassDef): class_docstring = ast.get_docstring(node) if hasattr(ast, "unparse"): bases = [ast.unparse(base) for base in node.bases] else: bases = [] for base in node.bases: if isinstance(base, ast.Name): bases.append(base.id) else: bases.append("Base") class_signature = f"class {node.name}" if bases: class_signature += f"({', '.join(bases)})" if is_high_quality_docstring(class_docstring, "class"): docstrings.append({ "type": "class", "location": file_path, "repo_name": repo_name, "content": class_docstring.strip(), "signature": class_signature }) # Class methods for body_item in node.body: if isinstance(body_item, ast.FunctionDef): func_docstring = ast.get_docstring(body_item) args_list = [arg.arg for arg in body_item.args.args] func_signature = f"def {body_item.name}({', '.join(args_list)})" if is_high_quality_docstring(func_docstring, "class_method"): docstrings.append({ "type": "class_method", "location": file_path, "repo_name": repo_name, "content": func_docstring.strip(), "signature": func_signature }) elif isinstance(node, ast.FunctionDef): # Check if top-level (parent is module) if isinstance(node.parent, ast.Module): func_docstring = ast.get_docstring(node) args_list = [arg.arg for arg in node.args.args] func_signature = f"def {node.name}({', '.join(args_list)})" if is_high_quality_docstring(func_docstring, "function"): docstrings.append({ "type": "function", "location": file_path, "repo_name": repo_name, "content": func_docstring.strip(), "signature": func_signature }) # Write each docstring as a separate JSON line immediately for d in docstrings: out_f.write(json.dumps(d, ensure_ascii=False) + "\n") out_f.flush() def main(): parser = argparse.ArgumentParser(description='Process Python files for docstrings') parser.add_argument('--input-dir', type=str, default="data/downloaded_repos", help='Input directory containing downloaded repos') parser.add_argument('--output-file', type=str, default="data/parsed_downloaded_repos/docstrings.jsonl", help='Output JSONL file path') args = parser.parse_args() process_all_repos(top_dir=args.input_dir, output_file=args.output_file) if __name__ == "__main__": main() ================================================ FILE: src/data/parse/downloader.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import yaml import os import logging from github import Github from pathlib import Path import git from typing import Dict, Any, List import time from datetime import datetime from tqdm import tqdm import json class GitHubRepoDownloader: def __init__(self, config_path: str): self.config = self._load_config(config_path) self.token = self.config.get('GITHUB_TOKEN') if not self.token: raise ValueError("GITHUB_TOKEN not found in config file") self.gh = Github(self.token) self.setup_logging() def _load_config(self, config_path: str) -> Dict[str, Any]: try: with open(config_path, 'r') as f: config = yaml.safe_load(f) or {} if 'search_criteria' not in config: config['search_criteria'] = {} return config except yaml.YAMLError as e: logging.error(f"Error parsing YAML file: {e}") raise except FileNotFoundError: logging.error(f"Config file not found: {config_path}") raise def setup_logging(self): log_filename = f"github_downloader_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_filename), logging.StreamHandler() ] ) def build_query(self) -> str: """Build GitHub search query from config.""" criteria = self.config.get('search_criteria', {}) query_parts = [] # Handle owners/users if owners := criteria.get('owners'): if isinstance(owners, list): query_parts.extend(f"user:{owner}" for owner in owners) else: query_parts.append(f"user:{owners}") # Handle dates - ensure proper date format and use created: qualifier dates = criteria.get('dates', {}) if created_after := dates.get('created_after'): # GitHub's search API requires YYYY-MM-DD format if isinstance(created_after, datetime): created_after = created_after.strftime('%Y-%m-%d') query_parts.append(f"created:>{created_after}") if created_before := dates.get('created_before'): if isinstance(created_before, datetime): created_before = created_before.strftime('%Y-%m-%d') query_parts.append(f"created:<{created_before}") # Handle language if language := criteria.get('language'): if isinstance(language, list): query_parts.append(f"language:{language[0]}") # GitHub API limitation: one language at a time else: query_parts.append(f"language:{language}") # Handle stars if stars := criteria.get('stars'): if isinstance(stars, dict): if min_stars := stars.get('min'): query_parts.append(f"stars:>{min_stars}") if max_stars := stars.get('max'): query_parts.append(f"stars:<{max_stars}") else: query_parts.append(f"stars:>{stars}") # Handle forks if forks := criteria.get('forks'): if isinstance(forks, dict): if min_forks := forks.get('min'): query_parts.append(f"forks:>{min_forks}") if max_forks := forks.get('max'): query_parts.append(f"forks:<{max_forks}") else: query_parts.append(f"forks:>{forks}") # Handle size if size := criteria.get('size'): if isinstance(size, dict): if min_size := size.get('min'): query_parts.append(f"size:>{min_size}") if max_size := size.get('max'): query_parts.append(f"size:<{max_size}") else: query_parts.append(f"size:>{size}") # Handle license if license_type := criteria.get('license'): if isinstance(license_type, list): query_parts.append(f"license:{license_type[0]}") # GitHub API limitation: one license at a time else: query_parts.append(f"license:{license_type}") query = ' '.join(query_parts) if query_parts else "is:public" logging.info(f"Search query: {query}") return query def clone_repository(self, repo, output_dir: Path) -> bool: """Clone a repository using GitPython.""" repo_dir = output_dir / repo.full_name if repo_dir.exists(): logging.info(f"Repository directory already exists: {repo_dir}") return False try: # Create clone URL with token clone_url = f"https://{self.token}@github.com/{repo.full_name}.git" # Clone the repository git.Repo.clone_from(clone_url, str(repo_dir)) # Save repository metadata metadata = { 'name': repo.name, 'full_name': repo.full_name, 'description': repo.description, 'stars': repo.stargazers_count, 'forks': repo.forks_count, 'language': repo.language, 'license': repo.license.name if repo.license else None, 'created_at': repo.created_at.isoformat() if repo.created_at else None, 'updated_at': repo.updated_at.isoformat() if repo.updated_at else None, 'topics': repo.get_topics(), 'size': repo.size, 'clone_time': datetime.now().isoformat(), } with open(repo_dir / 'repo_metadata.yaml', 'w') as f: yaml.dump(metadata, f) logging.info(f"Successfully cloned: {repo.full_name}") return True except Exception as e: logging.error(f"Error cloning repository {repo.full_name}: {e}") return False def run(self): output_dir = Path(self.config.get('output_directory', 'downloaded_repos')) output_dir.mkdir(parents=True, exist_ok=True) # Initialize or load existing metadata file meta_file = output_dir / 'repositories_metadata.json' if meta_file.exists(): with open(meta_file, 'r') as f: all_metadata = json.load(f) else: all_metadata = { 'download_session': datetime.now().isoformat(), 'search_query': self.build_query(), 'repositories': {} } max_repos = self.config.get('max_repos', 5) skip_archived = self.config.get('skip_archived', True) skip_forks = self.config.get('skip_forks', True) min_python_percentage = self.config.get('min_python_percentage', 80) # Default to 80% if not specified # Get date filters from config dates = self.config.get('search_criteria', {}).get('dates', {}) created_after = dates.get('created_after') if isinstance(created_after, str): created_after = datetime.fromisoformat(created_after.replace('Z', '+00:00')) created_before = dates.get('created_before') if isinstance(created_before, str): created_before = datetime.fromisoformat(created_before.replace('Z', '+00:00')) query = self.build_query() logging.info(f"Starting repository search with query: {query}") try: repos = self.gh.search_repositories( query=query, sort=self.config.get('search_criteria', {}).get('sort', 'stars'), order=self.config.get('search_criteria', {}).get('order', 'desc') ) total_count = repos.totalCount logging.info(f"Found {total_count} repositories matching the search criteria") downloaded = 0 pbar = tqdm(total=max_repos, desc="Downloading repositories") for repo in repos: if downloaded >= max_repos: break if skip_archived and repo.archived: logging.info(f"Skipping archived repository: {repo.full_name}") continue if skip_forks and repo.fork: logging.info(f"Skipping forked repository: {repo.full_name}") continue # Check Python language percentage try: languages = repo.get_languages() total_bytes = sum(languages.values()) python_bytes = languages.get('Python', 0) if total_bytes > 0: python_percentage = (python_bytes / total_bytes) * 100 if python_percentage < min_python_percentage: logging.info(f"Skipping repository {repo.full_name}: Python code is only {python_percentage:.2f}% (required: {min_python_percentage}%)") continue logging.info(f"Repository {repo.full_name} has {python_percentage:.2f}% Python code") elif min_python_percentage > 0: logging.info(f"Skipping repository {repo.full_name}: No language data available") continue except Exception as e: logging.warning(f"Couldn't check language stats for {repo.full_name}: {e}") # Continue even if we can't check language stats, to avoid missing potentially valid repositories if self.clone_repository(repo, output_dir): # Add repository metadata to the collective metadata metadata = { 'name': repo.name, 'full_name': repo.full_name, 'description': repo.description, 'stars': repo.stargazers_count, 'forks': repo.forks_count, 'language': repo.language, 'license': repo.license.name if repo.license else None, 'created_at': repo.created_at.isoformat() if repo.created_at else None, 'updated_at': repo.updated_at.isoformat() if repo.updated_at else None, 'topics': repo.get_topics(), 'size': repo.size, 'clone_time': datetime.now().isoformat(), 'local_path': str(output_dir / repo.full_name) } all_metadata['repositories'][repo.full_name] = metadata # Update the metadata file after each successful download with open(meta_file, 'w') as f: json.dump(all_metadata, f, indent=2) downloaded += 1 pbar.update(1) # Respect GitHub API rate limits time.sleep(1) pbar.close() logging.info(f"Successfully downloaded {downloaded} repositories") logging.info(f"Metadata file created at: {meta_file}") except Exception as e: logging.error(f"Error during repository download process: {e}") raise if __name__ == "__main__": try: downloader = GitHubRepoDownloader("config/download_repo_config.yaml") downloader.run() except Exception as e: logging.error(f"Fatal error: {e}") raise ================================================ FILE: src/data/parse/repo_tree.py ================================================ #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates import os import argparse from pathlib import Path import json from typing import Dict, List, Optional class ProjectStructureGenerator: def __init__(self, ignore_patterns: List[str] = None): self.ignore_patterns = ignore_patterns or [ '.git', '__pycache__', '.pytest_cache', '.env', 'venv', 'node_modules', '.DS_Store', '*.pyc', '*.pyo', '*.pyd', '.Python', '*.so' ] def should_ignore(self, path: str) -> bool: """Check if the path should be ignored based on patterns.""" path_obj = Path(path) return any( path_obj.match(pattern) or any(parent.match(pattern) for parent in path_obj.parents) for pattern in self.ignore_patterns ) def generate_structure(self, root_path: str, max_depth: Optional[int] = None) -> Dict: """Generate a hierarchical structure of the project.""" root_path = os.path.abspath(root_path) root_name = os.path.basename(root_path) def explore_directory(current_path: str, current_depth: int = 0) -> Dict: if max_depth is not None and current_depth > max_depth: return {"type": "directory", "name": os.path.basename(current_path), "truncated": True} structure = { "type": "directory", "name": os.path.basename(current_path), "contents": [] } try: for item in sorted(os.listdir(current_path)): item_path = os.path.join(current_path, item) if self.should_ignore(item_path): continue if os.path.isfile(item_path): file_info = { "type": "file", "name": item, "extension": os.path.splitext(item)[1][1:] or "none" } structure["contents"].append(file_info) elif os.path.isdir(item_path): subdir = explore_directory(item_path, current_depth + 1) if subdir.get("contents") or not subdir.get("truncated"): structure["contents"].append(subdir) except PermissionError: structure["error"] = "Permission denied" return structure return explore_directory(root_path) def format_structure(self, structure: Dict, indent: int = 0) -> str: """Format the structure in a hierarchical text format.""" output = [] prefix = "│ " * (indent - 1) + "├── " if indent > 0 else "" if structure.get("truncated"): output.append(f"{prefix}{structure['name']} [...]") return "\n".join(output) output.append(f"{prefix}{structure['name']}/") if "contents" in structure: for i, item in enumerate(structure["contents"]): is_last = i == len(structure["contents"]) - 1 if item["type"] == "file": item_prefix = "│ " * indent + ("└── " if is_last else "├── ") output.append(f"{item_prefix}{item['name']}") else: output.append(self.format_structure(item, indent + 1)) return "\n".join(output) def main(): parser = argparse.ArgumentParser( description="Generate a project structure in LLM-friendly format" ) parser.add_argument( "path", nargs="?", default=".", help="Path to the project directory (default: current directory)" ) parser.add_argument( "--max-depth", type=int, help="Maximum depth to traverse (default: no limit)" ) parser.add_argument( "--output", choices=["text", "json"], default="text", help="Output format (default: text)" ) parser.add_argument( "--ignore", nargs="+", help="Additional patterns to ignore" ) args = parser.parse_args() generator = ProjectStructureGenerator() if args.ignore: generator.ignore_patterns.extend(args.ignore) structure = generator.generate_structure(args.path, args.max_depth) if args.output == "json": print(json.dumps(structure, indent=2)) else: print(generator.format_structure(structure)) if __name__ == "__main__": main() ================================================ FILE: src/dependency_analyzer/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Dependency analyzer module for building and processing import dependency graphs between Python code components. """ from .ast_parser import CodeComponent, DependencyParser from .topo_sort import topological_sort, resolve_cycles, build_graph_from_components, dependency_first_dfs __all__ = [ 'CodeComponent', 'DependencyParser', 'topological_sort', 'resolve_cycles', 'build_graph_from_components', 'dependency_first_dfs' ] ================================================ FILE: src/dependency_analyzer/ast_parser.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ AST-based Python code parser that extracts dependency information between code components. This module identifies imports and references between Python code components (functions, classes, methods) and builds a dependency graph for topological sorting. """ import ast import os import json import logging import builtins from dataclasses import dataclass, field from typing import Dict, List, Set, Tuple, Optional, Any, Union from pathlib import Path logger = logging.getLogger(__name__) # Built-in Python types and modules that should be excluded from dependencies BUILTIN_TYPES = {name for name in dir(builtins)} STANDARD_MODULES = { 'abc', 'argparse', 'array', 'asyncio', 'base64', 'collections', 'copy', 'csv', 'datetime', 'enum', 'functools', 'glob', 'io', 'itertools', 'json', 'logging', 'math', 'os', 'pathlib', 'random', 're', 'shutil', 'string', 'sys', 'time', 'typing', 'uuid', 'warnings', 'xml' } EXCLUDED_NAMES = {'self', 'cls'} @dataclass class CodeComponent: """ Represents a single code component (function, class, or method) in a Python codebase. Stores the component's identifier, AST node, dependencies, and other metadata. """ # Unique identifier for the component, format: module_path.ClassName.method_name id: str # AST node representing this component node: ast.AST # Type of component: 'class', 'function', or 'method' component_type: str # Full path to the file containing this component file_path: str # Relative path within the repo relative_path: str # Set of component IDs this component depends on depends_on: Set[str] = field(default_factory=set) # Original source code of the component source_code: Optional[str] = None # Line numbers in the file (1-indexed) start_line: int = 0 end_line: int = 0 # Whether the component already has a docstring has_docstring: bool = False # Content of the docstring if it exists, empty string otherwise docstring: str = "" def to_dict(self) -> Dict[str, Any]: """Convert this component to a dictionary representation for JSON serialization.""" return { 'id': self.id, 'component_type': self.component_type, 'file_path': self.file_path, 'relative_path': self.relative_path, 'depends_on': list(self.depends_on), 'start_line': self.start_line, 'end_line': self.end_line, 'has_docstring': self.has_docstring, 'docstring': self.docstring } @staticmethod def from_dict(data: Dict[str, Any]) -> 'CodeComponent': """Create a CodeComponent from a dictionary representation.""" component = CodeComponent( id=data['id'], node=None, # AST node is not serialized component_type=data['component_type'], file_path=data['file_path'], relative_path=data['relative_path'], depends_on=set(data.get('depends_on', [])), start_line=data.get('start_line', 0), end_line=data.get('end_line', 0), has_docstring=data.get('has_docstring', False), docstring=data.get('docstring', "") ) return component class ImportCollector(ast.NodeVisitor): """Collects import statements from Python code.""" def __init__(self): self.imports = set() self.from_imports = {} # module -> [names] def visit_Import(self, node: ast.Import): """Process 'import x' statements.""" for name in node.names: self.imports.add(name.name) self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom): """Process 'from x import y' statements.""" if node.module is not None: module = node.module if module not in self.from_imports: self.from_imports[module] = [] for name in node.names: if name.name != '*': self.from_imports[module].append(name.name) self.generic_visit(node) class MethodDependencyCollector(ast.NodeVisitor): """ Special dependency collector for methods that also tracks 'self.XXX' references as potential dependencies. """ def __init__(self, class_id: str, method_id: str, class_methods: Dict[str, str]): self.class_id = class_id self.method_id = method_id self.class_methods = class_methods # method_name -> full_method_id self.self_attr_refs = set() # Set of attributes accessed via self.XXX def visit_Attribute(self, node: ast.Attribute): """Process attribute access, specifically looking for self.XXX references.""" if (isinstance(node.value, ast.Name) and node.value.id == 'self' and isinstance(node.ctx, ast.Load)): # Found a self.XXX reference attr_name = node.attr self.self_attr_refs.add(attr_name) self.generic_visit(node) def get_method_dependencies(self) -> Set[str]: """ Get the set of methods that this method depends on based on self.XXX references. Returns: A set of method IDs that this method depends on """ dependencies = set() # Check if any self.attr references match method names for attr in self.self_attr_refs: if attr in self.class_methods: # This is a reference to another method in the class dependencies.add(self.class_methods[attr]) return dependencies class DependencyCollector(ast.NodeVisitor): """ Collects dependencies between code components by analyzing attribute access, function calls, and class references. """ def __init__(self, imports, from_imports, current_module, repo_modules): self.imports = imports self.from_imports = from_imports self.current_module = current_module self.repo_modules = repo_modules self.dependencies = set() self._current_class = None # Track local variables defined in the current context self.local_variables = set() def visit_ClassDef(self, node: ast.ClassDef): """Process class definitions.""" old_class = self._current_class self._current_class = node.name # Check for base classes dependencies for base in node.bases: if isinstance(base, ast.Name): # Simple name reference, could be an imported class self._add_dependency(base.id) elif isinstance(base, ast.Attribute): # Module.Class reference self._process_attribute(base) self.generic_visit(node) self._current_class = old_class def visit_Assign(self, node: ast.Assign): """Track local variable assignments.""" for target in node.targets: if isinstance(target, ast.Name): # Add to local variables self.local_variables.add(target.id) self.generic_visit(node) def visit_Call(self, node: ast.Call): """Process function calls.""" if isinstance(node.func, ast.Name): # Direct function call self._add_dependency(node.func.id) elif isinstance(node.func, ast.Attribute): # Method call or module.function call self._process_attribute(node.func) self.generic_visit(node) def visit_Name(self, node: ast.Name): """Process name references.""" if isinstance(node.ctx, ast.Load): self._add_dependency(node.id) self.generic_visit(node) def visit_Attribute(self, node: ast.Attribute): """Process attribute access.""" self._process_attribute(node) self.generic_visit(node) def _process_attribute(self, node: ast.Attribute): """Process an attribute node to extract potential dependencies.""" parts = [] current = node # Traverse the attribute chain (e.g., module.submodule.Class.method) while isinstance(current, ast.Attribute): parts.insert(0, current.attr) current = current.value if isinstance(current, ast.Name): parts.insert(0, current.id) # Skip if the first part is a local variable if parts[0] in self.local_variables: return # Skip if the first part is in our excluded names if parts[0] in EXCLUDED_NAMES: return # Check if the first part is an imported module if parts[0] in self.imports: module_path = parts[0] # Skip standard library modules if module_path in STANDARD_MODULES: return # If it's a repo module, add as dependency if module_path in self.repo_modules: if len(parts) > 1: # Example: module.Class or module.function self.dependencies.add(f"{module_path}.{parts[1]}") # Check from imports elif parts[0] in self.from_imports.keys(): # Skip standard library modules if parts[0] in STANDARD_MODULES: return # Check if the name is in the imported names if len(parts) > 1 and parts[1] in self.from_imports[parts[0]]: self.dependencies.add(f"{parts[0]}.{parts[1]}") def _add_dependency(self, name): """Add a potential dependency based on a name reference.""" # Skip built-in types if name in BUILTIN_TYPES: return # Skip excluded names if name in EXCLUDED_NAMES: return # Skip local variables if name in self.local_variables: return # Check if name is directly imported from a module for module, imported_names in self.from_imports.items(): # Skip standard library modules if module in STANDARD_MODULES: continue if name in imported_names and module in self.repo_modules: self.dependencies.add(f"{module}.{name}") return # Check if name refers to a component in the current module local_component_id = f"{self.current_module}.{name}" self.dependencies.add(local_component_id) def add_parent_to_nodes(tree: ast.AST) -> None: """ Add a 'parent' attribute to each node in the AST. Args: tree: The AST to process """ for node in ast.walk(tree): for child in ast.iter_child_nodes(node): child.parent = node class DependencyParser: """ Parses Python code to build a dependency graph between code components. """ def __init__(self, repo_path: str): self.repo_path = os.path.abspath(repo_path) self.components: Dict[str, CodeComponent] = {} self.dependency_graph: Dict[str, List[str]] = {} self.modules: Set[str] = set() def parse_repository(self): """ Parse all Python files in the repository to build the dependency graph. """ logger.info(f"Parsing repository at {self.repo_path}") # First pass: collect all modules and code components for root, _, files in os.walk(self.repo_path): for file in files: if not file.endswith(".py"): continue file_path = os.path.join(root, file) relative_path = os.path.relpath(file_path, self.repo_path) # Convert file path to module path module_path = self._file_to_module_path(relative_path) self.modules.add(module_path) # Parse the file to collect components self._parse_file(file_path, relative_path, module_path) # Second pass: resolve dependencies self._resolve_dependencies() # Third pass: add class dependencies on methods self._add_class_method_dependencies() logger.info(f"Found {len(self.components)} code components") return self.components def _file_to_module_path(self, file_path: str) -> str: """Convert a file path to a Python module path.""" # Remove .py extension and convert / to . path = file_path[:-3] if file_path.endswith(".py") else file_path return path.replace(os.path.sep, ".") def _parse_file(self, file_path: str, relative_path: str, module_path: str): """Parse a single Python file to collect code components.""" try: with open(file_path, "r", encoding="utf-8") as f: source = f.read() tree = ast.parse(source) # Add parent field to AST nodes for easier traversal add_parent_to_nodes(tree) # Collect imports import_collector = ImportCollector() import_collector.visit(tree) # Collect code components self._collect_components(tree, file_path, relative_path, module_path, source) except (SyntaxError, UnicodeDecodeError) as e: logger.warning(f"Error parsing {file_path}: {e}") def _collect_components(self, tree: ast.AST, file_path: str, relative_path: str, module_path: str, source: str): """Collect all code components (functions, classes, methods) from an AST.""" for node in ast.walk(tree): if isinstance(node, ast.ClassDef): # Class definition class_id = f"{module_path}.{node.name}" # Check if the class has a docstring has_docstring = ( len(node.body) > 0 and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str) ) # Extract docstring if it exists docstring = self._get_docstring(source, node) if has_docstring else "" component = CodeComponent( id=class_id, node=node, component_type="class", file_path=file_path, relative_path=relative_path, source_code=self._get_source_segment(source, node), start_line=node.lineno, end_line=getattr(node, "end_lineno", node.lineno), has_docstring=has_docstring, docstring=docstring ) self.components[class_id] = component # Collect methods within the class for item in node.body: if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): method_id = f"{class_id}.{item.name}" # Check if the method has a docstring method_has_docstring = ( len(item.body) > 0 and isinstance(item.body[0], ast.Expr) and isinstance(item.body[0].value, ast.Constant) and isinstance(item.body[0].value.value, str) ) # Extract docstring if it exists method_docstring = self._get_docstring(source, item) if method_has_docstring else "" method_component = CodeComponent( id=method_id, node=item, component_type="method", file_path=file_path, relative_path=relative_path, source_code=self._get_source_segment(source, item), start_line=item.lineno, end_line=getattr(item, "end_lineno", item.lineno), has_docstring=method_has_docstring, docstring=method_docstring ) self.components[method_id] = method_component elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): # Only collect top-level functions if hasattr(node, 'parent') and isinstance(node.parent, ast.Module): func_id = f"{module_path}.{node.name}" # Check if the function has a docstring has_docstring = ( len(node.body) > 0 and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str) ) # Extract docstring if it exists docstring = self._get_docstring(source, node) if has_docstring else "" component = CodeComponent( id=func_id, node=node, component_type="function", file_path=file_path, relative_path=relative_path, source_code=self._get_source_segment(source, node), start_line=node.lineno, end_line=getattr(node, "end_lineno", node.lineno), has_docstring=has_docstring, docstring=docstring ) self.components[func_id] = component def _resolve_dependencies(self): """ Second pass to resolve dependencies between components. """ for component_id, component in self.components.items(): file_path = component.file_path try: with open(file_path, "r", encoding="utf-8") as f: source = f.read() # Parse file to get imports tree = ast.parse(source) # Add parent field to AST nodes for easier traversal add_parent_to_nodes(tree) # Collect imports import_collector = ImportCollector() import_collector.visit(tree) # Find the component node in the tree component_node = None module_path = self._file_to_module_path(component.relative_path) if component.component_type == "function": # Find top-level function for node in ast.iter_child_nodes(tree): if (isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == component.id.split(".")[-1]): component_node = node break elif component.component_type == "class": # Find class for node in ast.iter_child_nodes(tree): if isinstance(node, ast.ClassDef) and node.name == component.id.split(".")[-1]: component_node = node break elif component.component_type == "method": # Find method inside class class_name, method_name = component.id.split(".")[-2:] class_node = None for node in ast.iter_child_nodes(tree): if isinstance(node, ast.ClassDef) and node.name == class_name: class_node = node for item in node.body: if (isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name): component_node = item break break if component_node: # Collect dependencies for this specific component dependency_collector = DependencyCollector( import_collector.imports, import_collector.from_imports, module_path, self.modules ) # For functions and methods, collect variables defined in the function if isinstance(component_node, (ast.FunctionDef, ast.AsyncFunctionDef)): # Add function parameters to local variables for arg in component_node.args.args: dependency_collector.local_variables.add(arg.arg) dependency_collector.visit(component_node) # Add dependencies to the component component.depends_on.update(dependency_collector.dependencies) # Filter out non-existent dependencies component.depends_on = { dep for dep in component.depends_on if dep in self.components or dep.split(".", 1)[0] in self.modules } except (SyntaxError, UnicodeDecodeError) as e: logger.warning(f"Error analyzing dependencies in {file_path}: {e}") def _add_class_method_dependencies(self): """ Third pass to make classes dependent on their methods (except __init__). """ # Group components by class class_methods = {} # Collect all methods for each class for component_id, component in self.components.items(): if component.component_type == "method": parts = component_id.split(".") if len(parts) >= 2: method_name = parts[-1] class_id = ".".join(parts[:-1]) if class_id not in class_methods: class_methods[class_id] = [] # Don't include __init__ methods as dependencies of the class if method_name != "__init__": class_methods[class_id].append(component_id) # Add method dependencies to their classes for class_id, method_ids in class_methods.items(): if class_id in self.components: class_component = self.components[class_id] for method_id in method_ids: class_component.depends_on.add(method_id) def _get_source_segment(self, source: str, node: ast.AST) -> str: """Get source code segment for an AST node.""" try: if hasattr(ast, "get_source_segment"): segment = ast.get_source_segment(source, node) if segment is not None: return segment # Fallback to manual extraction lines = source.split("\n") start_line = node.lineno - 1 end_line = getattr(node, "end_lineno", node.lineno) - 1 return "\n".join(lines[start_line:end_line + 1]) except Exception as e: logger.warning(f"Error getting source segment: {e}") return "" def _get_docstring(self, source: str, node: ast.AST) -> str: """Get the docstring for a given AST node.""" try: if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): for item in node.body: if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant): if isinstance(item.value.value, str): return item.value.value elif isinstance(node, ast.ClassDef): for item in node.body: if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant): if isinstance(item.value.value, str): return item.value.value return "" except Exception as e: logger.warning(f"Error getting docstring: {e}") return "" def save_dependency_graph(self, output_path: str): """Save the dependency graph to a JSON file.""" # Convert to serializable format serializable_components = { comp_id: component.to_dict() for comp_id, component in self.components.items() } # Create directories if they don't exist os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: json.dump(serializable_components, f, indent=2) logger.info(f"Saved dependency graph to {output_path}") def load_dependency_graph(self, input_path: str): """Load the dependency graph from a JSON file.""" with open(input_path, "r", encoding="utf-8") as f: serialized_components = json.load(f) # Convert back to CodeComponent objects self.components = { comp_id: CodeComponent.from_dict(comp_data) for comp_id, comp_data in serialized_components.items() } logger.info(f"Loaded {len(self.components)} components from {input_path}") return self.components ================================================ FILE: src/dependency_analyzer/topo_sort.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Topological sorting utilities for dependency graphs with cycle handling. This module provides functions to perform topological sorting on a dependency graph, including detection and resolution of dependency cycles. """ import logging from typing import Dict, List, Set, Tuple, Any, Optional from collections import defaultdict, deque logger = logging.getLogger(__name__) def detect_cycles(graph: Dict[str, Set[str]]) -> List[List[str]]: """ Detect cycles in a dependency graph using Tarjan's algorithm to find strongly connected components. Args: graph: A dependency graph represented as adjacency lists (node -> set of dependencies) Returns: A list of lists, where each inner list contains the nodes in a cycle """ # Implementation of Tarjan's algorithm index_counter = [0] index = {} # node -> index lowlink = {} # node -> lowlink value onstack = set() # nodes currently on the stack stack = [] # stack of nodes result = [] # list of cycles (strongly connected components) def strongconnect(node): # Set the depth index for node index[node] = index_counter[0] lowlink[node] = index_counter[0] index_counter[0] += 1 stack.append(node) onstack.add(node) # Consider successors for successor in graph.get(node, set()): if successor not in index: # Successor has not yet been visited; recurse on it strongconnect(successor) lowlink[node] = min(lowlink[node], lowlink[successor]) elif successor in onstack: # Successor is on the stack and hence in the current SCC lowlink[node] = min(lowlink[node], index[successor]) # If node is a root node, pop the stack and generate an SCC if lowlink[node] == index[node]: # Start a new strongly connected component scc = [] while True: successor = stack.pop() onstack.remove(successor) scc.append(successor) if successor == node: break # Only include SCCs with more than one node (actual cycles) if len(scc) > 1: result.append(scc) # Visit each node for node in graph: if node not in index: strongconnect(node) return result def resolve_cycles(graph: Dict[str, Set[str]]) -> Dict[str, Set[str]]: """ Resolve cycles in a dependency graph by identifying strongly connected components and breaking cycles. Args: graph: A dependency graph represented as adjacency lists (node -> set of dependencies) Returns: A new acyclic graph with the same nodes but with cycles broken """ # Detect cycles (SCCs) cycles = detect_cycles(graph) if not cycles: logger.info("No cycles detected in the dependency graph") return graph logger.info(f"Detected {len(cycles)} cycles in the dependency graph") # Create a copy of the graph to modify new_graph = {node: deps.copy() for node, deps in graph.items()} # Process each cycle for i, cycle in enumerate(cycles): logger.info(f"Cycle {i+1}: {' -> '.join(cycle)}") # Strategy: Break the cycle by removing the "weakest" dependency # Here, we just arbitrarily remove the last edge to make the graph acyclic # In a real-world scenario, you might use heuristics to determine which edge to break # For example, removing edges between different modules before edges within the same module for j in range(len(cycle) - 1): current = cycle[j] next_node = cycle[j + 1] if next_node in new_graph[current]: logger.info(f"Breaking cycle by removing dependency: {current} -> {next_node}") new_graph[current].remove(next_node) break return new_graph def topological_sort(graph: Dict[str, Set[str]]) -> List[str]: """ Perform a topological sort on a dependency graph. Args: graph: A dependency graph represented as adjacency lists (node -> set of dependencies) Returns: A list of nodes in topological order (dependencies first) """ # First, check for and resolve cycles acyclic_graph = resolve_cycles(graph) # Initialize in-degree counter for all nodes in_degree = {node: 0 for node in acyclic_graph} # Count in-degrees for node, dependencies in acyclic_graph.items(): for dep in dependencies: if dep in in_degree: in_degree[dep] += 1 # Queue of nodes with no dependencies (in-degree of 0) queue = deque([node for node, degree in in_degree.items() if degree == 0]) # Result list to store the topological order result = [] # Process nodes in topological order while queue: node = queue.popleft() result.append(node) # Reduce in-degree for each node that depends on the current node for dependent, deps in acyclic_graph.items(): if node in deps: in_degree[dependent] -= 1 if in_degree[dependent] == 0: queue.append(dependent) # Check if the sort was successful (all nodes included) if len(result) != len(acyclic_graph): logger.warning("Topological sort failed: graph has cycles that weren't resolved") # Return all nodes in some order to avoid breaking the process return list(acyclic_graph.keys()) # Reverse the result to get dependencies first return result[::-1] def dependency_first_dfs(graph: Dict[str, Set[str]]) -> List[str]: """ Perform a depth-first traversal of the dependency graph, starting from root nodes that have no dependencies. The graph uses natural dependency direction: - If A depends on B, the graph has an edge A → B - This means an edge from X to Y represents "X depends on Y" - Root nodes (nodes with no incoming edges/dependencies) are processed first, followed by nodes that depend on them Args: graph: A dependency graph with natural direction (A→B if A depends on B) Returns: A list of nodes in an order where dependencies come before their dependents """ # First, resolve cycles to ensure we have a DAG acyclic_graph = resolve_cycles(graph) # Find root nodes (nodes with no dependencies) root_nodes = [] # Create a reverse graph to easily check if a node has incoming edges has_incoming_edge = {node: False for node in acyclic_graph} for node, deps in acyclic_graph.items(): for dep in deps: has_incoming_edge[dep] = True # Nodes with no incoming edges are root nodes for node in acyclic_graph: if not has_incoming_edge.get(node, False) and node in acyclic_graph: root_nodes.append(node) if not root_nodes: logger.warning("No root nodes found in the graph, using arbitrary starting point") root_nodes = list(acyclic_graph.keys())[:1] # Use the first node as starting point # Track visited nodes visited = set() result = [] # DFS function that processes dependencies first def dfs(node): if node in visited: return visited.add(node) # Visit all dependencies first for dep in sorted(acyclic_graph.get(node, set())): dfs(dep) # Add this node to the result after all its dependencies result.append(node) # Start DFS from each root node for root in sorted(root_nodes): dfs(root) # Check if all nodes were visited if len(result) != len(acyclic_graph): # Some nodes weren't visited - try to visit remaining nodes for node in sorted(acyclic_graph.keys()): if node not in visited: dfs(node) return result def build_graph_from_components(components: Dict[str, Any]) -> Dict[str, Set[str]]: """ Build a dependency graph from a collection of code components. The graph uses the natural dependency direction: - If A depends on B, we create an edge A → B - This means an edge from node X to node Y represents "X depends on Y" - Root nodes (nodes with no dependencies) are components that don't depend on anything Args: components: A dictionary of code components, where each component has a 'depends_on' attribute Returns: A dependency graph with natural dependency direction """ graph = {} for comp_id, component in components.items(): # Initialize the node's adjacency list if comp_id not in graph: graph[comp_id] = set() # Add dependencies for dep_id in component.depends_on: # Only include dependencies that are actual components in our repository if dep_id in components: graph[comp_id].add(dep_id) return graph ================================================ FILE: src/evaluate_helpfulness.py ================================================ #!/usr/bin/env python # Copyright (c) Meta Platforms, Inc. and affiliates """ Script to evaluate the helpfulness of docstrings generated by different systems. Usage: conda activate docstringgen python src/evaluate_helpfulness.py """ import os import yaml import argparse import sys from pathlib import Path # Add the src directory to the path so we can import modules src_dir = Path(__file__).parent.parent sys.path.insert(0, str(src_dir)) from src.evaluator.helpfulness_evaluator import DocstringHelpfulnessEvaluator def main(): parser = argparse.ArgumentParser(description="Evaluate docstring helpfulness") parser.add_argument("--data-path", type=str, default="experiments/eval/results/completeness_evaluation_cleaned.json", help="Path to the completeness evaluation data") parser.add_argument("--output-dir", type=str, default="experiments/eval/results/helpfulness", help="Directory to store evaluation results") parser.add_argument("--n-samples", type=int, default=50, help="Number of components to sample") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") parser.add_argument("--model", type=str, default=None, help="LLM model to use (defaults to model in config)") args = parser.parse_args() # Create output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) # Get configuration config_path = "config/agent_config.yaml" with open(config_path, 'r') as f: config = yaml.safe_load(f) # Get API key and model from config api_key = config["llm"]["api_key"] model = args.model or config["llm"]["model"] print(f"Using model: {model}") print(f"Sampling {args.n_samples} components with seed {args.seed}") # Initialize evaluator evaluator = DocstringHelpfulnessEvaluator( data_path=args.data_path, output_dir=args.output_dir, api_key=api_key, model=model ) # Run evaluation results = evaluator.run_evaluation( n_samples=args.n_samples, seed=args.seed ) # Print summary print("\n=== Evaluation Complete ===") print(f"Results saved to {args.output_dir}") print(f"Total evaluations: {len(results['results'])}") # Calculate average score scores = [r["score"] for r in results["results"]] avg_score = sum(scores) / len(scores) if scores else 0 print(f"Overall average score: {avg_score:.2f}") # Calculate average by system systems = evaluator.SYSTEMS for system in systems: system_scores = [r["score"] for r in results["results"] if r["system"] == system] if system_scores: avg = sum(system_scores) / len(system_scores) print(f"{system}: {avg:.2f} (n={len(system_scores)})") if __name__ == "__main__": main() ================================================ FILE: src/evaluator/README.md ================================================ # Docstring Quality Evaluator provides a robust framework for evaluating the quality of Python docstrings. It uses static analysis through the Abstract Syntax Tree (AST) to examine docstrings in Python code and assess their completeness based on established documentation standards. ## Architecture Overview The project follows a hierarchical design with clear separation of concerns: ### Base Evaluator The foundation of the evaluation system is the `BaseEvaluator` abstract class. This class establishes the core interface that all evaluators must implement: ```python class BaseEvaluator(ABC): def __init__(self, name: str, description: str): self._score: float = 0.0 self._name = name self._description = description ``` Every evaluator derives from this base class, ensuring consistent scoring behavior and interface across the system. The base evaluator enforces score validation (must be between 0 and 1) and provides the abstract `evaluate` method that all concrete evaluators must implement. ### Completeness Evaluation The completeness evaluation system is structured in three layers: 1. `CompletenessEvaluator`: The base class for completeness evaluation 2. `ClassCompletenessEvaluator`: Specializes in evaluating class docstrings 3. `FunctionCompletenessEvaluator`: Specializes in evaluating function/method docstrings #### Class Docstring Evaluation The `ClassCompletenessEvaluator` examines four essential elements of class documentation: 1. **Summary** (required) - A one-line description at the start of the docstring - Must be the first non-empty line - Should provide a quick overview of the class's purpose 2. **Description** (required) - Detailed explanation following the summary - Multiple lines describing the class's functionality - Appears before any special sections (Attributes, Examples, etc.) 3. **Attributes** (required if class has attributes) - Documentation of class attributes - Must start with "Attributes:" section - Lists each attribute with type information and description - Required if class has class variables, instance variables in __init__, or enum values 4. **Parameters** (required if class has __init__ parameters) - Documentation of constructor parameters - Must start with "Parameters:" section - Lists each parameter with type information and description - Required if __init__ has parameters beyond self 5. **Examples** (required for public classes) - Usage examples showing how to use the class - Must start with "Example:" or "Examples:" section - Should include executable code snippets - Only required for classes not starting with underscore (_) Each element is evaluated independently through dedicated methods: ```python @staticmethod def evaluate_summary(docstring: str) -> float: """Evaluates if a proper one-liner summary exists.""" @staticmethod def evaluate_description(docstring: str) -> float: """Evaluates if a proper description section exists.""" @staticmethod def evaluate_attributes(docstring: str) -> float: """Evaluates if attribute documentation exists.""" @staticmethod def evaluate_examples(docstring: str) -> float: """Evaluates if usage examples exist.""" ``` #### Function Docstring Evaluation The `FunctionCompletenessEvaluator` examines up to six elements, with required elements determined dynamically based on the function's characteristics: 1. **Summary** (required for all functions) - One-line description at the start - Concise explanation of function's purpose 2. **Description** (required for all functions) - Detailed explanation of functionality - Implementation details and usage notes 3. **Arguments** (required if function has parameters) - Documentation for each parameter - Must start with "Args:" or "Arguments:" - Includes type information and description 4. **Returns** (required if function has return statement) - Documentation of return value - Must start with "Returns:" - Includes type information and description 5. **Raises** (required if function has raise statements) - Documentation of exceptions - Must start with "Raises:" - Lists each exception type and trigger condition 6. **Examples** (required for public functions) - Usage examples - Must start with "Example:" or "Examples:" - Not required for private methods (starting with underscore) The evaluator automatically determines required sections through AST analysis: ```python def _get_required_sections(self, node: ast.FunctionDef) -> List[str]: """Determines which sections are required based on function characteristics.""" ``` ### Scoring System Both evaluators use a normalized scoring system: 1. Each required element contributes equally to the final score 2. Scores are always between 0.0 and 1.0 3. Individual element scores are stored in `element_scores` dictionary 4. Final score is the average of all required element scores For example, if a class docstring has all elements except examples: ```python element_scores = { 'summary': 1.0, 'description': 1.0, 'attributes': 1.0, 'examples': 0.0 } final_score = 0.75 # (1.0 + 1.0 + 1.0 + 0.0) / 4 ``` ## Usage Examples ### Evaluating a Class Docstring ```python from docstring_evaluator import ClassCompletenessEvaluator import ast # Create evaluator evaluator = ClassCompletenessEvaluator() # Define class with docstring class_code = ''' class MyClass: """ A demonstration class. This class shows proper docstring formatting. Attributes: name (str): The class name. Example: >>> obj = MyClass() """ pass ''' # Parse and evaluate node = ast.parse(class_code).body[0] score = evaluator.evaluate(node) print(f"Overall score: {score}") print("Element scores:", evaluator.element_scores) ``` ### Evaluating a Function Docstring ```python from docstring_evaluator import FunctionCompletenessEvaluator import ast # Create evaluator evaluator = FunctionCompletenessEvaluator() # Define function with docstring function_code = ''' def process_data(data: List[str]) -> Dict[str, int]: """ Process a list of strings and return word frequencies. This function takes a list of strings and returns a dictionary containing the frequency of each word. Args: data (List[str]): List of strings to process. Returns: Dict[str, int]: Dictionary mapping words to their frequencies. Raises: ValueError: If input list is empty. Example: >>> process_data(["hello", "world", "hello"]) {'hello': 2, 'world': 1} """ if not data: raise ValueError("Empty input list") return Counter(data) ''' # Parse and evaluate node = ast.parse(function_code).body[0] score = evaluator.evaluate(node) print(f"Overall score: {score}") print("Element scores:", evaluator.element_scores) ``` ### Exception Handling Guidelines The evaluator checks for uncaught exceptions in two ways: 1. Direct raise statements: - Walks through all raise statements in the function - Checks if each raise is inside a try-except block - If a raise is not caught by any except handler, it's considered to bubble up 2. Function calls: - Walks through all function call nodes - Assumes any uncaught function call could potentially raise - Checks if the call is inside a try-except block - If not caught, considers it as a potential exception source The evaluator uses AST traversal to track parent-child relationships and determine if exceptions are properly handled within the function scope. ### Function Analysis Limitations - Nested functions (functions defined inside other functions) are not evaluated by the tool. These inner functions are skipped during analysis. ### Other Notes - __init__ function is not evaluated. (will be considered during the evaluation of the class) ## Best Practices for Documentation To achieve high scores, follow these guidelines: 1. Always start with a clear, one-line summary 2. Provide detailed description in subsequent paragraphs 3. Document all attributes for classes 4. Include practical usage examples 5. For functions: - Document all parameters under "Args:" - Specify return type and value under "Returns:" - List all possible exceptions under "Raises:" - Provide examples for public functions ## Development ### Adding New Evaluators To create new evaluators: 1. Inherit from `BaseEvaluator` 2. Implement the `evaluate` method 3. Define specific evaluation criteria 4. Add unit tests Example: ```python class StyleEvaluator(BaseEvaluator): """Evaluates docstring style consistency.""" def evaluate(self, node: ast.AST) -> float: # Implementation here pass ``` # Limitations - the elements must start with the included labels. (see definition of evaluators) Otherwise, the evaluator will not be able to detect the element. - except summary and description. (which is detected by the first and second non-empty line) - each element must seperate by at least one empty line. ================================================ FILE: src/evaluator/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from .base import BaseEvaluator from .completeness import ( # Remove 'evaluators.' from the path CompletenessEvaluator, ClassCompletenessEvaluator, FunctionCompletenessEvaluator ) __all__ = [ 'BaseEvaluator', 'CompletenessEvaluator', 'ClassCompletenessEvaluator', 'FunctionCompletenessEvaluator' ] ================================================ FILE: src/evaluator/base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod import ast from typing import Optional, Dict, Any class BaseEvaluator(ABC): """ Base class for all docstring evaluators. This class provides the foundation for implementing various docstring quality evaluators. Each evaluator should focus on a specific aspect of docstring quality such as completeness, helpfulness, or redundancy. Attributes: score (float): The evaluation score, ranging from 0 to 1. name (str): The name of the evaluator. description (str): A description of what this evaluator checks. """ def __init__(self, name: str, description: str): self._score: float = 0.0 self._name = name self._description = description @property def score(self) -> float: """ Returns the current evaluation score. Returns: float: A score between 0 and 1 indicating the quality measure. """ return self._score @score.setter def score(self, value: float) -> None: """ Sets the evaluation score. Args: value (float): The score to set, must be between 0 and 1. Raises: ValueError: If the score is not between 0 and 1. """ if not 0 <= value <= 1: raise ValueError("Score must be between 0 and 1") self._score = value @abstractmethod def evaluate(self, node: ast.AST) -> float: """ Evaluates the quality of a docstring based on specific criteria. Args: node (ast.AST): The AST node containing the docstring to evaluate. Returns: float: The evaluation score between 0 and 1. """ pass ================================================ FILE: src/evaluator/completeness.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import ast import re from typing import Dict, List, Optional from evaluator.base import BaseEvaluator class CompletenessEvaluator(BaseEvaluator): """ Base class for evaluating docstring completeness. This evaluator examines whether a docstring contains all necessary elements according to common documentation standards. Attributes: score (float): The completeness score from 0 to 1. element_scores (Dict[str, bool]): Individual scores for each docstring element. element_required (Dict[str, bool]): Whether each element is required. weights (List[float]): Weights for each element in scoring. """ def __init__(self, name: str, description: str): super().__init__(name=name, description=description) self.element_scores: Dict[str, bool] = {} self.element_required: Dict[str, bool] = {} self.weights: List[float] = [] def evaluate(self, node: ast.AST) -> float: """ Evaluates the completeness of a docstring. This method determines which specific evaluator to use based on the AST node type and delegates the evaluation accordingly. Args: node (ast.AST): The AST node to evaluate. Returns: float: The completeness score between 0 and 1. Raises: ValueError: If the node type is not supported. """ if isinstance(node, ast.ClassDef): evaluator = ClassCompletenessEvaluator() self.score = evaluator.evaluate(node) elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): evaluator = FunctionCompletenessEvaluator() self.score = evaluator.evaluate(node) else: raise ValueError(f"Unsupported node type: {type(node)}") return self.score class ClassCompletenessEvaluator(CompletenessEvaluator): """ Evaluator for class docstring completeness. This evaluator checks for the presence of required elements in class docstrings including summary, description, attributes, parameters, and examples. Attributes: score (float): The overall completeness score from 0 to 1. element_scores (Dict[str, bool]): Individual scores for each docstring element. element_required (Dict[str, bool]): Whether each element is required. weights (List[float]): Weights for each element in scoring. required_sections (List[str]): List of required sections for the current class. """ # Valid section labels (case-insensitive) ATTRIBUTE_LABELS = { "attributes:", "members:", "member variables:", "instance variables:", "properties:", } EXAMPLE_LABELS = { "example:", "examples:", "usage:", "usage example:", "usage examples:", } PARAMETER_LABELS = {"parameters:", "params:", "args:", "arguments:"} def __init__(self): super().__init__( name="Class Completeness Evaluator", description="Evaluates the completeness of class docstrings", ) # Initialize element scores and requirements elements = ["summary", "description", "parameters", "attributes", "examples"] self.element_scores = {el: False for el in elements} self.element_required = { el: False for el in elements } # Will be set during evaluation self.weights = [0.2] * len(elements) # Equal weights by default # Verify dictionaries have same keys in same order assert list(self.element_scores.keys()) == list(self.element_required.keys()) assert len(self.element_scores) == len(self.weights) self.required_sections: List[str] = [] @staticmethod def evaluate_summary(docstring: str) -> bool: """ Evaluates if the docstring has a proper one-liner summary. Args: docstring (str): The docstring to evaluate. Returns: bool: True if summary exists, False otherwise. """ lines = docstring.strip().split("\n") return bool(lines and lines[0].strip()) @staticmethod def evaluate_description(docstring: str) -> bool: """ Evaluates if the docstring has a proper description. Args: docstring (str): The docstring to evaluate. Returns: bool: True if description exists, False otherwise. """ # Split docstring into chunks by empty lines chunks = [] current_chunk = [] for line in docstring.strip().split("\n"): if not line.strip(): if current_chunk: chunks.append(current_chunk) current_chunk = [] else: current_chunk.append(line.strip()) if current_chunk: chunks.append(current_chunk) # Need at least 2 chunks (summary and description) if len(chunks) < 2: return False # Check if second chunk starts with any other section label description_chunk = chunks[1] if not description_chunk: return False first_line = description_chunk[0].lower() for labels in [ ClassCompletenessEvaluator.ATTRIBUTE_LABELS, ClassCompletenessEvaluator.PARAMETER_LABELS, ClassCompletenessEvaluator.EXAMPLE_LABELS, ]: if any(first_line.startswith(label.lower()) for label in labels): return False return True @staticmethod def evaluate_attributes(docstring: str) -> bool: """ Evaluates if the docstring has attribute documentation. Args: docstring (str): The docstring to evaluate. Returns: bool: True if attributes section exists, False otherwise. """ # Check if any attribute label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in ClassCompletenessEvaluator.ATTRIBUTE_LABELS ) @staticmethod def evaluate_parameters(docstring: str) -> bool: """ Evaluates if the docstring has constructor parameter documentation. Args: docstring (str): The docstring to evaluate. Returns: bool: True if parameters section exists, False otherwise. """ # Check if any parameter label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in ClassCompletenessEvaluator.PARAMETER_LABELS ) @staticmethod def evaluate_examples(docstring: str) -> bool: """ Evaluates if the docstring has usage examples. Args: docstring (str): The docstring to evaluate. Returns: bool: True if examples section exists, False otherwise. """ # Check if any example label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in ClassCompletenessEvaluator.EXAMPLE_LABELS ) def _has_attributes(self, node: ast.ClassDef) -> bool: """ Checks if the class has attributes by looking for class variables, instance variables in __init__, or enum values. Args: node (ast.ClassDef): The class definition node. Returns: bool: True if class has attributes, False otherwise. """ # Check for class variables has_class_vars = any( isinstance(item, (ast.AnnAssign, ast.Assign)) for item in node.body ) # Check for instance variables in __init__ has_instance_vars = False for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == "__init__": has_instance_vars = any( isinstance(stmt, ast.Assign) and isinstance(stmt.targets[0], ast.Attribute) and isinstance(stmt.targets[0].value, ast.Name) and stmt.targets[0].value.id == "self" for stmt in ast.walk(item) ) break # Check if it's an Enum is_enum = ( hasattr(node, "bases") and node.bases and any( isinstance(base, ast.Name) and base.id == "Enum" for base in node.bases ) ) return has_class_vars or has_instance_vars or is_enum def _get_required_sections(self, node: ast.ClassDef) -> List[str]: """ Determines which sections are required for the class docstring. Args: node (ast.ClassDef): The class definition node. Returns: List[str]: List of required section names. """ required = ["summary", "description"] if self._has_attributes(node): required.append("attributes") # Check if __init__ has parameters beyond self if self._has_init_parameters(node): required.append("parameters") # Examples are required for public classes if not node.name.startswith("_"): required.append("examples") return required def _has_init_parameters(self, node: ast.ClassDef) -> bool: """ Checks if the class __init__ method has parameters beyond self. Args: node (ast.ClassDef): The class definition node. Returns: bool: True if __init__ has parameters beyond self. """ for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == "__init__": args = [arg for arg in item.args.args if arg.arg != "self"] return bool(args or item.args.kwonlyargs) return False def evaluate(self, node: ast.ClassDef) -> float: """ Evaluates the completeness of a class docstring. Checks for: 1. One-liner summary 2. Description 3. Attributes documentation 4. Parameters documentation (if __init__ has parameters beyond self) 5. Usage examples Args: node (ast.ClassDef): The class definition node to evaluate. Returns: float: The completeness score between 0 and 1. """ # Get required sections for this class first self.required_sections = self._get_required_sections(node) # Reset scores and update requirements self.element_scores = {key: False for key in self.element_scores} self.element_required = { key: key in self.required_sections for key in self.element_scores } docstring = ast.get_docstring(node) if not docstring: self.score = 0.0 return self.score # Evaluate each element if "summary" in self.required_sections: self.element_scores["summary"] = self.evaluate_summary(docstring) if "description" in self.required_sections: self.element_scores["description"] = self.evaluate_description(docstring) if "parameters" in self.required_sections: self.element_scores["parameters"] = self.evaluate_parameters(docstring) if "attributes" in self.required_sections: self.element_scores["attributes"] = self.evaluate_attributes(docstring) if "examples" in self.required_sections: self.element_scores["examples"] = self.evaluate_examples(docstring) # Calculate weighted score considering requirements total_weight = 0.0 weighted_score = 0.0 for (key, score), weight, required in zip( self.element_scores.items(), self.weights, self.element_required.values() ): if required: total_weight += weight if score: weighted_score += weight self.score = weighted_score / total_weight if total_weight > 0 else 0.0 return self.score def evaluate_using_string(self, docstring: str, element_required: Dict) -> Dict: """ """ # Get required sections for this class first # Reset scores and update requirements element_scores = {key: False for key in element_required} if not docstring: score = 0.0 return element_scores # Evaluate each element for key in element_required: if key == "summary": element_scores[key] = self.evaluate_summary(docstring) elif key == "description": element_scores[key] = self.evaluate_description(docstring) elif key == "parameters": element_scores[key] = self.evaluate_parameters(docstring) elif key == "attributes": element_scores[key] = self.evaluate_attributes(docstring) elif key == "examples": element_scores[key] = self.evaluate_examples(docstring) return element_scores class FunctionCompletenessEvaluator(CompletenessEvaluator): """ Evaluator for function/method docstring completeness. This evaluator checks for the presence of required elements in function docstrings including summary, description, arguments, returns, raises, and examples. Attributes: score (float): The overall completeness score from 0 to 1. element_scores (Dict[str, bool]): Individual scores for each docstring element. element_required (Dict[str, bool]): Whether each element is required. weights (List[float]): Weights for each element in scoring. required_sections (List[str]): List of required sections for the current function. """ # Valid section labels (case-insensitive) ARGS_LABELS = {"args:", "arguments:", "parameters:", "params:"} RETURNS_LABELS = { "returns:", "return:", "return value:", "return type:", "yields:", "yield:", } RAISES_LABELS = {"raises:", "exceptions:", "throws:"} EXAMPLE_LABELS = { "example:", "examples:", "usage:", "usage example:", "usage examples:", } def __init__(self): super().__init__( name="Function Completeness Evaluator", description="Evaluates the completeness of function docstrings", ) # Initialize element scores and requirements elements = ["summary", "description", "args", "returns", "raises", "examples"] self.element_scores = {el: False for el in elements} self.element_required = { el: False for el in elements } # Will be set during evaluation self.weights = [1 / len(elements)] * len(elements) # Equal weights by default # Verify dictionaries have same keys in same order assert list(self.element_scores.keys()) == list(self.element_required.keys()) assert len(self.element_scores) == len(self.weights) self.required_sections: List[str] = [] @staticmethod def evaluate_summary(docstring: str) -> bool: """ Evaluates if the docstring has a proper one-liner summary. Args: docstring (str): The docstring to evaluate. Returns: bool: True if summary exists, False otherwise. """ lines = docstring.strip().split("\n") return bool(lines and lines[0].strip()) @staticmethod def evaluate_description(docstring: str) -> bool: """ Evaluates if the docstring has a proper description. Args: docstring (str): The docstring to evaluate. Returns: bool: True if description exists, False otherwise. """ # Split docstring into chunks by empty lines chunks = [] current_chunk = [] for line in docstring.strip().split("\n"): if not line.strip(): if current_chunk: chunks.append(current_chunk) current_chunk = [] else: current_chunk.append(line.strip()) if current_chunk: chunks.append(current_chunk) # Need at least 2 chunks (summary and description) if len(chunks) < 2: return False # Check if second chunk starts with any other section label description_chunk = chunks[1] if not description_chunk: return False first_line = description_chunk[0].lower() for labels in [ FunctionCompletenessEvaluator.ARGS_LABELS, FunctionCompletenessEvaluator.RETURNS_LABELS, FunctionCompletenessEvaluator.RAISES_LABELS, FunctionCompletenessEvaluator.EXAMPLE_LABELS, ]: if any(first_line.startswith(label.lower()) for label in labels): return False return True @staticmethod def evaluate_args(docstring: str) -> bool: """ Evaluates if the docstring has argument documentation. Args: docstring (str): The docstring to evaluate. Returns: bool: True if arguments section exists, False otherwise. """ # Check if any argument label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in FunctionCompletenessEvaluator.ARGS_LABELS ) @staticmethod def evaluate_returns(docstring: str) -> bool: """ Evaluates if the docstring has return value or yield documentation. Args: docstring (str): The docstring to evaluate. Returns: bool: True if returns/yields section exists, False otherwise. """ # Check if any return label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in FunctionCompletenessEvaluator.RETURNS_LABELS ) @staticmethod def evaluate_raises(docstring: str) -> bool: """ Evaluates if the docstring has exception documentation. Args: docstring (str): The docstring to evaluate. Returns: bool: True if raises section exists, False otherwise. """ # Check if any raise label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in FunctionCompletenessEvaluator.RAISES_LABELS ) @staticmethod def evaluate_examples(docstring: str) -> bool: """ Evaluates if the docstring has usage examples. Args: docstring (str): The docstring to evaluate. Returns: bool: True if examples section exists, False otherwise. """ # Check if any example label appears anywhere in the docstring return any( label.lower() in docstring.lower() for label in FunctionCompletenessEvaluator.EXAMPLE_LABELS ) def evaluate(self, node: ast.FunctionDef) -> float: """ Evaluates the completeness of a function docstring. Checks for: 1. One-liner summary 2. Description 3. Arguments documentation (if has arguments) 4. Returns documentation (if has return) 5. Raises documentation (if has raise statements) 6. Examples (if not private) Args: node (ast.FunctionDef): The function definition node to evaluate. Returns: float: The completeness score between 0 and 1. """ # Skip __init__ methods if node.name == "__init__": self.score = 1.0 return self.score # Get required sections for this function first self.required_sections = self._get_required_sections(node) # Reset scores and update requirements self.element_scores = {key: False for key in self.element_scores} self.element_required = { key: key in self.required_sections for key in self.element_scores } docstring = ast.get_docstring(node) if not docstring: self.score = 0.0 return self.score # Evaluate each element if "summary" in self.required_sections: self.element_scores["summary"] = self.evaluate_summary(docstring) if "description" in self.required_sections: self.element_scores["description"] = self.evaluate_description(docstring) if "args" in self.required_sections: self.element_scores["args"] = self.evaluate_args(docstring) if "returns" in self.required_sections: self.element_scores["returns"] = self.evaluate_returns(docstring) if "raises" in self.required_sections: self.element_scores["raises"] = self.evaluate_raises(docstring) if "examples" in self.required_sections: self.element_scores["examples"] = self.evaluate_examples(docstring) # Calculate weighted score considering requirements total_weight = 0.0 weighted_score = 0.0 for (key, score), weight, required in zip( self.element_scores.items(), self.weights, self.element_required.values() ): if required: total_weight += weight if score: weighted_score += weight self.score = weighted_score / total_weight if total_weight > 0 else 0.0 return self.score def evaluate_using_string(self, docstring: str, element_required: Dict) -> Dict: """ """ # Get required sections for this class first # Reset scores and update requirements element_scores = {key: False for key in element_required} if not docstring: return element_scores # Evaluate each element for key in element_required: if key == "summary": element_scores[key] = self.evaluate_summary(docstring) elif key == "description": element_scores[key] = self.evaluate_description(docstring) elif key == "args": element_scores[key] = self.evaluate_args(docstring) elif key == "returns": element_scores[key] = self.evaluate_returns(docstring) elif key == "raises": element_scores[key] = self.evaluate_raises(docstring) elif key == "examples": element_scores[key] = self.evaluate_examples(docstring) return element_scores def _get_required_sections(self, node: ast.FunctionDef) -> List[str]: """ Determines which sections are required for the function docstring. Args: node (ast.FunctionDef): The function definition node. Returns: List[str]: List of required section names. """ required = ["summary", "description"] # Check if function has arguments beyond just 'self' args = [arg for arg in node.args.args if arg.arg != "self"] if args or node.args.kwonlyargs: required.append("args") # Check if function has returns if self._has_return_statement(node): required.append("returns") # Check if function has raise statements if self._has_raise_statement(node): required.append("raises") # Check if function is public (not starting with _) if not node.name.startswith("_"): required.append("examples") return required def _has_return_statement(self, node: ast.FunctionDef) -> bool: """ Checks if the function has any meaningful return statements or yields. A return statement is considered meaningful if it: 1. Returns a value other than None 2. Uses yield or yield from (generator function) 3. Has an explicit return None statement Args: node (ast.FunctionDef): The function definition node. Returns: bool: True if the function has a meaningful return value or is a generator. """ has_explicit_return = False for child in ast.walk(node): if isinstance(child, ast.Return): if child.value is not None: # Return with any value (including None) has_explicit_return = True if ( not isinstance(child.value, ast.Constant) or child.value.value is not None ): return True elif isinstance(child, (ast.Yield, ast.YieldFrom)): # Function is a generator return True return has_explicit_return def _has_raise_statement(self, node: ast.FunctionDef) -> bool: """ Checks if the function has any uncaught raise statements that bubble up to caller. Args: node (ast.FunctionDef): The function definition node. Returns: bool: True if the function has any uncaught raise statements. """ for child in ast.walk(node): if isinstance(child, ast.Raise): # Check if this raise is inside a try-except block parent = child while parent != node: if isinstance(parent, ast.ExceptHandler): # Exception is caught, skip this raise break parent = next( p for p in ast.walk(node) if any( isinstance(c, type(parent)) and c is parent for c in ast.iter_child_nodes(p) ) ) else: # No except handler found, exception bubbles up return True # Also check any function calls that may raise for child in ast.walk(node): if isinstance(child, ast.Call): # Here we could recursively check called functions # but for now we'll assume any uncaught function call # could potentially raise try: parent = child while parent != node: if isinstance(parent, ast.ExceptHandler): break parent = next( p for p in ast.walk(node) if any( isinstance(c, type(parent)) and c is parent for c in ast.iter_child_nodes(p) ) ) else: return True except StopIteration: continue return False ================================================ FILE: src/evaluator/evaluation_common.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """Common utilities and classes for docstring evaluation.""" from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass from enum import Enum class ScoreLevel(Enum): """Defines the possible score levels for docstring evaluation.""" POOR = 1 FAIR = 2 GOOD = 3 VERY_GOOD = 4 EXCELLENT = 5 @dataclass class SummaryEvaluationExample: """Stores an example of docstring summary evaluation with different quality levels.""" function_signature: str summaries: Dict[ScoreLevel, str] explanations: Dict[ScoreLevel, str] @dataclass class DescriptionEvaluationExample: """Stores an example of docstring description evaluation with different quality levels.""" function_signature: str descriptions: Dict[ScoreLevel, str] explanations: Dict[ScoreLevel, str] @dataclass class ParameterEvaluationExample: """Stores an example of docstring parameter evaluation with different quality levels.""" parameters: Dict[str, str] quality_examples: Dict[ScoreLevel, Dict[str, str]] explanations: Dict[ScoreLevel, str] ================================================ FILE: src/evaluator/helper/context_finder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import List, Dict, Optional, Tuple import os import ast import json from pathlib import Path import re class UsageLocation: """Represents a location where a function/class/method is used.""" def __init__(self, file_path: str, line_number: int, usage_type: str): self.file_path = file_path self.line_number = line_number self.usage_type = usage_type # 'function', 'class', or 'method' def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization.""" return { 'file_path': self.file_path, 'line_number': self.line_number, 'usage_type': self.usage_type, 'repo_path': self.repo_path, 'signature': self.signature } @classmethod def from_dict(cls, data: Dict) -> 'UsageLocation': """Create from dictionary.""" return cls(data['file_path'], data['line_number'], data['usage_type']) class ContextSearcher: """ Searches for usage of functions, classes, and methods in a Python project. Caches results to avoid repeated searches. """ def __init__(self, repo_path: str): """ Initialize the searcher. Args: repo_path: Path to the repository root """ self.repo_path = Path(repo_path) self.cache_dir = os.path.join('data', 'evaluator' , 'search_cache') os.makedirs(self.cache_dir, exist_ok=True) def _get_cache_key(self, file_path: str, signature: str) -> str: """Generate a cache key for the search.""" import hashlib # Create a unique key based on file path and signature key = f"{file_path}:{signature}" return hashlib.md5(key.encode()).hexdigest() def _load_from_cache(self, cache_key: str) -> Optional[List[UsageLocation]]: """Load search results from cache if available.""" cache_file = self.cache_dir + f"/{cache_key}.json" if os.path.exists(cache_file): with open(cache_file) as f: data = json.load(f) return [UsageLocation.from_dict(loc) for loc in data] return None def _save_to_cache(self, cache_key: str, locations: List[UsageLocation]): """Save search results to cache.""" cache_file = self.cache_dir + f"/{cache_key}.json" with open(cache_file, 'w') as f: json.dump([loc.to_dict() for loc in locations], f, indent=2) def find_usages(self, target_file: str, signature: str) -> List[UsageLocation]: """ Find all usages of a function/class/method in the repository. Args: target_file: Relative path to the file containing the target signature: The signature of the function/class/method Returns: List of UsageLocation objects """ cache_key = self._get_cache_key(target_file, signature) # Try to load from cache first cached_results = self._load_from_cache(cache_key) if cached_results is not None: return cached_results # Parse signature to get name and type name, usage_type = self._parse_signature(signature) locations = [] # Walk through all Python files in the repo for root, _, files in os.walk(self.repo_path): for file in files: if not file.endswith('.py'): continue file_path = Path(root) / file rel_path = file_path.relative_to(self.repo_path) # Skip the target file itself if str(rel_path) == target_file: continue try: with open(file_path) as f: content = f.read() # Find all usages in this file file_locations = self._find_usages_in_file( content, str(rel_path), name, usage_type ) # Add repo path and signature to each location for loc in file_locations: loc.repo_path = str(self.repo_path) loc.signature = signature locations.extend(file_locations) except Exception as e: print(f"Error processing {file_path}: {e}") # Cache the results self._save_to_cache(cache_key, locations) return locations def _parse_signature(self, signature: str) -> Tuple[str, str]: """Parse a signature to get name and type.""" signature = signature.strip() # Split into lines to check for decorators is_static = '@staticmethod' in signature # remove @staticmethod decorator if is_static: signature = signature.replace('@staticmethod', '').strip() if signature.startswith('class '): return signature.split()[1].split('(')[0].split(':')[0], 'class' elif signature.startswith('def '): name = signature.split()[1].split('(')[0] if name == '__init__': return None, 'method' # Skip __init__ methods if is_static: return name, 'staticmethod' return name, 'function' if '(self' not in signature else 'method' raise ValueError(f"Invalid signature: {signature}") def _find_usages_in_file(self, content: str, file_path: str, name: str, usage_type: str) -> List[UsageLocation]: """Find all usages in a single file.""" locations = [] tree = ast.parse(content) for node in ast.walk(tree): # For function calls and static methods if usage_type in ('function', 'method', 'staticmethod'): if usage_type == 'staticmethod': if isinstance(node, ast.Assign): if isinstance(node.value, ast.Call): if isinstance(node.value.func, ast.Attribute) and node.value.func.attr == name: locations.append(UsageLocation( file_path, node.lineno, usage_type )) elif isinstance(node, ast.Call): if usage_type == 'function' and isinstance(node.func, ast.Name): if node.func.id == name: locations.append(UsageLocation( file_path, node.lineno, usage_type )) elif usage_type == 'method' and isinstance(node.func, ast.Attribute): if node.func.attr == name: locations.append(UsageLocation( file_path, node.lineno, usage_type )) # For class instantiation elif usage_type == 'class': if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): if node.func.id == name: locations.append(UsageLocation( file_path, node.lineno, usage_type )) return locations class ContextPreparer: """ Prepares context for example evaluation by extracting relevant code from usage locations. """ def __init__(self, repo_path: str): """ Initialize the preparer. Args: repo_path: Path to the repository root """ self.repo_path = Path(repo_path) self.searcher = ContextSearcher(repo_path) def prepare_contexts(self, target_file: str, signature: str) -> List[Tuple[str, str]]: """ Prepare context for all usages of a function/class/method. Args: target_file: Relative path to the file containing the target signature: The signature of the function/class/method Returns: List of tuples (context_code, ground_truth) where: - context_code is the code leading up to the usage - ground_truth is the actual usage line """ locations = self.searcher.find_usages(target_file, signature) contexts = [] for location in locations: context, ground_truth = self._prepare_single_context(location) if context and ground_truth: contexts.append((context, ground_truth)) return contexts def _prepare_single_context(self, location: UsageLocation) -> Tuple[Optional[str], Optional[str]]: """Prepare context for a single usage location.""" file_path = self.repo_path / location.file_path with open(file_path) as f: lines = f.readlines() # Get the ground truth lines ground_truth_lines = [] i = location.line_number - 1 # Keep adding lines until we find a line ending with colon after right parenthesis while i < len(lines): line = lines[i].strip() ground_truth_lines.append(line) if ')' in line: break i += 1 ground_truth = '\n'.join(ground_truth_lines) # Get the context (all lines up to the usage) context_lines = lines[:location.line_number - 1] # Remove trailing empty lines while context_lines and not context_lines[-1].strip(): context_lines.pop() context = ''.join(context_lines) return context, ground_truth ================================================ FILE: src/evaluator/helpfulness_attributes.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, List, Optional, Tuple import re from dataclasses import dataclass from enum import Enum class ScoreLevel(Enum): """Defines the possible score levels for docstring evaluation.""" POOR = 1 FAIR = 2 GOOD = 3 VERY_GOOD = 4 EXCELLENT = 5 @dataclass class EvaluationExample: """Stores an example of docstring attribute evaluation with different quality levels.""" class_signature: str init_function: str attributes: Dict[str, str] quality_examples: Dict[ScoreLevel, Dict[str, str]] explanations: Dict[ScoreLevel, str] class DocstringAttributeEvaluator: """ Evaluates the quality of Python docstring attribute descriptions using predefined criteria. This class assesses how well attribute descriptions in docstrings convey the purpose, lifecycle, and usage context of class attributes, going beyond mere type information to provide meaningful guidance about attribute roles and behaviors. """ def __init__(self): """Initialize the evaluator with predefined criteria and examples.""" self.criteria = self._initialize_criteria() self.examples = self._initialize_examples() def _initialize_criteria(self) -> Dict[str, Any]: """ Set up the evaluation criteria for attribute descriptions. The criteria define five quality levels, from mere type repetition (1) to excellent usage guidance and context (5). Returns: Dict containing the evaluation criteria and descriptions for each score level. """ return { 'description': ( 'Evaluate how effectively the attribute descriptions convey the purpose, ' 'lifecycle, and usage context of class attributes. High-quality descriptions ' 'should go beyond type information to provide meaningful guidance about ' 'attribute roles, initialization, modification patterns, and relationships ' 'with class behavior.' ), 'score_criteria': { ScoreLevel.POOR: ( 'The attribute descriptions merely restate the attribute types or ' 'convert the type hints to natural language without adding any ' 'meaningful information about purpose or lifecycle.' ), ScoreLevel.FAIR: ( 'The descriptions provide basic information about attribute purpose ' 'but lack details about initialization, modification, or usage patterns. ' 'They may use vague language or miss important details.' ), ScoreLevel.GOOD: ( 'The descriptions explain attribute purpose and include some key ' 'information about initialization or usage patterns, but might miss ' 'important lifecycle details or relationships with class behavior.' ), ScoreLevel.VERY_GOOD: ( 'The descriptions clearly explain purpose, initialization, and common ' 'usage patterns. They may note important relationships with class ' 'methods and document any special handling or constraints.' ), ScoreLevel.EXCELLENT: ( 'The descriptions provide comprehensive guidance including purpose, ' 'initialization, modification patterns, relationships with class ' 'behavior, and any special considerations. They help users understand ' 'both how and when to interact with the attributes.' ) } } def _initialize_examples(self) -> List[EvaluationExample]: """ Set up concrete examples of attribute descriptions at different quality levels. Each example includes class and __init__ signatures with corresponding attribute descriptions at different quality levels, along with explanations of the ratings. Returns: List of EvaluationExample objects containing the example cases. """ return [ EvaluationExample( class_signature="class DataProcessor:", init_function='''def __init__(self, config: Dict[str, Any]): """Initialize the data processor. Args: config: Configuration dictionary for the processor """ self.config = config self.data_cache = {} self.is_initialized = False self.stats = defaultdict(int) self._lock = threading.Lock()''', attributes={ "config": "Configuration settings for the processor", "data_cache": "Cache for processed data", "is_initialized": "Whether the processor is initialized", "stats": "Processing statistics", "_lock": "Thread synchronization lock" }, quality_examples={ ScoreLevel.POOR: { "config": "Dictionary of configuration", "data_cache": "Dictionary for cache", "is_initialized": "Boolean flag", "stats": "Dictionary of statistics", "_lock": "Threading lock object" }, ScoreLevel.FAIR: { "config": "Configuration settings for processing", "data_cache": "Cache storage for processed items", "is_initialized": "Tracks initialization status", "stats": "Counts of processed items", "_lock": "Lock for thread safety" }, ScoreLevel.GOOD: { "config": "Configuration dictionary controlling processing behavior. Set at initialization", "data_cache": "Cache of processed items to avoid recomputation. Cleared with reset()", "is_initialized": "Flag indicating if setup() has been called successfully", "stats": "Counters tracking number of items processed, errors, cache hits etc", "_lock": "Thread lock ensuring thread-safe access to shared resources" }, ScoreLevel.VERY_GOOD: { "config": "Configuration dictionary controlling processing behavior. Set at initialization and accessed by all processing methods. Read-only after initialization", "data_cache": "Cache of processed items to avoid recomputation. Cleared with reset(). Keys are item IDs, values are processed results", "is_initialized": "Flag indicating if setup() has been called successfully. Methods will raise RuntimeError if called before initialization", "stats": "Counters tracking processing metrics (items processed, errors, cache hits etc). Updated by process() and reset by clear_stats()", "_lock": "Thread lock ensuring thread-safe access to cache and stats. Used internally by all public methods" }, ScoreLevel.EXCELLENT: { "config": "Configuration dictionary controlling processing behavior. Set at initialization and accessed by all processing methods. Read-only after initialization. Must contain 'batch_size' and 'max_cache_size' keys. See CONFIG_SCHEMA for full specification", "data_cache": "Cache of processed items to avoid recomputation. Cleared with reset(). Keys are item IDs, values are processed results. Limited to max_cache_size items with LRU eviction. Thread-safe access via _lock", "is_initialized": "Flag indicating if setup() has been called successfully. Methods will raise RuntimeError if called before initialization. Set to True by setup() and False by reset(). Thread-safe access via _lock", "stats": "Counters tracking processing metrics (items processed, errors, cache hits etc). Updated by process() and reset by clear_stats(). Access via get_stats() for thread-safe snapshot. Used for monitoring and auto-scaling decisions", "_lock": "Thread lock ensuring thread-safe access to cache and stats. Used internally by all public methods. Reentrant lock allowing nested acquisition by same thread. Consider using async methods for high-concurrency scenarios" } }, explanations={ ScoreLevel.POOR: "These descriptions merely restate the attribute types without adding value", ScoreLevel.FAIR: "Provides basic purpose but lacks lifecycle and usage guidance", ScoreLevel.GOOD: "Includes initialization context and some usage patterns but could be more comprehensive", ScoreLevel.VERY_GOOD: "Clear purpose, initialization, and usage patterns with thread-safety context", ScoreLevel.EXCELLENT: "Comprehensive guidance including constraints, thread-safety, and practical usage tips" } ) ] def get_evaluation_prompt(self, class_signature: str, init_function: str, attribute_descriptions: Dict[str, str]) -> str: """ Generates a prompt for LLM evaluation of attribute descriptions. Args: class_signature: The complete class signature. init_function: The complete __init__ function including docstring. attribute_descriptions: Dict mapping attribute names to their descriptions. Returns: A formatted prompt string that can be sent to an LLM for evaluation. """ example = self.examples[0] # Use first example as reference prompt = [ "Please evaluate the following Python docstring attribute descriptions based on these criteria:", "", "", f"Class signature:\n{class_signature}", "", f"Init function:\n{init_function}", "", "", "", "Attribute descriptions to evaluate:", ] for attr, desc in attribute_descriptions.items(): prompt.append(f"{attr}: {desc}") prompt.append("") prompt.extend([ "", "", "Evaluation criteria:", self.criteria['description'], "", "Score levels:", ]) # Add criteria for each score level for level in ScoreLevel: prompt.append(f"{level.value}. {self.criteria['score_criteria'][level]}") prompt.append("") # Add example prompt.extend([ "", "", "Example for reference:", f"Class: {example.class_signature}", f"Init:\n{example.init_function}", "", "Attribute descriptions at different quality levels:", ]) for level in ScoreLevel: prompt.extend([ f"Level {level.value}:", *[f"{attr}: {desc}" for attr, desc in example.quality_examples[level].items()], f"Explanation: {example.explanations[level]}", "" ]) prompt.append("") prompt.extend([ "", "", "IMPORTANT INSTRUCTIONS FOR ANALYSIS:", "1. Analyze how well each attribute description provides meaningful information beyond type hints", "2. Consider completeness of lifecycle documentation (initialization, modification, access patterns)", "3. Look for helpful context about relationships with class behavior", "4. Check for thread-safety and special handling documentation where relevant", "", "", "", "Please structure your response as follows:", "1. Analyze each attribute description's strengths and weaknesses", "2. Compare against the criteria and example quality levels", "3. Suggest specific improvements for weaker descriptions", "4. Provide your score (1-5) enclosed in tags", "", "", "Remember: Do not rush to assign a score. Take time to analyze thoroughly and justify your reasoning.", "No need to provide Suggestions for Improvement", "The score should reflect your careful analysis and should be the last part of your response." ]) return "\n".join(prompt) def parse_llm_response(self, response: str) -> Tuple[int, str]: """ Extracts the numerical score and full analysis from an LLM's response. Args: response: The complete response text from the LLM. Returns: A tuple containing: - The numerical score (1-5) - The full analysis text Raises: ValueError: If no valid score is found or if multiple scores are found. """ # Extract score from XML tags score_matches = re.findall(r'(\d)', response) if not score_matches: raise ValueError("No valid score found in LLM response. Response must include a score in tags.") if len(score_matches) > 1: raise ValueError("Multiple scores found in LLM response. Expected exactly one score.") score = int(score_matches[0]) if score < 1 or score > 5: raise ValueError(f"Invalid score value: {score}. Score must be between 1 and 5.") # Remove the score tags from the analysis text analysis = re.sub(r'\d', '', response).strip() return score, analysis def get_criteria_description(self) -> str: """Returns the main criteria description.""" return self.criteria['description'] def get_score_criteria(self, level: ScoreLevel) -> str: """Returns the criteria description for a specific score level.""" return self.criteria['score_criteria'][level] def get_examples(self) -> List[EvaluationExample]: """Returns all evaluation examples.""" return self.examples ================================================ FILE: src/evaluator/helpfulness_description.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, List, Tuple from dataclasses import dataclass from enum import Enum import re from src.evaluator.evaluation_common import ScoreLevel class DescriptionAspect(Enum): """Defines the different aspects of docstring description evaluation.""" MOTIVATION = "motivation" USAGE_SCENARIOS = "usage_scenarios" INTEGRATION = "integration" FUNCTIONALITY = "functionality" @dataclass class AspectCriteria: """Stores criteria for a single evaluation aspect.""" description: str score_criteria: Dict[ScoreLevel, str] example_good: str example_poor: str class DocstringDescriptionEvaluator: """ Evaluates the quality of Python docstring descriptions across multiple aspects. This evaluator analyzes docstring descriptions based on four key aspects: 1. Motivation/Purpose explanation 2. Usage scenarios and conditions 3. System integration and interactions 4. Functionality overview Each aspect is scored independently on a scale of 1-5, providing a comprehensive assessment of the description's effectiveness. """ def __init__(self): """Initialize the evaluator with predefined criteria for each aspect.""" self.criteria = self._initialize_criteria() def _initialize_criteria(self) -> Dict[DescriptionAspect, AspectCriteria]: """ Set up the evaluation criteria for each aspect of docstring descriptions. Returns: Dictionary mapping aspects to their evaluation criteria. """ return { DescriptionAspect.MOTIVATION: AspectCriteria( description="How well does the description explain the reason or motivation behind the code?", score_criteria={ ScoreLevel.POOR: "No explanation of why the code exists or its purpose", ScoreLevel.FAIR: "Basic purpose stated but without context or reasoning", ScoreLevel.GOOD: "Clear explanation of purpose with some context", ScoreLevel.VERY_GOOD: "Thorough explanation of purpose with business/technical context", ScoreLevel.EXCELLENT: "Comprehensive explanation of purpose, context, and value proposition" }, example_good=( "This cache manager addresses the performance bottleneck in our API " "responses by reducing database load during peak hours, while ensuring " "data freshness for critical operations." ), example_poor="This is a cache manager for storing data." ), DescriptionAspect.USAGE_SCENARIOS: AspectCriteria( description="How effectively does it describe when and how to use the code?", score_criteria={ ScoreLevel.POOR: "No information about usage scenarios", ScoreLevel.FAIR: "Basic usage information without specific scenarios", ScoreLevel.GOOD: "Some key usage scenarios described", ScoreLevel.VERY_GOOD: "Detailed usage scenarios with common cases", ScoreLevel.EXCELLENT: "Comprehensive coverage of use cases, including edge cases" }, example_good=( "Use this validator when processing user-submitted data, especially " "for high-stakes operations like financial transactions. It handles " "various edge cases including partial submissions and legacy formats." ), example_poor="Validates data according to rules." ), DescriptionAspect.INTEGRATION: AspectCriteria( description="How well does it explain integration with other system components?", score_criteria={ ScoreLevel.POOR: "No mention of system integration", ScoreLevel.FAIR: "Minimal reference to other components", ScoreLevel.GOOD: "Basic explanation of main interactions", ScoreLevel.VERY_GOOD: "Clear description of integration points and dependencies", ScoreLevel.EXCELLENT: "Comprehensive overview of system interactions and data flow" }, example_good=( "This service interfaces with the UserAuth system for validation, " "writes logs to CloudWatch, and triggers notifications through SNS. " "It serves as a crucial link between the frontend and payment processor." ), example_poor="Processes data and sends it to other services." ), DescriptionAspect.FUNCTIONALITY: AspectCriteria( description="How clearly does it explain the functionality without excessive technical detail?", score_criteria={ ScoreLevel.POOR: "No explanation of functionality", ScoreLevel.FAIR: "Overly technical or vague explanation", ScoreLevel.GOOD: "Basic explanation of main functionality", ScoreLevel.VERY_GOOD: "Clear, balanced explanation of functionality", ScoreLevel.EXCELLENT: "Perfect balance of clarity and technical detail" }, example_good=( "Processes incoming customer data by first validating format and required fields, " "then enriching with relevant historical data, and finally " "generating risk scores using configurable criteria." ), example_poor="Processes data using various functions and algorithms." ) } def get_evaluation_prompt(self, code_implementation: str, docstring: str, eval_type: str = None) -> str: """ Generates a prompt for LLM evaluation of docstring descriptions. Args: code_implementation: The function or class implementation docstring: The docstring to evaluate eval_type: The type of code component (class, function, method). If not provided, it will be determined from code_implementation. Returns: Prompt for LLM evaluation """ # Determine eval_type if not provided if eval_type is None: if code_implementation.strip().startswith("class "): eval_type = "class" else: eval_type = "function" if "self" not in code_implementation.split("(")[0] else "method" # Extract description from docstring (everything after the summary) description = self._extract_description(docstring) if not description: return "The docstring does not have a description section to evaluate." prompt = ["# Docstring Description Evaluation", ""] prompt.extend([ "## Code Component", f"```python", f"{code_implementation}", f"```", "", ]) prompt.extend([ "## Docstring Description to Evaluate", f"```", f"{description}", f"```", "", ]) # Add evaluation criteria prompt.extend([ "## Evaluation Criteria", "Please evaluate the above docstring description across these four aspects:", "" ]) for aspect in DescriptionAspect: criteria = self.criteria[aspect] prompt.extend([ f"### {aspect.value.title()}", f"{criteria.description}", "", "Score levels:", "", ]) for level in ScoreLevel: prompt.append(f"{level.value}. {criteria.score_criteria[level]}") prompt.extend([ "", "Examples:", f"Good: \"{criteria.example_good}\"", f"Poor: \"{criteria.example_poor}\"", "", ]) # Add output format instructions prompt.extend([ "## Output Format", "Please evaluate the description and provide your assessment in this format:", "", "```", "Motivation: [score 1-5]", "Usage Scenarios: [score 1-5]", "Integration: [score 1-5]", "Functionality: [score 1-5]", "", "Overall: [average of the scores, rounded to nearest integer]", "", "Suggestions: [2-3 concrete suggestions for improvement focusing on the weakest aspects]", "```", ]) return "\n".join(prompt) def parse_llm_response(self, response: str) -> Tuple[int, str]: """ Extracts scores and suggestions from an LLM's response. Args: response: The complete response text from the LLM. Returns: Tuple of (overall_score, suggestions) Raises: ValueError: If required information is missing or invalid. """ # Default score if we can't find explicit scores default_score = 3 # If the response indicates no description section if "docstring does not have a description section" in response: return default_score, "Add a description section to the docstring." # Try to extract an overall score first (easiest) overall_pattern = r"Overall:\s*\[?(\d)\.?\d*\]?" overall_matches = re.findall(overall_pattern, response, re.IGNORECASE) if overall_matches: overall_score = int(overall_matches[0]) else: # If we can't find an explicit overall score, use a default overall_score = default_score # Extract suggestions # Look for several common patterns suggestion_patterns = [ r"Suggestions:\s*(.+?)(?:\n\n|\Z)", # Format in prompt r"(.*?)", # XML tags r"suggestions?:?\s*\n\s*(.+?)(?:\n\n|\Z)", # Common formats ] for pattern in suggestion_patterns: suggestion_matches = re.findall(pattern, response, re.DOTALL | re.IGNORECASE) if suggestion_matches: suggestion = suggestion_matches[0].strip() break else: # Default suggestion if none found suggestion = "Consider adding more detail to the description section." return overall_score, suggestion def _extract_description(self, docstring: str) -> str: """ Extract the description part from a docstring. The description is everything after the summary line (first line) and before any parameter sections, return sections, etc. Args: docstring: The complete docstring Returns: The extracted description, or empty string if none found """ if not docstring: return "" # Split into lines and remove empty lines at start/end lines = [line.strip() for line in docstring.strip().split('\n')] if not lines: return "" # Skip the first line (summary) lines = lines[1:] # Find where the parameters section or other sections begin section_markers = ['Args:', 'Parameters:', 'Arguments:', 'Returns:', 'Raises:', 'Yields:', 'Examples:'] description_lines = [] for line in lines: # Stop if we hit a section marker if any(line.strip().startswith(marker) for marker in section_markers): break description_lines.append(line) # Join and strip to get the description description = '\n'.join(description_lines).strip() return description ================================================ FILE: src/evaluator/helpfulness_evaluator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import json import random import os import sys from pathlib import Path from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass # Add the project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from src.evaluator.helpfulness_summary import DocstringSummaryEvaluator from src.evaluator.helpfulness_description import DocstringDescriptionEvaluator from src.evaluator.helpfulness_parameters import DocstringParametersEvaluator from src.agent.llm.openai_llm import OpenAILLM @dataclass class EvaluationResult: """Store the results of a single evaluation.""" system: str component_id: str aspect: str score: int suggestion: str class DocstringHelpfulnessEvaluator: """Evaluates the helpfulness of docstrings generated by different systems.""" SYSTEMS = [ "copy_paste_codellama34b", "copy_paste_gpt4o_mini", "docassist-codellama34b", "docassist-gpt4o_mini", "fim-codellama13b", ] ASPECTS = ["summary", "description", "parameters"] def __init__(self, data_path: str, output_dir: str, api_key: str, model: str = "gpt-4o"): """Initialize the evaluator. Args: data_path: Path to the completeness evaluation data output_dir: Directory to store evaluation results api_key: OpenAI API key model: LLM model to use for evaluation """ self.data_path = data_path self.output_dir = output_dir self.llm = OpenAILLM(api_key=api_key, model=model) # Initialize evaluators for each aspect self.evaluators = { "summary": DocstringSummaryEvaluator(), "description": DocstringDescriptionEvaluator(), "parameters": DocstringParametersEvaluator() } # Load evaluation data with open(self.data_path, 'r') as f: self.data = json.load(f) # Create output directory if it doesn't exist os.makedirs(self.output_dir, exist_ok=True) def sample_components(self, n: int = 50, seed: int = 42) -> List[str]: """Randomly sample code components where all systems have valid docstrings. Args: n: Number of components to sample seed: Random seed for reproducibility Returns: List of component IDs """ random.seed(seed) # Filter components where all systems have valid docstrings valid_components = [] for component_id, component_data in self.data.items(): # Check if all systems have docstrings has_all_docstrings = True for system in self.SYSTEMS: if system not in component_data.get("docstrings", {}): has_all_docstrings = False break # Check if docstring is not empty docstring = component_data["docstrings"].get(system, {}).get("docstring", "") if not docstring or docstring == "example string": has_all_docstrings = False break if has_all_docstrings: valid_components.append(component_id) # Sample n components if len(valid_components) < n: print(f"Warning: Only {len(valid_components)} components have valid docstrings for all systems") return valid_components return random.sample(valid_components, n) def evaluate_component(self, component_id: str) -> List[EvaluationResult]: """Evaluate docstrings from all systems for a given component. Args: component_id: Component ID Returns: List of evaluation results """ component_data = self.data[component_id] results = [] component_type = component_data.get("type", "function") source_code = component_data.get("source_code", "") for system in self.SYSTEMS: if system not in component_data.get("docstrings", {}): continue system_data = component_data["docstrings"][system] docstring = system_data.get("docstring", "") # Skip if docstring is empty or the example placeholder if not docstring or docstring == "example string": continue print(f" Evaluating system: {system}") # Evaluate each aspect for aspect in self.ASPECTS: # Check if the aspect is present in the docstring element_scores = system_data.get("element_scores", {}) if aspect not in element_scores or not element_scores[aspect]: print(f" Skipping aspect '{aspect}' - not present in docstring") continue print(f" Evaluating aspect: {aspect}") try: # Get the evaluator for this aspect evaluator = self.evaluators[aspect] # Create prompt for evaluation prompt = evaluator.get_evaluation_prompt(source_code, docstring, component_type) # Call LLM for evaluation messages = [ self.llm.format_message("system", "You are an expert docstring quality evaluator."), self.llm.format_message("user", prompt) ] response = self.llm.generate(messages, temperature=0.1, max_tokens=1024) # Parse response score, suggestion = evaluator.parse_llm_response(response) print(f" Score: {score}") # Store result result = EvaluationResult( system=system, component_id=component_id, aspect=aspect, score=score, suggestion=suggestion ) results.append(result) except Exception as e: print(f" Error evaluating {aspect}: {str(e)}") # Continue with other evaluations return results def run_evaluation(self, n_samples: int = 50, seed: int = 42) -> Dict[str, Any]: """Run the helpfulness evaluation on sampled components. Args: n_samples: Number of components to sample seed: Random seed for reproducibility Returns: Evaluation results """ # Sample components component_ids = self.sample_components(n_samples, seed) # Evaluate each component all_results = [] for component_id in component_ids: print(f"Evaluating component: {component_id}") results = self.evaluate_component(component_id) all_results.extend(results) # Organize results results_dict = { "metadata": { "n_samples": len(component_ids), "seed": seed, "systems": self.SYSTEMS, "aspects": self.ASPECTS }, "component_ids": component_ids, "results": [ { "system": r.system, "component_id": r.component_id, "aspect": r.aspect, "score": r.score, "suggestion": r.suggestion } for r in all_results ] } # Save results to file output_path = os.path.join(self.output_dir, "helpfulness_evaluation_results.json") with open(output_path, 'w') as f: json.dump(results_dict, f, indent=2) # Generate statistics stats = self.calculate_statistics(results_dict) # Save statistics to file stats_path = os.path.join(self.output_dir, "helpfulness_evaluation_stats.md") with open(stats_path, 'w') as f: f.write(self.format_statistics_markdown(stats)) return results_dict def calculate_statistics(self, results: Dict[str, Any]) -> Dict[str, Any]: """Calculate statistics from evaluation results. Args: results: Evaluation results Returns: Statistics """ stats = { "overall": {}, "by_system": {}, "by_aspect": {}, "by_system_and_aspect": {} } # Calculate overall average scores = [r["score"] for r in results["results"]] stats["overall"]["average_score"] = sum(scores) / len(scores) if scores else 0 stats["overall"]["count"] = len(scores) # Calculate average by system for system in self.SYSTEMS: system_scores = [r["score"] for r in results["results"] if r["system"] == system] stats["by_system"][system] = { "average_score": sum(system_scores) / len(system_scores) if system_scores else 0, "count": len(system_scores) } # Calculate average by aspect for aspect in self.ASPECTS: aspect_scores = [r["score"] for r in results["results"] if r["aspect"] == aspect] stats["by_aspect"][aspect] = { "average_score": sum(aspect_scores) / len(aspect_scores) if aspect_scores else 0, "count": len(aspect_scores) } # Calculate average by system and aspect for system in self.SYSTEMS: stats["by_system_and_aspect"][system] = {} for aspect in self.ASPECTS: scores = [r["score"] for r in results["results"] if r["system"] == system and r["aspect"] == aspect] stats["by_system_and_aspect"][system][aspect] = { "average_score": sum(scores) / len(scores) if scores else 0, "count": len(scores) } return stats def format_statistics_markdown(self, stats: Dict[str, Any]) -> str: """Format statistics as markdown. Args: stats: Statistics Returns: Markdown representation of statistics """ md = "# Docstring Helpfulness Evaluation Results\n\n" # Overall statistics md += "## Overall Statistics\n\n" md += f"- Average Score: {stats['overall']['average_score']:.2f}\n" md += f"- Number of Evaluations: {stats['overall']['count']}\n\n" # By system md += "## Results by System\n\n" md += "| System | Average Score | Count |\n" md += "| ------ | ------------- | ----- |\n" for system, system_stats in stats["by_system"].items(): md += f"| {system} | {system_stats['average_score']:.2f} | {system_stats['count']} |\n" md += "\n" # By aspect md += "## Results by Aspect\n\n" md += "| Aspect | Average Score | Count |\n" md += "| ------ | ------------- | ----- |\n" for aspect, aspect_stats in stats["by_aspect"].items(): md += f"| {aspect} | {aspect_stats['average_score']:.2f} | {aspect_stats['count']} |\n" md += "\n" # By system and aspect md += "## Results by System and Aspect\n\n" md += "| System | Aspect | Average Score | Count |\n" md += "| ------ | ------ | ------------- | ----- |\n" for system, aspects in stats["by_system_and_aspect"].items(): for aspect, aspect_stats in aspects.items(): md += f"| {system} | {aspect} | {aspect_stats['average_score']:.2f} | {aspect_stats['count']} |\n" return md def main(): """Run the docstring helpfulness evaluation.""" # Configuration data_path = "experiments/eval/results/completeness_evaluation_cleaned.json" output_dir = "experiments/eval/results/helpfulness" # Get API key from config with open("config/agent_config.yaml", 'r') as f: config = yaml.safe_load(f) api_key = config["llm"]["api_key"] model = config["llm"]["model"] # Run evaluation evaluator = DocstringHelpfulnessEvaluator(data_path, output_dir, api_key, model) evaluator.run_evaluation(n_samples=50, seed=42) if __name__ == "__main__": import yaml main() ================================================ FILE: src/evaluator/helpfulness_evaluator_ablation.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import json import random import os import sys from pathlib import Path from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass # Add the project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from src.evaluator.helpfulness_summary import DocstringSummaryEvaluator from src.evaluator.helpfulness_description import DocstringDescriptionEvaluator from src.evaluator.helpfulness_parameters import DocstringParametersEvaluator from src.agent.llm.openai_llm import OpenAILLM @dataclass class EvaluationResult: """Store the results of a single evaluation.""" system: str component_id: str aspect: str score: int suggestion: str class DocstringHelpfulnessEvaluatorAblation: """Evaluates the helpfulness of docstrings generated by different systems.""" SYSTEMS = [ "docassist-codellama34b-random-file", "docassist-codellama34b-random-node", "docassist-gpt4o_mini-random-file", "docassist-gpt4o_mini-random-node", "docassist-codellama34b", "docassist-gpt4o_mini", ] ASPECTS = ["summary", "description", "parameters"] def __init__(self, data_path: str, output_dir: str, api_key: str, model: str = "gpt-4o"): """Initialize the evaluator. Args: data_path: Path to the completeness evaluation data output_dir: Directory to store evaluation results api_key: OpenAI API key model: LLM model to use for evaluation """ self.data_path = data_path self.output_dir = output_dir self.llm = OpenAILLM(api_key=api_key, model=model) # Initialize evaluators for each aspect self.evaluators = { "summary": DocstringSummaryEvaluator(), "description": DocstringDescriptionEvaluator(), "parameters": DocstringParametersEvaluator() } # Load evaluation data with open(self.data_path, 'r') as f: self.data = json.load(f) # Create output directory if it doesn't exist os.makedirs(self.output_dir, exist_ok=True) def sample_components(self, n: Optional[int] = 50, seed: int = 42) -> List[str]: """Randomly sample code components where all systems have valid docstrings. Args: n: Number of components to sample. If None, return all valid components. seed: Random seed for reproducibility Returns: List of component IDs """ random.seed(seed) # Filter components where all systems have valid docstrings valid_components = [] for component_id, component_data in self.data.items(): # Check if all systems have docstrings has_all_docstrings = True for system in self.SYSTEMS: if system not in component_data.get("docstrings", {}): has_all_docstrings = False break # Check if docstring is not empty docstring = component_data["docstrings"].get(system, {}).get("docstring", "") if not docstring or docstring == "example string": has_all_docstrings = False break if has_all_docstrings: valid_components.append(component_id) # If n is None, return all valid components if n is None: print(f"Using all {len(valid_components)} valid components") return valid_components # Sample n components if len(valid_components) < n: print(f"Warning: Only {len(valid_components)} components have valid docstrings for all systems") return valid_components return random.sample(valid_components, n) def evaluate_component(self, component_id: str) -> List[EvaluationResult]: """Evaluate docstrings from all systems for a given component. Args: component_id: Component ID Returns: List of evaluation results """ component_data = self.data[component_id] results = [] component_type = component_data.get("type", "function") source_code = component_data.get("source_code", "") for system in self.SYSTEMS: if system not in component_data.get("docstrings", {}): continue system_data = component_data["docstrings"][system] docstring = system_data.get("docstring", "") # Skip if docstring is empty or the example placeholder if not docstring or docstring == "example string": continue print(f" Evaluating system: {system}") # Evaluate each aspect for aspect in self.ASPECTS: # Check if the aspect is present in the docstring element_scores = system_data.get("element_scores", {}) if aspect not in element_scores or not element_scores[aspect]: print(f" Skipping aspect '{aspect}' - not present in docstring") continue print(f" Evaluating aspect: {aspect}") try: # Get the evaluator for this aspect evaluator = self.evaluators[aspect] # Create prompt for evaluation prompt = evaluator.get_evaluation_prompt(source_code, docstring, component_type) # Call LLM for evaluation messages = [ self.llm.format_message("system", "You are an expert docstring quality evaluator."), self.llm.format_message("user", prompt) ] response = self.llm.generate(messages, temperature=0.1, max_tokens=1024) # Parse response score, suggestion = evaluator.parse_llm_response(response) print(f" Score: {score}") # Store result result = EvaluationResult( system=system, component_id=component_id, aspect=aspect, score=score, suggestion=suggestion ) results.append(result) except Exception as e: print(f" Error evaluating {aspect}: {str(e)}") # Continue with other evaluations return results def run_evaluation(self, n_samples: int = 50, seed: int = 42) -> Dict[str, Any]: """Run the helpfulness evaluation on sampled components. Args: n_samples: Number of components to sample seed: Random seed for reproducibility Returns: Evaluation results """ # Sample components component_ids = self.sample_components(n_samples, seed) # Evaluate each component all_results = [] for component_id in component_ids: print(f"Evaluating component: {component_id}") results = self.evaluate_component(component_id) all_results.extend(results) # Organize results results_dict = { "metadata": { "n_samples": len(component_ids), "seed": seed, "systems": self.SYSTEMS, "aspects": self.ASPECTS }, "component_ids": component_ids, "results": [ { "system": r.system, "component_id": r.component_id, "aspect": r.aspect, "score": r.score, "suggestion": r.suggestion } for r in all_results ] } # Save results to file output_path = os.path.join(self.output_dir, "helpfulness_evaluation_results.json") with open(output_path, 'w') as f: json.dump(results_dict, f, indent=2) # Generate statistics stats = self.calculate_statistics(results_dict) # Save statistics to file stats_path = os.path.join(self.output_dir, "helpfulness_evaluation_stats.md") with open(stats_path, 'w') as f: f.write(self.format_statistics_markdown(stats)) return results_dict def calculate_statistics(self, results: Dict[str, Any]) -> Dict[str, Any]: """Calculate statistics from evaluation results. Args: results: Evaluation results Returns: Statistics """ stats = { "overall": {}, "by_system": {}, "by_aspect": {}, "by_system_and_aspect": {} } # Calculate overall average scores = [r["score"] for r in results["results"]] stats["overall"]["average_score"] = sum(scores) / len(scores) if scores else 0 stats["overall"]["count"] = len(scores) # Calculate average by system for system in self.SYSTEMS: system_scores = [r["score"] for r in results["results"] if r["system"] == system] stats["by_system"][system] = { "average_score": sum(system_scores) / len(system_scores) if system_scores else 0, "count": len(system_scores) } # Calculate average by aspect for aspect in self.ASPECTS: aspect_scores = [r["score"] for r in results["results"] if r["aspect"] == aspect] stats["by_aspect"][aspect] = { "average_score": sum(aspect_scores) / len(aspect_scores) if aspect_scores else 0, "count": len(aspect_scores) } # Calculate average by system and aspect for system in self.SYSTEMS: stats["by_system_and_aspect"][system] = {} for aspect in self.ASPECTS: scores = [r["score"] for r in results["results"] if r["system"] == system and r["aspect"] == aspect] stats["by_system_and_aspect"][system][aspect] = { "average_score": sum(scores) / len(scores) if scores else 0, "count": len(scores) } return stats def format_statistics_markdown(self, stats: Dict[str, Any]) -> str: """Format statistics as markdown. Args: stats: Statistics Returns: Markdown representation of statistics """ md = "# Docstring Helpfulness Evaluation Results\n\n" # Overall statistics md += "## Overall Statistics\n\n" md += f"- Average Score: {stats['overall']['average_score']:.2f}\n" md += f"- Number of Evaluations: {stats['overall']['count']}\n\n" # By system md += "## Results by System\n\n" md += "| System | Average Score | Count |\n" md += "| ------ | ------------- | ----- |\n" for system, system_stats in stats["by_system"].items(): md += f"| {system} | {system_stats['average_score']:.2f} | {system_stats['count']} |\n" md += "\n" # By aspect md += "## Results by Aspect\n\n" md += "| Aspect | Average Score | Count |\n" md += "| ------ | ------------- | ----- |\n" for aspect, aspect_stats in stats["by_aspect"].items(): md += f"| {aspect} | {aspect_stats['average_score']:.2f} | {aspect_stats['count']} |\n" md += "\n" # By system and aspect md += "## Results by System and Aspect\n\n" md += "| System | Aspect | Average Score | Count |\n" md += "| ------ | ------ | ------------- | ----- |\n" for system, aspects in stats["by_system_and_aspect"].items(): for aspect, aspect_stats in aspects.items(): md += f"| {system} | {aspect} | {aspect_stats['average_score']:.2f} | {aspect_stats['count']} |\n" return md def main(): """Run the docstring helpfulness evaluation.""" # Configuration data_path = "experiments/eval/results/completeness_evaluation_ablation_cleaned.json" output_dir = "experiments/eval/results/helpfulness_ablation" # Get API key from config with open("config/agent_config.yaml", 'r') as f: config = yaml.safe_load(f) api_key = config["llm"]["api_key"] model = config["llm"]["model"] # Run evaluation evaluator = DocstringHelpfulnessEvaluatorAblation(data_path, output_dir, api_key, model) evaluator.run_evaluation(n_samples=50, seed=42) if __name__ == "__main__": import yaml main() ================================================ FILE: src/evaluator/helpfulness_examples.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, List, Optional, Tuple, Union from dataclasses import dataclass from abc import ABC, abstractmethod import ast import re def get_callable_name(node: Union[ast.Name, ast.Attribute]) -> str: """ Extract the name of a callable whether it's an ast.Name or ast.Attribute. """ if isinstance(node, ast.Name): # e.g., "my_function" return node.id elif isinstance(node, ast.Attribute): # e.g., "some_module.my_function" # node.value.id -> "some_module", node.attr -> "my_function" return node.attr else: raise ValueError(f"Unsupported node type for function/class: {type(node)}") @dataclass class FunctionCallExample: """Stores an example of function usage with context and expected output.""" context_code: str # Code leading up to the function call function_signature: str # The complete function signature docstring_example: str # Only the example part of the docstring expected_call: str # The expected function call line(s) @dataclass class ClassCallExample: """Stores an example of class instantiation with context and expected output.""" context_code: str # Code leading up to class instantiation class_signature: str # The class signature init_signature: str # The __init__ method signature docstring_example: str # Only the example part of the docstring expected_call: str # The expected instantiation line(s) @dataclass class MethodCallExample: """Stores an example of method usage with context and expected output.""" context_code: str # Code leading up to method call method_signature: str # The method signature docstring_example: str # Only the example part of the docstring expected_call: str # The expected method call line(s) class BaseExampleEvaluator(ABC): """ Base class for evaluating docstring examples. This class provides the foundation for evaluating how well docstring examples enable users to correctly use the code without needing to understand its implementation. """ @abstractmethod def get_evaluation_prompt(self, context_code: str, signature: str, example: str) -> str: """ Generates a prompt for LLM to predict the next line(s) of code. Args: context_code: The code leading up to where the prediction should be made signature: The complete signature of the function/class/method example: The example part of the docstring Returns: A formatted prompt string for the LLM """ pass @abstractmethod def evaluate_prediction(self, prediction: str, ground_truth: str) -> Tuple[bool, str]: """ Evaluates if the predicted usage matches the ground truth. Args: prediction: The LLM's predicted line(s) of code ground_truth: The expected line(s) of code Returns: A tuple containing: - Boolean indicating if the prediction is correct - String explaining the evaluation result """ pass class FunctionExampleEvaluator(BaseExampleEvaluator): """ Evaluates the quality of function docstring examples by testing if they enable correct function usage prediction. """ def get_evaluation_prompt(self, context_code: str, signature: str, example: str) -> str: """ Generates a prompt for LLM to predict the next line of function usage. Args: context_code: The code leading up to the function call signature: The complete function signature example: The example part of the docstring Returns: A formatted prompt string that can be sent to an LLM for prediction """ prompt = [ "Given the following context, predict ONLY the next line of code that calls the function.", "Your prediction should be based solely on the function signature and example provided.", "", "Function signature:", signature, "", "Example from docstring:", example, "", "Context code leading up to function call:", context_code, "", "IMPORTANT INSTRUCTIONS:", "1. Predict ONLY the next line(s) that calls the function", "2. Base your prediction solely on the signature and example", "3. Include ONLY the function call, no additional explanation", "4. If the function call spans multiple lines, include all necessary lines", "5. Ensure the prediction is valid Python syntax", "", "Your prediction should be enclosed in tags", ] return "\n".join(prompt) def evaluate_prediction(self, prediction: str, ground_truth: str) -> Tuple[bool, str]: """ Evaluates if the predicted function call matches the ground truth. Performs robust parsing of both prediction and ground truth to compare: 1. Function name 2. Argument names and their order 3. Argument values (when they are literals) Args: prediction: The LLM's predicted function call ground_truth: The expected function call Returns: Tuple containing: - Boolean indicating if the prediction is correct - String explaining why the prediction was correct or incorrect """ # Parse both prediction and ground truth into AST pred_ast = ast.parse(prediction.strip()).body[0].value truth_ast = ast.parse(ground_truth.strip()).body[0].value # Verify it's a function call if not isinstance(pred_ast, ast.Call) or not isinstance(truth_ast, ast.Call): return False, "Not a valid function call" # Check function name pred_name = get_callable_name(pred_ast.func) truth_name = get_callable_name(truth_ast.func) if pred_name != truth_name: return False, f"Mismatch: expected '{truth_name}', got '{pred_name}'" # Get argument information pred_args = { kw.arg: kw.value for kw in pred_ast.keywords } truth_args = { kw.arg: kw.value for kw in truth_ast.keywords } # Check positional arguments if len(pred_ast.args) != len(truth_ast.args): return False, "Mismatched number of positional arguments" # Check keyword arguments if set(pred_args.keys()) != set(truth_args.keys()): return False, "Mismatched keyword argument names" # Check argument order for positional args for i, (p_arg, t_arg) in enumerate(zip(pred_ast.args, truth_ast.args)): if not self._compare_ast_nodes(p_arg, t_arg): return False, f"Positional argument {i+1} mismatch" # Check keyword argument values for arg_name, t_value in truth_args.items(): p_value = pred_args[arg_name] if not self._compare_ast_nodes(p_value, t_value): return False, f"Keyword argument '{arg_name}' value mismatch" return True, "Function call matches expected usage" def _compare_ast_nodes(self, node1: ast.AST, node2: ast.AST) -> bool: """ Helper method to compare two AST nodes. Args: node1: First AST node node2: Second AST node Returns: Boolean indicating if the nodes are equivalent """ # For literals (strings, numbers, etc.) if isinstance(node1, (ast.Str, ast.Num, ast.NameConstant)): return isinstance(node2, type(node1)) and node1.value == node2.value # For variable names if isinstance(node1, ast.Name) and isinstance(node2, ast.Name): return node1.id == node2.id # For attribute access (e.g., obj.attr) if isinstance(node1, ast.Attribute) and isinstance(node2, ast.Attribute): return node1.attr == node2.attr and self._compare_ast_nodes(node1.value, node2.value) # For lists/tuples if isinstance(node1, (ast.List, ast.Tuple)) and isinstance(node2, type(node1)): if len(node1.elts) != len(node2.elts): return False return all(self._compare_ast_nodes(e1, e2) for e1, e2 in zip(node1.elts, node2.elts)) return False class ClassExampleEvaluator(BaseExampleEvaluator): """ Evaluates the quality of class docstring examples by testing if they enable correct class instantiation prediction. """ def get_evaluation_prompt(self, context_code: str, signature: str, example: str) -> str: """ Generates a prompt for LLM to predict the class instantiation line. Args: context_code: The code leading up to class instantiation signature: Combined class and __init__ signatures example: The example part of the docstring Returns: A formatted prompt string that can be sent to an LLM for prediction """ prompt = [ "Given the following context, predict ONLY the next line of code that creates a class instance.", "Your prediction should be based solely on the class signature and example provided.", "", "Class and __init__ signatures:", signature, "", "Example from docstring:", example, "", "Context code leading up to class instantiation:", context_code, "", "IMPORTANT INSTRUCTIONS:", "1. Predict ONLY the next line(s) that creates the class instance", "2. Base your prediction solely on the signatures and example", "3. Include ONLY the instantiation code, no additional explanation", "4. If the instantiation spans multiple lines, include all necessary lines", "5. Ensure the prediction is valid Python syntax", "", "Your prediction should be enclosed in tags", ] return "\n".join(prompt) def _compare_ast_nodes(self, node1: ast.AST, node2: ast.AST) -> bool: """ Example placeholder comparison method. You should implement your logic based on how you want to compare constant values, variable references, etc. """ if isinstance(node1, ast.Constant) and isinstance(node2, ast.Constant): return node1.value == node2.value # Extend your comparison logic here (e.g., for lists, dicts, names, etc.) return ast.dump(node1) == ast.dump(node2) def _get_func_name(self, node: Union[ast.Name, ast.Attribute]) -> str: """ Extract the function/class name whether it's `Name` or `Attribute`. - ast.Name: directly has `node.id` - ast.Attribute: the class name is in `node.attr` (e.g. `some_module.MyClass`) """ if isinstance(node, ast.Name): return node.id elif isinstance(node, ast.Attribute): return node.attr else: # If your code can handle more node types, add logic here raise ValueError(f"Unsupported node type for function/class: {type(node)}") def evaluate_prediction(self, prediction: str, ground_truth: str) -> Tuple[bool, str]: """ Evaluates if the predicted class instantiation matches the ground truth. Performs robust parsing of both prediction and ground truth to compare: 1. Class name 2. Constructor argument names and order 3. Argument values (when literals) Args: prediction: The LLM's predicted instantiation code ground_truth: The expected instantiation code Returns: Tuple containing: - Boolean indicating if the prediction is correct - String explaining why the prediction was correct or incorrect """ # Parse both prediction and ground truth into AST pred_ast = ast.parse(prediction.strip()).body[0].value truth_ast = ast.parse(ground_truth.strip()).body[0].value # Verify it's a class instantiation if not isinstance(pred_ast, ast.Call) or not isinstance(truth_ast, ast.Call): return False, "Not a valid class instantiation" # Safely extract the class name from both pred_func_name = self._get_func_name(pred_ast.func) truth_func_name = self._get_func_name(truth_ast.func) # Check class name if pred_func_name != truth_func_name: return False, f"Class name mismatch: expected {truth_func_name}, got {pred_func_name}" # Get argument information (keyword args) pred_args = {kw.arg: kw.value for kw in pred_ast.keywords} truth_args = {kw.arg: kw.value for kw in truth_ast.keywords} # Check positional arguments if len(pred_ast.args) != len(truth_ast.args): return False, "Mismatched number of positional arguments" # Check keyword arguments if set(pred_args.keys()) != set(truth_args.keys()): return False, "Mismatched keyword argument names" # Check argument order and values for positional args for i, (p_arg, t_arg) in enumerate(zip(pred_ast.args, truth_ast.args)): if not self._compare_ast_nodes(p_arg, t_arg): return False, f"Positional argument {i+1} mismatch" # Check keyword argument values for arg_name, t_value in truth_args.items(): p_value = pred_args[arg_name] if not self._compare_ast_nodes(p_value, t_value): return False, f"Keyword argument '{arg_name}' value mismatch" return True, "Class instantiation matches expected usage" def _compare_ast_nodes(self, node1: ast.AST, node2: ast.AST) -> bool: """Helper method to compare two AST nodes.""" # Reuse the same implementation as FunctionExampleEvaluator return FunctionExampleEvaluator._compare_ast_nodes(self, node1, node2) class MethodExampleEvaluator(BaseExampleEvaluator): """ Evaluates the quality of class method docstring examples by testing if they enable correct method call prediction. """ def get_evaluation_prompt(self, context_code: str, signature: str, example: str) -> str: """ Generates a prompt for LLM to predict the method call line. Args: context_code: The code leading up to method call signature: The method signature example: The example part of the docstring Returns: A formatted prompt string that can be sent to an LLM for prediction """ prompt = [ "Given the following context, predict ONLY the next line of code that calls the class method.", "Your prediction should be based solely on the method signature and example provided.", "", "Method signature:", "", signature, "", "", "Example from docstring:", "", example, "", "", "Context code leading up to method call:", "", context_code, "", "", "IMPORTANT INSTRUCTIONS:", "1. Predict ONLY the next line(s) that calls the method", "2. Base your prediction solely on the signature and example", "3. Include ONLY the method call, no additional explanation", "4. If the method call spans multiple lines, include all necessary lines", "5. Ensure the prediction is valid Python syntax", "", "Your prediction should be enclosed in tags", ] return "\n".join(prompt) def evaluate_prediction(self, prediction: str, ground_truth: str) -> Tuple[bool, str]: """ Evaluates if the predicted method call matches the ground truth. Performs robust parsing of both prediction and ground truth to compare: 1. Object and method names 2. Argument names and order 3. Argument values (when literals) Args: prediction: The LLM's predicted method call ground_truth: The expected method call Returns: Tuple containing: - Boolean indicating if the prediction is correct - String explaining why the prediction was correct or incorrect """ # Parse both prediction and ground truth into AST pred_ast = ast.parse(prediction.strip()).body[0].value truth_ast = ast.parse(ground_truth.strip()).body[0].value # Verify it's a method call if not isinstance(pred_ast, ast.Call) or not isinstance(truth_ast, ast.Call): return False, "Not a valid method call" # For method calls, we need to check both object and method names if not isinstance(pred_ast.func, ast.Attribute) or not isinstance(truth_ast.func, ast.Attribute): return False, "Not a valid method call (missing object reference)" # Check object name if not self._compare_ast_nodes(pred_ast.func.value, truth_ast.func.value): return False, "Object reference mismatch" # Check method name if pred_ast.func.attr != truth_ast.func.attr: return False, f"Method name mismatch: expected {truth_ast.func.attr}, got {pred_ast.func.attr}" # Get argument information pred_args = { kw.arg: kw.value for kw in pred_ast.keywords } truth_args = { kw.arg: kw.value for kw in truth_ast.keywords } # Check positional arguments if len(pred_ast.args) != len(truth_ast.args): return False, "Mismatched number of positional arguments" # Check keyword arguments if set(pred_args.keys()) != set(truth_args.keys()): return False, "Mismatched keyword argument names" # Check argument order for positional args for i, (p_arg, t_arg) in enumerate(zip(pred_ast.args, truth_ast.args)): if not self._compare_ast_nodes(p_arg, t_arg): return False, f"Positional argument {i+1} mismatch" # Check keyword argument values for arg_name, t_value in truth_args.items(): p_value = pred_args[arg_name] if not self._compare_ast_nodes(p_value, t_value): return False, f"Keyword argument '{arg_name}' value mismatch" return True, "Method call matches expected usage" def _compare_ast_nodes(self, node1: ast.AST, node2: ast.AST) -> bool: """Helper method to compare two AST nodes.""" # Reuse the same implementation as FunctionExampleEvaluator return FunctionExampleEvaluator._compare_ast_nodes(self, node1, node2) ================================================ FILE: src/evaluator/helpfulness_parameters.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, List, Optional, Tuple import re from dataclasses import dataclass from enum import Enum from src.evaluator.evaluation_common import ScoreLevel, ParameterEvaluationExample class DocstringParametersEvaluator: """ Evaluates the quality of Python docstring parameter descriptions using predefined criteria. This class assesses how well parameter descriptions in docstrings convey the purpose, constraints, and usage context of class initialization parameters, going beyond mere type information to provide meaningful guidance to users. """ def __init__(self): """Initialize the evaluator with predefined criteria and examples.""" self.criteria = self._initialize_criteria() self.examples = self._initialize_examples() def _initialize_criteria(self) -> Dict[str, Any]: """ Set up the evaluation criteria for parameter descriptions. The criteria define five quality levels, from mere type repetition (1) to excellent usage guidance and context (5). Returns: Dict containing the evaluation criteria and descriptions for each score level. """ return { 'description': ( 'Evaluate how effectively the parameter descriptions convey the purpose, ' 'constraints, and usage context of class initialization parameters. ' 'High-quality descriptions should go beyond type information to provide ' 'meaningful guidance about parameter usage, valid values, and impact ' 'on class behavior.' ), 'score_criteria': { ScoreLevel.POOR: ( 'The parameter descriptions merely restate the parameter types or ' 'convert the type hints to natural language without adding any ' 'meaningful information about usage or purpose.' ), ScoreLevel.FAIR: ( 'The descriptions provide basic information about parameter purpose ' 'but lack details about constraints, valid values, or usage context. ' 'They may use vague language or miss important details.' ), ScoreLevel.GOOD: ( 'The descriptions explain parameter purpose and include some key ' 'constraints or valid value ranges, but might miss edge cases or ' 'lack examples where helpful.' ), ScoreLevel.VERY_GOOD: ( 'The descriptions clearly explain purpose, constraints, and common ' 'usage patterns. They may include examples for complex parameters ' 'and note important edge cases or default behaviors.' ), ScoreLevel.EXCELLENT: ( 'The descriptions provide comprehensive guidance including purpose, ' 'constraints, examples, edge cases, and impact on class behavior. But still keep it concise and focus on the most important information.' 'They help users make informed decisions about parameter values.' ) } } def _initialize_examples(self) -> List[ParameterEvaluationExample]: """ Set up concrete examples of parameter descriptions at different quality levels. Each example includes class and __init__ signatures with corresponding parameter descriptions at different quality levels, along with explanations of the ratings. Returns: List of ParameterEvaluationExample objects containing the example cases. """ return [ ParameterEvaluationExample( parameters={ "Model_entity_id": "Numeric identifier for the model entity", "Dist_pg": "Distributed process group for coordination", "Checkpoint_config": "Defines checkpoint saving intervals and retention", "Runtime_config": "Specifies resource or environmental constraints", "Train_module": "Orchestrates training steps and interfaces with checkpoints" }, quality_examples={ ScoreLevel.POOR: { "Model_entity_id": "the model entity ID", "Dist_pg": "The Process group", "Checkpoint_config": "The checkpoint Configuration", "Runtime_config": "The Runtime configuration", "Train_module": "The Training module" }, ScoreLevel.FAIR: { "Model_entity_id": "A number that identifies the model", "Dist_pg": "Process group for distributed operations", "Checkpoint_config": "Settings for checkpoint management", "Runtime_config": "Configuration for runtime behavior", "Train_module": "Module that manages the training process" }, ScoreLevel.GOOD: { "Model_entity_id": "identifier for the model entity.", "Dist_pg": "PyTorch distributed process group that handles communication between processes", "Checkpoint_config": "Configuration that determines when checkpoints are saved and how many are kept", "Runtime_config": "Specifies runtime parameters like memory limits and timeout settings", "Train_module": "Module that implements training logic and interacts with the checkpoint system" }, ScoreLevel.VERY_GOOD: { "Model_entity_id": "Unique numeric identifier for the model entity in the registry. Must be a valid registered model ID", "Dist_pg": "PyTorch distributed process group that coordinates operations across GPUs/nodes during training. Should match your distributed setup", "Checkpoint_config": "Controls checkpoint frequency, storage locations, and retention policies. Important for balancing disk usage with recovery capabilities", "Runtime_config": "Defines resource constraints and operational parameters. Must be configured appropriately for your hardware to avoid performance issues", "Train_module": "Orchestrates the training workflow, manages state transitions, and defines what model components get checkpointed" }, ScoreLevel.EXCELLENT: { "Model_entity_id": "Unique integer ID for the model entity (e.g., 1014925). Should always be a 7 digits number. Must exist in the model registry before checkpointing, otherwise will hit CheckpointNotFoundError and fail to load the checkpoint.", "Dist_pg": "Distributed process group that handles collective operations for multi-GPU or multi-node setups. This setup must be consistent with the training configuration 'distributed_training_config'.", "Checkpoint_config": "Specifies saving intervals, naming formats, and retention. Supports advanced features like asynchronous checkpointing. See examples in 'https://fb.workplace.com/groups/652446422242/preview'.", "Runtime_config": "Contains environment constraints (e.g., memory, disk I/O) and concurrency policies. Ensures checkpointing does not stall training under restricted resources, otherwise will hit CheckpointAccessError and fail to load the checkpoint.", "Train_module": "Manages end-to-end training flow, triggers checkpoint saving at appropriate intervals, and provides context on what states/parameters to store." }, }, explanations={ ScoreLevel.POOR: "Descriptions recite minimal type info, lacking usage or constraints", ScoreLevel.FAIR: "Provides a basic sense of the purpose for each parameter, but lacks detail", ScoreLevel.GOOD: "Covers core constraints and a bit of context, but some usage details are still missing", ScoreLevel.VERY_GOOD: "Explains relevant usage patterns, constraints, and environment needs", ScoreLevel.EXCELLENT: "Comprehensive coverage including resource impact, advanced usage scenarios, and constraints" } ) ] def get_evaluation_prompt(self, code_component: str, docstring: str, eval_type: str = None) -> str: """ Generates a prompt for LLM evaluation of parameter descriptions. Args: code_component: The code implementation (class or function/method) docstring: The docstring to evaluate eval_type: The type of code component (class, function, method). If not provided, it will be determined from code_component. Returns: Prompt for LLM evaluation """ # Determine eval_type if not provided if eval_type is None: if code_component.strip().startswith("class "): eval_type = "class" else: eval_type = "function" if "self" not in code_component.split("(")[0] else "method" assert eval_type in ["class", "function", "method"], "eval_type must be one of 'class', 'function', or 'method'" example = self.examples[0] # Use first example as reference # system prompt prompt = [ "Please evaluate the parameter description section for a docstring of a " + eval_type + " based on these criteria:"] # second part, the evaluation criteria prompt.extend([ "", "", "Evaluation criteria:", self.criteria['description'], "", "Score levels:", ]) # Add criteria for each score level for level in ScoreLevel: prompt.append(f"{level.value}. {self.criteria['score_criteria'][level]}") prompt.append("") # Add example prompt.extend([ "", "", "Parameter descriptions at different quality levels:", ]) for level in ScoreLevel: prompt.extend([ f"Level {level.value}:", *[f"{param}: {desc}" for param, desc in example.quality_examples[level].items()], f"Explanation: {example.explanations[level]}", "" ]) prompt.append("") # add focal code component and docstring prompt.extend([ "", "", f"{code_component}", "", "", "", "Parameter descriptions to evaluate:", f"{docstring}", "" ]) prompt.extend([ "", "", "IMPORTANT INSTRUCTIONS FOR ANALYSIS:", "1. Analyze how well each parameter description provides meaningful information beyond type hints", "2. Consider completeness of constraint and valid value documentation", "3. Look for helpful context about parameter impact on code component's behavior", "4. Check for clear examples or guidance where appropriate", "", "", "", "Please structure your response as follows:", "1. Compare against the criteria and example quality levels", "2. Suggest specific improvements for weaker descriptions. Include your suggestions in tags. No need to provide suggestions for excellent descriptions.", "3. Provide your score (1-5) enclosed in tags", "", "", "Remember: Do not rush to assign a score. Take time to analyze thoroughly and justify your reasoning.", "The score should reflect your careful analysis and should be the last part of your response.", ]) return "\n".join(prompt) def parse_llm_response(self, response: str) -> Tuple[int, str]: """ Extracts the numerical score and suggestions from an LLM's response. Args: response: The complete response text from the LLM. Returns: A tuple containing: - The numerical score (1-5) - The suggestions for improvement Raises: ValueError: If no valid score is found. """ # Extract score from XML tags score_patterns = [ r'(\d)', # XML tags r'score:\s*(\d)', # Common format r'score\s*=\s*(\d)', # Alternative format r'(\d)\s*/\s*5', # Rating format ] # Try each pattern for pattern in score_patterns: score_matches = re.findall(pattern, response, re.IGNORECASE) if score_matches: score = int(score_matches[0]) if 1 <= score <= 5: break else: # If no score found, use a default score = 3 # Extract suggestions - look for several common patterns suggestion_patterns = [ r'(.*?)', # XML tags r'suggestions?:\s*(.+?)(?:\n\n|\Z)', # Common format r'improve?:?\s*(.+?)(?:\n\n|\Z)', # Alternative format ] # Try each pattern for pattern in suggestion_patterns: suggestion_matches = re.findall(pattern, response, re.DOTALL | re.IGNORECASE) if suggestion_matches: suggestion = suggestion_matches[0].strip() break else: # Try to find any text that looks like suggestions lines = response.split('\n') for i, line in enumerate(lines): if "suggest" in line.lower() and i < len(lines) - 1: suggestion = lines[i+1].strip() break else: suggestion = "Consider adding more detailed parameter descriptions." return score, suggestion def get_criteria_description(self) -> str: """Returns the main criteria description.""" return self.criteria['description'] def get_score_criteria(self, level: ScoreLevel) -> str: """Returns the criteria description for a specific score level.""" return self.criteria['score_criteria'][level] def get_examples(self) -> List[ParameterEvaluationExample]: """Returns all evaluation examples.""" return self.examples ================================================ FILE: src/evaluator/helpfulness_summary.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Any, List, Optional, Tuple import re from dataclasses import dataclass from enum import Enum from src.evaluator.evaluation_common import ScoreLevel, SummaryEvaluationExample class DocstringSummaryEvaluator: """ Evaluates the quality of Python docstring summaries using predefined criteria and examples. This class provides a structured way to assess how well a docstring's summary line conveys the purpose and value of a function or class. It includes detailed criteria for different quality levels and concrete examples to guide the evaluation process. """ def __init__(self): """Initialize the evaluator with predefined criteria and examples.""" self.criteria = self._initialize_criteria() self.examples = self._initialize_examples() def _initialize_criteria(self) -> Dict[str, Any]: """ Set up the evaluation criteria for docstring summaries. The criteria define five quality levels, from mere signature repetition (1) to excellent context and purpose explanation (5). Returns: Dict containing the evaluation criteria and descriptions for each score level. """ return { 'description': ( 'Evaluate how effectively the one-line summary conveys ' 'the purpose and value of the function/class while providing additional ' 'context beyond what is apparent from the signature. A high-quality ' 'summary should be concise yet informative, avoiding mere signature ' 'repetition while adding meaningful context about the "why" or ' 'higher-level purpose.' ), 'score_criteria': { ScoreLevel.POOR: ( 'The summary merely restates the function signature in natural ' 'language or is completely unrelated to the function purpose. ' 'The summary provides no additional information beyond what is ' 'already obvious from the function name and parameters.' ), ScoreLevel.FAIR: ( 'The summary provides minimal information beyond the signature, ' 'perhaps adding one minor detail but still failing to convey ' 'meaningful context or purpose. It may use vague or overly ' 'technical language that doesn\'t help understanding.' ), ScoreLevel.GOOD: ( 'The summary provides some useful context beyond the signature, ' 'touching on either the "why" or a key use case, but could be ' 'more specific or comprehensive. It gives readers a general idea ' 'but may leave out important context.' ), ScoreLevel.VERY_GOOD: ( 'The summary effectively communicates both what the function does ' 'and its higher-level purpose, using clear language that helps ' 'readers understand when/why to use it. It avoids technical ' 'jargon unless necessary.' ), ScoreLevel.EXCELLENT: ( 'The summary excellently balances conciseness with informativeness, ' 'clearly conveying the function\'s purpose, value, and context in ' 'business/practical terms. It helps readers immediately understand ' 'both what the function does and why it matters.' ) } } def _initialize_examples(self) -> List[SummaryEvaluationExample]: """ Set up concrete examples of docstring summaries at different quality levels. Each example includes a function signature and corresponding summaries at different quality levels, along with explanations of the ratings. Returns: List of SummaryEvaluationExample objects containing the example cases. """ return [ SummaryEvaluationExample( function_signature=( "def calculate_user_metrics(user_id: str, start_date: datetime, " "end_date: datetime) -> Dict[str, float]" ), summaries={ ScoreLevel.POOR: "Calculates metrics for a user between two dates.", ScoreLevel.FAIR: "Processes user metrics data through various calculation methods.", ScoreLevel.GOOD: "Analyzes user engagement patterns by computing daily interaction statistics.", ScoreLevel.VERY_GOOD: ( "Generates user engagement insights for quarterly reporting by " "processing daily interaction metrics." ), ScoreLevel.EXCELLENT: ( "Identifies at-risk users by analyzing engagement patterns " "against historical churn indicators." ) }, explanations={ ScoreLevel.POOR: "This summary merely converts the function signature into a sentence, providing no additional value.", ScoreLevel.FAIR: "While this adds slightly more information than the signature, it remains vague and unhelpful.", ScoreLevel.GOOD: ( "This provides some context about the purpose (engagement analysis) " "but could be more specific about why we track this." ), ScoreLevel.VERY_GOOD: ( "This effectively communicates both what it does and why " "(quarterly reporting), giving clear context for its use." ), ScoreLevel.EXCELLENT: ( "This excellently conveys both the technical function and its " "business purpose (preventing churn) in a clear, meaningful way." ) } ), SummaryEvaluationExample( function_signature=( "class DatasetLoader:" ), summaries={ ScoreLevel.POOR: "A class that loads datasets.", ScoreLevel.FAIR: "Handles loading of data from various sources.", ScoreLevel.GOOD: "Provides unified interface for loading and validating datasets from multiple sources.", ScoreLevel.VERY_GOOD: ( "Streamlines dataset ingestion by providing a consistent interface " "for loading and validating data from diverse sources." ), ScoreLevel.EXCELLENT: ( "Ensures data quality and consistency by providing a unified interface " "for loading, validating, and preprocessing datasets across multiple " "formats and sources while handling common edge cases." ) }, explanations={ ScoreLevel.POOR: "Simply restates the class name without adding value.", ScoreLevel.FAIR: "Adds minimal information, remains vague about capabilities.", ScoreLevel.GOOD: ( "Provides context about key functionality but could better explain " "benefits and use cases." ), ScoreLevel.VERY_GOOD: ( "Clearly communicates purpose and value while highlighting key " "features and benefits." ), ScoreLevel.EXCELLENT: ( "Excellently balances technical capabilities with practical benefits, " "while highlighting key differentiators and value proposition." ) } ) ] def get_evaluation_prompt(self, code_component: str, docstring: str, eval_type: str = None) -> str: """ Generates a prompt for LLM evaluation of docstring summaries. Args: code_component: The code implementation (class or function/method) docstring: The docstring to evaluate eval_type: The type of code component (class, function, method). If not provided, it will be determined from code_component. Returns: Prompt for LLM evaluation """ # Determine eval_type if not provided if eval_type is None: if code_component.strip().startswith("class "): eval_type = "class" else: eval_type = "function" if "self" not in code_component.split("(")[0] else "method" # Determine if input is a class or function signature is_class = eval_type == "class" # Select relevant example based on signature type relevant_example = next( example for example in self.examples if (example.function_signature.startswith('class') == is_class) ) prompt = [ "Please evaluate the summary part of a docstring of a " + eval_type + " based on these criteria:", ] # Add criteria for each score level for level in ScoreLevel: prompt.append(f"{level.value}. {self.criteria['score_criteria'][level]}") prompt.append("") # Add single relevant example prompt.extend([ "", "", "Summaries at different levels:", ]) for level in ScoreLevel: prompt.extend([ f"Level {level.value}: {relevant_example.summaries[level]}", f"Explanation: {relevant_example.explanations[level]}", "" ]) prompt.append("") # add the code component and the docstring prompt.extend([ "", "", f"{code_component}", "", ]) prompt.extend([ "", "", f"{docstring}", "", ]) prompt.extend([ "", "", "IMPORTANT INSTRUCTIONS FOR ANALYSIS:", "1. Take your time to analyze the relationship between the focal code component and the summary part of the docstring.", "2. Consider how much additional context and value the summary provides beyond the signature.", "3. Compare the summary against each score level's criteria methodically.", "4. Look for similarities with the provided example at each quality level.", "", "", "", "Please structure your response as follows:", "1. First explain your reasoning by comparing against the criteria", "2. If applicable, suggest specific improvements. Include your suggestions in tags. No need to provide suggestions for excellent summaries.", "3. Finally, provide your score (1-5) enclosed in tags", "", "", "Remember: Do not rush to assign a score. Take time to analyze thoroughly and justify your reasoning.", "The score should reflect your careful analysis and should be the last part of your response." ]) return "\n".join(prompt) def parse_llm_response(self, response: str) -> Tuple[int, str]: """ Extracts the numerical score and suggestions from an LLM's response. Args: response: The complete response text from the LLM. Returns: A tuple containing: - The numerical score (1-5) - The suggestions for improvement Raises: ValueError: If no valid score is found. """ # Extract score from various patterns score_patterns = [ r'(\d)', # XML tags r'score:\s*(\d)', # Common format r'score\s*=\s*(\d)', # Alternative format r'(\d)\s*/\s*5', # Rating format r'level\s*(\d)', # Level references ] # Try each pattern for pattern in score_patterns: score_matches = re.findall(pattern, response, re.IGNORECASE) if score_matches: score = int(score_matches[0]) if 1 <= score <= 5: break else: # If no score found, default to 3 score = 3 # Extract suggestions - look for several common patterns suggestion_patterns = [ r'(.*?)', # XML tags r'suggestions?:\s*(.+?)(?:\n\n|\Z)', # Common format r'could be improved by:?\s*(.+?)(?:\n\n|\Z)', # Alternative phrasing r'improvement:?\s*(.+?)(?:\n\n|\Z)', # Another alternative ] # Try each pattern for pattern in suggestion_patterns: suggestion_matches = re.findall(pattern, response, re.DOTALL | re.IGNORECASE) if suggestion_matches: suggestion = suggestion_matches[0].strip() break else: # If we can't find a suggestion, extract sentences that seem like suggestions suggestion_sentences = [] for sentence in re.split(r'[.!?]\s+', response): if any(word in sentence.lower() for word in ['could', 'should', 'might', 'consider', 'suggest', 'improve', 'better']): suggestion_sentences.append(sentence.strip()) if suggestion_sentences: suggestion = ' '.join(suggestion_sentences) + '.' else: # Default suggestion suggestion = "Consider adding more context and purpose to the summary." return score, suggestion def get_criteria_description(self) -> str: """Returns the main criteria description.""" return self.criteria['description'] def get_score_criteria(self, level: ScoreLevel) -> str: """Returns the criteria description for a specific score level.""" return self.criteria['score_criteria'][level] def get_examples(self) -> List[SummaryEvaluationExample]: """Returns all evaluation examples.""" return self.examples ================================================ FILE: src/evaluator/segment.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import re def parse_google_style_docstring(docstring): """ A robust parser for Google-style docstrings that have multiple possible labels for each section. For example, any of the lines in EXAMPLE_LABELS indicates the start of the "examples" section. """ # Define all recognized sections. The key is the canonical name (lowercase). # The value is a set of synonyms (also lowercase). SECTION_LABELS = { "summary": {"summary:", "short description:", "brief:", "overview:"}, "description": {"description:", "desc:", "details:", "detailed description:", "long description:"}, "parameters": {"parameters:", "params:", "args:", "arguments:", "keyword args:", "keyword arguments:", "**kwargs:"}, "attributes": {"attributes:", "members:", "member variables:", "instance variables:", "properties:", "vars:", "variables:"}, "returns": {"returns:", "return:", "return value:", "return values:"}, "raises": {"raises:", "exceptions:", "throws:", "raise:", "exception:", "throw:"}, "examples": {"example:", "examples:", "usage:", "usage example:", "usage examples:", "example usage:"}, } # Prepare a dictionary to hold the parsed content for each canonical key parsed_content = {key: [] for key in SECTION_LABELS.keys()} # Split by lines; if docstring uses Windows line endings, .splitlines() handles that gracefully lines = docstring.strip().splitlines() # -- 1) Fallback: no explicit sections at all in the entire docstring -- # If no recognized label appears anywhere, treat the first line as summary, rest as description. has_section_labels = False for line in lines: line_lower = line.strip().lower() for labels in SECTION_LABELS.values(): for label in labels: if line_lower.startswith(label): has_section_labels = True break if has_section_labels: break if has_section_labels: break if len(lines) > 0 and not has_section_labels: parsed_content["summary"] = [lines[0]] if len(lines) > 1: parsed_content["description"] = lines[1:] # Convert lists to single strings return {key: "\n".join(value).strip() for key, value in parsed_content.items()} # -- 2) Partial Fallback for the first line only -- # If the first line doesn't match any known label, treat it as summary and then # switch to "description" until an explicit label is found. current_section = None # keep track of which section we're in first_line = lines[0].strip().lower() if lines else "" if not any(first_line.startswith(label) for labels in SECTION_LABELS.values() for label in labels): if lines: # Save first line as summary parsed_content["summary"] = [lines[0]] # Make the current section "description" current_section = "description" lines = lines[1:] # We'll handle the rest below for line in lines: # We'll do a trimmed, lowercase version of the line to check for a header # but keep original_line if you want to preserve original indentation or case. trimmed_line = line.strip().lower() # Check if the trimmed line (minus trailing colon, if present) matches a known section # We'll also handle any trailing colon, extra spaces, etc. # e.g. " Parameters: " -> "parameters:" # We only match a line if it starts exactly with that label. # If you want more flexible matching (like partial lines), you can adapt this. matched_section = None for canonical_name, synonyms in SECTION_LABELS.items(): # Each synonym might be "parameters:", "args:", etc. # We'll see if the trimmed_line starts exactly with one of them. for synonym in synonyms: # If line starts with the synonym, we treat it as a new section. # Example: "PARAMETERS:" -> synonyms might contain "parameters:" in lowercase if trimmed_line.startswith(synonym): matched_section = canonical_name # Extract leftover text on the same line, after the label leftover = line.strip()[len(synonym):].strip() if leftover: parsed_content[matched_section].append(leftover) break if matched_section: break # If matched_section is not None, we found a new section header if matched_section is not None: # Switch to that section current_section = matched_section # No need to append the header line to content - we've already handled any content after the label else: # Otherwise, accumulate this line under the current section if we have one if current_section is not None: parsed_content[current_section].append(line) # Convert list of lines to a single string for each section, # with consistent line breaks, and strip extra whitespace for section in parsed_content: parsed_content[section] = "\n".join(parsed_content[section]).strip() return parsed_content # ------------------------------ Example Usage ------------------------------ if __name__ == "__main__": sample_docstring = """ Summary: Provides a utility for processing and managing data through a structured workflow. Description: This class is designed to facilitate data processing tasks by integrating with the `DataProcessor` class. It retrieves and manipulates data. Parameters: param1: This is the first parameter. param2: This is the second parameter. Attributes: data: Stores the current data. Example: ```python helper = HelperClass() helper.process_data() print(helper.data) ``` """ result = parse_google_style_docstring(sample_docstring) # Print out each section for section_name, content in result.items(): print("SECTION:", section_name.upper()) print("CONTENT:\n", content) print("-" * 40) ================================================ FILE: src/evaluator/truthfulness.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import json import os import re import sys from typing import List, Dict, Any, Set, Tuple import google.generativeai as genai from tqdm import tqdm import pandas as pd from collections import defaultdict # Constants SYSTEMS = [ "copy_paste_codellama34b", "copy_paste_gpt4o_mini", "docassist-codellama34b", "docassist-gpt4o_mini", "fim-codellama13b", ] GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") if not GEMINI_API_KEY: raise ValueError("GEMINI_API_KEY is not set") # Configure Gemini API genai.configure(api_key=GEMINI_API_KEY) model = genai.GenerativeModel("gemini-2.0-flash") def extract_components_from_docstring(docstring: str) -> List[str]: """ Extract code components (classes, methods, functions) mentioned in a docstring using Gemini API. Args: docstring: The docstring text to analyze Returns: List of code component names mentioned in the docstring """ prompt = f""" Please extract all the non-common (very likely to be newly-defined in the repository) code components (classes, methods, functions) mentioned in the following docstring. Ignore the example part of the docstring if it exists (the code component you extract should not come from the example code). For example, "List" is a very common class, so it should not be included. On the other hand, "InMemoryCache" is not a common class, so it should be included. Return only a Python list of strings with the exact names. If no code components are mentioned, return an empty list. Docstring: ``` {docstring} ``` Format your response as a Python list wrapped in XML tags like this: ["ClassA", "method_b", "function_c"] """ try: response = model.generate_content(prompt) response_text = response.text.strip() # Extract list from XML tags match = re.search(r'(.*?)', response_text, re.DOTALL) if match: list_str = match.group(1) try: # Safely evaluate the list string components = eval(list_str) if isinstance(components, list): return components except: # If evaluation fails, extract strings manually components = re.findall(r'"([^"]*)"', list_str) return components # Fallback: try to extract using regex for regular list match = re.search(r'\[.*?\]', response_text, re.DOTALL) if match: list_str = match.group(0) try: # Safely evaluate the list string components = eval(list_str) if isinstance(components, list): return components except: # If evaluation fails, extract strings manually components = re.findall(r'"([^"]*)"', list_str) return components # Fallback: try to find any mention of code looking elements components = re.findall(r'`([^`]+)`', docstring) return [c for c in components if not c.startswith('(') and not c.endswith(')')] except Exception as e: print(f"Error calling Gemini API: {e}") # Fallback: try to find any mention of code looking elements components = re.findall(r'`([^`]+)`', docstring) return [c for c in components if not c.startswith('(') and not c.endswith(')')] def load_dependency_graph(repo_name: str) -> Dict[str, Any]: """ Load the dependency graph for a given repository. Args: repo_name: Repository name Returns: Dependency graph data """ file_path = f"output/dependency_graphs/{repo_name}_dependency_graph.json" try: with open(file_path, 'r') as f: return json.load(f) except FileNotFoundError: print(f"Dependency graph not found: {file_path}") return {} def check_component_existence( component_name: str, dependency_graph: Dict[str, Any], docstring_path: str ) -> Tuple[bool, bool]: """ Check if a component exists in the dependency graph and if it's a cross-file reference. Args: component_name: Name of the component to check dependency_graph: Dependency graph data docstring_path: Path of the docstring's component Returns: Tuple of (exists, is_cross_file) """ exists = False is_cross_file = False docstring_relative_path = None if "/" in docstring_path: # Extract the relative path from the docstring path parts = docstring_path.split("/") repo_name = parts[1] relative_path = "/".join(parts[1:-1]) docstring_relative_path = relative_path for comp_id, comp_data in dependency_graph.items(): # Check if the component name is in the ID if component_name in comp_id.split(".")[-1]: exists = True # Check if it's a cross-file reference if docstring_relative_path and "relative_path" in comp_data: comp_relative_path = comp_data["relative_path"] if docstring_relative_path != comp_relative_path: is_cross_file = True break return exists, is_cross_file def main(): # Load completeness evaluation data print("Loading completeness evaluation data...") with open("experiments/eval/results/completeness_evaluation_cleaned.json", 'r') as f: completeness_data = json.load(f) results = {} # Process each component in the completeness data for component_path, component_data in tqdm(completeness_data.items()): if "docstrings" not in component_data: continue # Extract repo name parts = component_path.split("/") repo_name = parts[1] # replace all - in reponame to _ repo_name = repo_name.replace("-", "_") # Load dependency graph for this repo (once) if repo_name not in results: print(f"Loading dependency graph for {repo_name}...") dependency_graph = load_dependency_graph(repo_name) results[repo_name] = {} # For each system, analyze the docstring for system in SYSTEMS: if system not in component_data["docstrings"]: continue docstring = component_data["docstrings"][system]["docstring"] # Extract mentioned components from docstring components = extract_components_from_docstring(docstring) # Check existence of each component in the dependency graph component_results = [] for comp in components: exists, is_cross_file = check_component_existence( comp, dependency_graph, component_path ) component_results.append({ "name": comp, "exists": exists, "is_cross_file": is_cross_file }) # Store results if component_path not in results[repo_name]: results[repo_name][component_path] = {} results[repo_name][component_path][system] = { "mentioned_components": component_results, "total_mentions": len(components), "existing_mentions": sum(1 for c in component_results if c["exists"]), "cross_file_mentions": sum(1 for c in component_results if c["is_cross_file"]) } # Save detailed results with open("experiments/eval/results/docstring_truthfulness_evaluation.json", 'w') as f: json.dump(results, f, indent=2) # Generate summary report generate_summary_report(results) def generate_summary_report(results: Dict[str, Dict[str, Dict[str, Any]]]): """ Generate a summary report comparing the five systems. Args: results: The evaluation results """ # Aggregate statistics stats = { system: { "total_components_mentioned": 0, "existing_components": 0, "cross_file_mentions": 0, "docstrings_analyzed": 0 } for system in SYSTEMS } # Calculate statistics for repo_name, repo_data in results.items(): for component_path, comp_data in repo_data.items(): for system, system_data in comp_data.items(): if system in SYSTEMS: stats[system]["total_components_mentioned"] += system_data["total_mentions"] stats[system]["existing_components"] += system_data["existing_mentions"] stats[system]["cross_file_mentions"] += system_data["cross_file_mentions"] stats[system]["docstrings_analyzed"] += 1 # Calculate ratios for system in SYSTEMS: total = stats[system]["total_components_mentioned"] if total > 0: stats[system]["existence_ratio"] = stats[system]["existing_components"] / total else: stats[system]["existence_ratio"] = 0 if stats[system]["existing_components"] > 0: stats[system]["cross_file_ratio"] = stats[system]["cross_file_mentions"] / stats[system]["existing_components"] else: stats[system]["cross_file_ratio"] = 0 if stats[system]["docstrings_analyzed"] > 0: stats[system]["avg_mentions_per_doc"] = total / stats[system]["docstrings_analyzed"] else: stats[system]["avg_mentions_per_doc"] = 0 # Create markdown report report = "# Docstring Truthfulness Evaluation Report\n\n" # Table 1: Component Existence report += "## Component Existence Ratio (higher is better)\n\n" report += "| System | Components Mentioned | Existing Components | Existence Ratio |\n" report += "|--------|---------------------|---------------------|-----------------|\n" for system in SYSTEMS: report += f"| {system} | {stats[system]['total_components_mentioned']} | {stats[system]['existing_components']} | {stats[system]['existence_ratio']:.2%} |\n" # Table 2: Component Mentions report += "\n## Component Mention Frequency (higher is better)\n\n" report += "| System | Docstrings Analyzed | Total Components | Avg Mentions Per Doc |\n" report += "|--------|---------------------|------------------|-----------------------|\n" for system in SYSTEMS: report += f"| {system} | {stats[system]['docstrings_analyzed']} | {stats[system]['total_components_mentioned']} | {stats[system]['avg_mentions_per_doc']:.2f} |\n" # Table 3: Cross-file References report += "\n## Cross-file References (higher is better)\n\n" report += "| System | Existing Components | Cross-file References | Cross-file Ratio |\n" report += "|--------|---------------------|----------------------|-----------------|\n" for system in SYSTEMS: report += f"| {system} | {stats[system]['existing_components']} | {stats[system]['cross_file_mentions']} | {stats[system]['cross_file_ratio']:.2%} |\n" # Save the report with open("experiments/eval/results/docstring_truthfulness_report.md", 'w') as f: f.write(report) print("Summary report generated: docstring_truthfulness_report.md") if __name__ == "__main__": main() ================================================ FILE: src/visualizer/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from .status import StatusVisualizer from .progress import ProgressVisualizer __all__ = ['StatusVisualizer', 'ProgressVisualizer'] ================================================ FILE: src/visualizer/progress.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Terminal-based progress visualization for docstring generation. This module provides a class for visualizing the progress of generating docstrings using a topologically sorted dependency graph. """ import sys import time import os from typing import Dict, List, Set, Optional from colorama import Fore, Back, Style, init from tqdm import tqdm class ProgressVisualizer: """Visualizes the progress of docstring generation in the terminal.""" def __init__(self, components: Dict[str, any], sorted_order: List[str]): """ Initialize the progress visualizer. Args: components: Dictionary of code components sorted_order: List of component IDs in topological order """ init() # Initialize colorama self.components = components self.sorted_order = sorted_order self.processed = set() # Set of processed component IDs self.current = None # Current component being processed self.progress_bar = None self.start_time = time.time() def initialize(self): """Initialize the visualization and show the initial state.""" self._clear_screen() self._print_header() # Create progress bar self.progress_bar = tqdm( total=len(self.sorted_order), desc="Generating docstrings", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" ) # Print initial component status self._print_component_status() def update(self, component_id: str = None, status: str = "processing"): """ Update the visualization with the current component status. Args: component_id: ID of the component being processed (or None) status: Status of the component ('processing', 'completed', or 'error') """ if component_id is not None: self.current = component_id if status == "completed": self.processed.add(component_id) self.progress_bar.update(1) # Update the visualization self._print_component_status() def finalize(self): """Finalize the visualization and show summary statistics.""" if self.progress_bar: self.progress_bar.close() # Calculate elapsed time elapsed = time.time() - self.start_time minutes, seconds = divmod(elapsed, 60) hours, minutes = divmod(minutes, 60) self._clear_screen() self._print_header() # Print summary print(f"\n{Fore.GREEN}Docstring Generation Complete!{Style.RESET_ALL}") print(f"Total components processed: {len(self.processed)}/{len(self.sorted_order)}") print(f"Time elapsed: {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}") print("\nComponents by type:") # Count components by type type_counts = {"function": 0, "method": 0, "class": 0} for comp_id in self.processed: comp_type = self.components[comp_id].component_type type_counts[comp_type] += 1 for comp_type, count in type_counts.items(): print(f" {comp_type.capitalize()}: {count}") print("\nGeneration complete. Results saved to repository files.") def _clear_screen(self): """Clear the terminal screen.""" sys.stdout.write("\033[2J\033[H") sys.stdout.flush() def _print_header(self): """Print the header with title and information.""" title = "Topological Docstring Generator" print(f"\n{Fore.CYAN}{Style.BRIGHT}{title}{Style.RESET_ALL}\n") print(f"Generating docstrings for {len(self.sorted_order)} code components in dependency order") print(f"Components will be processed in topological order to ensure all dependencies") print(f"have docstrings before dependent components.") def _print_component_status(self): """Print the current status of components in the dependency graph.""" if not self.current: return # Get the current component and its info current_comp = self.components.get(self.current) if not current_comp: return # Print current component information comp_type = current_comp.component_type.capitalize() file_path = current_comp.relative_path # Create a simplified name for display parts = self.current.split('.') if len(parts) > 2 and current_comp.component_type == "method": # For methods, show Class.method name = f"{parts[-2]}.{parts[-1]}" else: # For functions and classes, show just the name name = parts[-1] # Print status line print(f"\n{Fore.YELLOW}Currently processing: {Style.RESET_ALL}{comp_type} '{name}' in {file_path}") # Print dependency information if current_comp.depends_on: deps = [dep_id for dep_id in current_comp.depends_on if dep_id in self.components] if deps: print(f"{Fore.CYAN}Dependencies:{Style.RESET_ALL}") for dep_id in deps: dep = self.components.get(dep_id) if not dep: continue # Format the dependency name similarly parts = dep_id.split('.') if len(parts) > 2 and dep.component_type == "method": dep_name = f"{parts[-2]}.{parts[-1]}" else: dep_name = parts[-1] # Color based on processing status if dep_id in self.processed: status_color = Fore.GREEN status_text = "(processed)" else: status_color = Fore.RED status_text = "(not yet processed)" print(f" {status_color}{dep.component_type.capitalize()} '{dep_name}' {status_text}{Style.RESET_ALL}") # Add some space after the component status print() def show_dependency_stats(self): """Show statistics about the dependency graph.""" # Calculate dependency metrics total_deps = sum(len(self.components[comp_id].depends_on) for comp_id in self.components) max_deps = max((len(self.components[comp_id].depends_on), comp_id) for comp_id in self.components) avg_deps = total_deps / len(self.components) if self.components else 0 # Count components by type types = {"function": 0, "method": 0, "class": 0} for comp_id in self.components: comp_type = self.components[comp_id].component_type types[comp_type] += 1 print(f"\n{Fore.CYAN}Dependency Graph Statistics:{Style.RESET_ALL}") print(f"Total components: {len(self.components)}") print(f" Functions: {types['function']}") print(f" Methods: {types['method']}") print(f" Classes: {types['class']}") print(f"Average dependencies per component: {avg_deps:.2f}") print(f"Max dependencies: {max_deps[0]} (in component '{max_deps[1]}')") # Print information about cycles if available print(f"\nComponents will be processed in topological order.") print() ================================================ FILE: src/visualizer/status.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates from typing import Dict, Set from colorama import Fore, Back, Style, init import sys import time import ast from agent.tool.ast import _get_component_name_from_code class StatusVisualizer: """Visualizes the workflow status of DocAssist agents in the terminal.""" def __init__(self): """Initialize the status visualizer.""" init() # Initialize colorama self.active_agent = None # Track only the currently active agent self._agent_art = { 'reader': [ "┌─────────┐", "│ READER │", "└─────────┘" ], 'searcher': [ "┌─────────┐", "│SEARCHER │", "└─────────┘" ], 'writer': [ "┌─────────┐", "│ WRITER │", "└─────────┘" ], 'verifier': [ "┌─────────┐", "│VERIFIER │", "└─────────┘" ] } self._status_message = "" self._current_component = "" self._current_file = "" def _clear_screen(self): """Clear the terminal screen.""" sys.stdout.write("\033[2J\033[H") sys.stdout.flush() def _get_agent_color(self, agent: str) -> str: """Get the color for an agent based on its state.""" return Fore.GREEN if agent == self.active_agent else Fore.WHITE def set_current_component(self, focal_component: str, file_path: str): """Set the current component being processed and display its information. Args: focal_component: The code component being processed file_path: Relative path to the file containing the component """ # Try to extract the component name from the code try: self._current_component = _get_component_name_from_code(focal_component) except: # If parsing fails, just use a generic name self._current_component = "unknown component" self._current_file = file_path self._display_component_info() def _display_component_info(self): """Display information about the current component being processed.""" # print(f"\n{Fore.CYAN}Currently Processing:{Style.RESET_ALL}") print(f"Component: {self._current_component}") print(f"File: {self._current_file}\n") def update(self, active_agent: str, status_message: str = ""): """Update the visualization with the current active agent and status. Args: active_agent: Name of the currently active agent status_message: Current status message to display """ self.active_agent = active_agent # Update the single active agent self._status_message = status_message self._clear_screen() # Build the visualization lines = [] # Add header # lines.append(f"{Fore.CYAN}DocAssist Workflow Status{Style.RESET_ALL}") # lines.append("") # Display current component info if available if self._current_component and self._current_file: lines.append(f"Processing: {self._current_component}") lines.append(f"File: {self._current_file}") lines.append("") # Input arrow to Reader # lines.append(" Input") # lines.append(" ↓") # First row: Reader and Searcher with loop for i in range(3): line = (f"{self._get_agent_color('reader')}{self._agent_art['reader'][i]}" f" ←→ " f"{self._get_agent_color('searcher')}{self._agent_art['searcher'][i]}" f"{Style.RESET_ALL}") lines.append(line) # Arrow from Reader to Writer # lines.append(" ↓") # Second row: Writer for i in range(3): line = (f" {self._get_agent_color('writer')}{self._agent_art['writer'][i]}{Style.RESET_ALL}") lines.append(line) # Arrow from Writer to Verifier # lines.append(" ↓") # Third row: Verifier with output for i in range(3): if i == 1: line = (f" {self._get_agent_color('verifier')}{self._agent_art['verifier'][i]}{Style.RESET_ALL} → Output") else: line = (f" {self._get_agent_color('verifier')}{self._agent_art['verifier'][i]}{Style.RESET_ALL}") lines.append(line) # # Feedback arrows from Verifier # lines.append(" ↑") # lines.append(" ↗ ↑") # Add status message if self._status_message: lines.append("") lines.append(f"{Fore.YELLOW}Status: {self._status_message}{Style.RESET_ALL}") # Print the visualization print("\n".join(lines)) sys.stdout.flush() def reset(self): """Reset the visualization state.""" self.active_agent = None self._status_message = "" self._current_component = "" self._current_file = "" self._clear_screen() ================================================ FILE: src/visualizer/web_bridge.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Web bridge for the docstring generation visualizers. This module provides adapters that connect the existing terminal-based visualizers to the web interface. When enabled, the visualizers will send updates to the web interface in addition to their normal terminal output. """ import threading import time import functools from typing import Dict, Any, Optional # Singleton pattern for the web socket manager class WebSocketManager: """Manages the connection to the web socket for sending visualization updates.""" _instance = None _socket = None _enabled = False def __new__(cls): if cls._instance is None: cls._instance = super(WebSocketManager, cls).__new__(cls) return cls._instance @classmethod def set_socket(cls, socket): """Set the socket.io instance for sending updates.""" cls._socket = socket cls._enabled = True @classmethod def is_enabled(cls): """Check if web visualization is enabled.""" return cls._enabled and cls._socket is not None @classmethod def emit(cls, event, data): """Emit an event to the web interface.""" if cls.is_enabled(): try: cls._socket.emit(event, data) except Exception as e: print(f"Error sending web update: {e}") @classmethod def disable(cls): """Disable web visualization.""" cls._enabled = False class WebStatusAdapter: """Adapter for the StatusVisualizer to send updates to the web interface.""" def __init__(self, original_visualizer): """ Initialize the web status adapter. Args: original_visualizer: The original StatusVisualizer instance """ self.original = original_visualizer self.socket_manager = WebSocketManager() # Store original methods to avoid recursion self._original_set_active_agent = original_visualizer.set_active_agent self._original_set_status_message = original_visualizer.set_status_message self._original_set_current_component = original_visualizer.set_current_component def set_active_agent(self, agent_name): """ Set the active agent and send update to web interface. Args: agent_name: Name of the active agent """ # Call the original method directly result = self._original_set_active_agent(agent_name) # Send update to web interface if self.socket_manager.is_enabled(): self.socket_manager.emit('status_update', { 'status': { 'active_agent': agent_name, 'status_message': self.original._status_message, 'current_component': self.original._current_component, 'current_file': self.original._current_file } }) return result def set_status_message(self, message): """ Set the status message and send update to web interface. Args: message: The status message """ # Call the original method directly result = self._original_set_status_message(message) # Send update to web interface if self.socket_manager.is_enabled(): self.socket_manager.emit('status_update', { 'status': { 'active_agent': self.original.active_agent, 'status_message': message, 'current_component': self.original._current_component, 'current_file': self.original._current_file } }) return result def set_current_component(self, focal_component, file_path): """ Set the current component being processed and send update to web interface. Args: focal_component: The component being processed file_path: The path to the file containing the component """ # Call the original method directly result = self._original_set_current_component(focal_component, file_path) # Send update to web interface if self.socket_manager.is_enabled(): self.socket_manager.emit('status_update', { 'status': { 'active_agent': self.original.active_agent, 'status_message': self.original._status_message, 'current_component': focal_component, 'current_file': file_path } }) # Special message format for the web interface to parse print(f"COMPONENT: {focal_component} in file {file_path}") return result class WebProgressAdapter: """Adapter for the ProgressVisualizer to send updates to the web interface.""" def __init__(self, original_visualizer): """ Initialize the web progress adapter. Args: original_visualizer: The original ProgressVisualizer instance """ self.original = original_visualizer self.socket_manager = WebSocketManager() # Store original methods to avoid recursion self._original_update = original_visualizer.update if hasattr(original_visualizer, 'mark_complete'): self._original_mark_complete = original_visualizer.mark_complete def update(self, component_id=None, status="processing"): """ Update the progress visualization and send update to web interface. Args: component_id: ID of the component being processed status: Status of the component """ # Call the original method directly result = self._original_update(component_id, status) # Send update to web interface if self.socket_manager.is_enabled(): # Get the component status from the original visualizer component_status = {} for comp_id in self.original.components: if comp_id in self.original.processed: component_status[comp_id] = "complete" elif comp_id == self.original.current: component_status[comp_id] = "in_progress" else: component_status[comp_id] = "not_started" self.socket_manager.emit('status_update', { 'progress': { 'total_components': len(self.original.sorted_order), 'processed_components': len(self.original.processed), 'current_component': self.original.current, 'component_status': component_status } }) # Special message format for the web interface to parse print(f"PROGRESS: {len(self.original.processed)}/{len(self.original.sorted_order)} components processed") return result def mark_complete(self, component_id): """ Mark a component as complete and send update to web interface. Args: component_id: ID of the component to mark as complete """ # Check if the original visualizer has mark_complete if not hasattr(self, '_original_mark_complete'): # Fall back to update return self.update(component_id, "complete") # Call the original method directly result = self._original_mark_complete(component_id) # Update web interface if self.socket_manager.is_enabled(): # Use the update method to send progress self.update(component_id, "complete") return result def patch_visualizers(): """ Patch the existing visualizer classes to add web interface support. This function should be called before creating any visualizer instances to ensure they have web support. """ from . import StatusVisualizer, ProgressVisualizer # Check if already patched to avoid double patching if hasattr(StatusVisualizer, '_web_patched'): return # Mark as patched StatusVisualizer._web_patched = True ProgressVisualizer._web_patched = True # Store the original __init__ methods original_status_init = StatusVisualizer.__init__ original_progress_init = ProgressVisualizer.__init__ # Create patched __init__ methods def patched_status_init(self, *args, **kwargs): original_status_init(self, *args, **kwargs) # Create adapter and store original methods adapter = WebStatusAdapter(self) # Replace methods with adapter methods self.set_active_agent = adapter.set_active_agent self.set_status_message = adapter.set_status_message self.set_current_component = adapter.set_current_component def patched_progress_init(self, *args, **kwargs): original_progress_init(self, *args, **kwargs) # Create adapter and store original methods adapter = WebProgressAdapter(self) # Replace methods with adapter methods self.update = adapter.update if hasattr(self, 'mark_complete'): self.mark_complete = adapter.mark_complete # Apply the patches StatusVisualizer.__init__ = patched_status_init ProgressVisualizer.__init__ = patched_progress_init ================================================ FILE: src/web/README.md ================================================ # DocAgent Web Interface A real-time web visualization system for the DocAgent docstring generation tool. ## Overview The DocAgent Web Interface provides a modern, interactive web UI for generating and tracking Python docstring generation. The application visualizes the agent-based docstring generation process in real-time, allowing users to monitor progress, view code structure, track completeness metrics, and manage the configuration. ## Features - **Configuration Management**: Easily configure all aspects of the docstring generation process (Repository Path, LLM settings, Flow Control, Docstring Options) through a user-friendly web form. Test LLM API connectivity before starting. - **Real-time Visualization**: Observe the docstring generation process as it happens. - **Agent Status Tracking**: View which agent (Reader, Searcher, Writer, Verifier) is currently active in the generation workflow via a visual graph. - **Repository Structure Visualization**: Interactive tree visualization of your Python codebase, highlighting files as they are processed (White: unprocessed, Yellow: processing, Green: completed). - **Dynamic Progress Tracking**: Real-time progress bars and component completion tracking. - **Completeness Metrics Visualization**: Visual representation of docstring completeness across your codebase, updated as the generation progresses (visible in the left sidebar). - **Log Viewer**: Consolidated view of the generation process logs. - **Process Control**: Start and stop the generation process via UI buttons. ## Architecture ### Backend The web application is built using: - **Flask**: Web framework for the backend server - **Socket.IO**: Real-time bidirectional communication between client and server - **Eventlet**: Asynchronous networking library for handling concurrent connections ### Frontend The frontend uses: - **Bootstrap 5**: CSS framework for responsive design - **D3.js**: Data visualization library for interactive repository and agent visualizations - **Socket.IO Client**: Real-time communication with the backend - **jQuery**: DOM manipulation and event handling ### Directory Structure ``` src/web/ ├── app.py - Main Flask application ├── config_handler.py - Handles configuration loading/saving ├── process_handler.py - Manages the docstring generation process ├── visualization_handler.py - Handles visualization state management ├── static/ - Static assets │ ├── css/ - CSS stylesheets │ │ └── style.css - Custom styling │ └── js/ - JavaScript files │ ├── completeness.js - Completeness visualization │ ├── config.js - Configuration handling │ ├── log-handler.js - Log display handling │ ├── main.js - Main application logic │ ├── repo-structure.js - Repository structure visualization │ └── status-visualizer.js - Agent status visualization └── templates/ - HTML templates └── index.html - Main application page ``` ## Data Flow 1. User configures settings via the web form. 2. User clicks "Start Generation". 3. Flask backend spawns a subprocess running the `generate_docstrings.py` script (expected in the project root). 4. Process output (status updates, logs, metrics) is captured and parsed in real-time by the backend. 5. Parsed events are emitted via Socket.IO to the frontend. 6. Frontend components (Agent Status, Repo Structure, Logs, Progress, Completeness) update dynamically based on the received events. 7. User receives real-time feedback on the generation process. 8. User can stop the process using the "Stop Generation" button. ## Usage Guide ### 1. Starting the Web Interface Run the web application from the project root directory: ```bash python run_web_ui.py ``` By default, the web interface will be available at `http://127.0.0.1:5000`. You can customize the host and port: ```bash # Example: Run on port 8080, accessible externally python run_web_ui.py --host 0.0.0.0 --port 8080 ``` ### 2. Configuration The initial screen presents configuration options: - **Repository Path**: Path to the Python codebase for docstring generation. - **LLM Configuration**: Settings for the language model (Type, API Key, Model, Temperature, Max Tokens). Use the "Test API" button to verify credentials. - **Flow Control**: Advanced settings for the generation process. - **Docstring Options**: Control options like overwriting existing docstrings. ### 3. Starting the Generation Process 1. Fill in the configuration form accurately. 2. Click "Start Generation". 3. The interface will switch to the monitoring/visualization view. ### 4. Monitoring the Generation Process The visualization interface consists of several panels: - **Agent Status Panel**: Shows the current active agent in the workflow graph. - **Repository Structure Panel**: Displays the interactive codebase tree, highlighting the currently processed file. - **Logs and Progress Panel**: Shows real-time logs and overall progress. - **Completeness Panel (Sidebar)**: Shows statistics about docstring completeness. ### 5. Stopping the Process Click the "Stop Generation" button in the header to terminate the process early. ================================================ FILE: src/web/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Web application for docstring generation visualization. This module provides a web-based interface for configuring and visualizing the progress of docstring generation in a Python codebase. """ from .app import create_app __all__ = ['create_app'] ================================================ FILE: src/web/app.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Main Flask application for the docstring generation visualization. This module defines the Flask application, routes, and event handlers for the web-based docstring generation visualization system. """ import os import json import yaml import threading import eventlet from pathlib import Path from flask import Flask, render_template, request, jsonify, send_from_directory from flask_socketio import SocketIO # Patch standard library for async support with eventlet eventlet.monkey_patch() from . import config_handler from . import visualization_handler from . import process_handler def create_app(debug=True): """ Create and configure the Flask application. Args: debug: Whether to run the application in debug mode Returns: The configured Flask application instance """ app = Flask(__name__, static_folder='static', template_folder='templates') app.config['SECRET_KEY'] = 'docstring-generator-secret!' app.config['DEBUG'] = debug # Initialize SocketIO for real-time updates with async mode socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet') # Store application state app.config['APP_STATE'] = { 'is_running': False, 'config': {}, 'repo_path': '', 'process': None } # Routes @app.route('/') def index(): """Render the main application page.""" return render_template('index.html') @app.route('/api/default_config') def get_default_config(): """Get the default configuration from agent_config.yaml.""" return jsonify(config_handler.get_default_config()) @app.route('/api/test_api', methods=['POST']) def test_api(): """Test the LLM API connection with a simple query.""" data = request.json if not data or 'api_key' not in data or not data['api_key']: return jsonify({ 'status': 'error', 'message': 'API key is required' }) # Get the configuration llm_type = data.get('llm_type', 'claude') api_key = data.get('api_key', '') model = data.get('model', 'claude-3-5-haiku-latest') try: # Import the appropriate LLM client based on type if llm_type.lower() == 'claude': try: import anthropic client = anthropic.Anthropic(api_key=api_key) # Send a simple test message response = client.messages.create( model=model, max_tokens=100, messages=[ {"role": "user", "content": "Who are you? Please keep your answer very brief."} ] ) # Extract the response text if response and hasattr(response, 'content') and len(response.content) > 0: model_response = response.content[0].text else: model_response = "No response content" return jsonify({ 'status': 'success', 'message': 'Successfully connected to Claude API', 'model_response': model_response }) except Exception as e: return jsonify({ 'status': 'error', 'message': f'Error connecting to Claude API: {str(e)}' }) elif llm_type.lower() == 'openai': try: import openai client = openai.OpenAI(api_key=api_key) # Send a simple test message response = client.chat.completions.create( model=model, max_tokens=100, messages=[ {"role": "user", "content": "Who are you? Please keep your answer very brief."} ] ) # Extract the response text if response and hasattr(response, 'choices') and len(response.choices) > 0: model_response = response.choices[0].message.content else: model_response = "No response content" return jsonify({ 'status': 'success', 'message': 'Successfully connected to OpenAI API', 'model_response': model_response }) except Exception as e: return jsonify({ 'status': 'error', 'message': f'Error connecting to OpenAI API: {str(e)}' }) else: return jsonify({ 'status': 'error', 'message': f'Unsupported LLM type: {llm_type}' }) except ImportError as e: return jsonify({ 'status': 'error', 'message': f'Missing required dependency: {str(e)}' }) @app.route('/api/start', methods=['POST']) def start_generation(): """Start the docstring generation process.""" if app.config['APP_STATE']['is_running']: return jsonify({'status': 'error', 'message': 'Generation already in progress'}) data = request.json # Validate repo path repo_path = data['repo_path'] if not os.path.exists(repo_path): return jsonify({'status': 'error', 'message': f'Repository path not found: {repo_path}'}) # Save configuration try: config_path = config_handler.save_config(data['config']) except ValueError as e: return jsonify({'status': 'error', 'message': str(e)}) # Store in application state app.config['APP_STATE']['config'] = data['config'] app.config['APP_STATE']['repo_path'] = repo_path app.config['APP_STATE']['is_running'] = True # Start the generation process thread = socketio.start_background_task( process_handler.start_generation_process, socketio, repo_path, config_path ) app.config['APP_STATE']['process'] = thread return jsonify({'status': 'success', 'message': 'Generation started'}) @app.route('/api/stop', methods=['POST']) def stop_generation(): """Stop the docstring generation process.""" if not app.config['APP_STATE']['is_running']: return jsonify({'status': 'error', 'message': 'No generation in progress'}) process_handler.stop_generation_process() app.config['APP_STATE']['is_running'] = False return jsonify({'status': 'success', 'message': 'Generation stopped'}) @app.route('/api/status') def get_status(): """Get the current status of the generation process.""" return jsonify({ 'is_running': app.config['APP_STATE']['is_running'], 'repo_path': app.config['APP_STATE']['repo_path'] }) @app.route('/api/completeness') def get_completeness(): """Get the current completeness evaluation of the repository.""" if not app.config['APP_STATE']['repo_path']: return jsonify({'status': 'error', 'message': 'No repository selected'}) results = visualization_handler.get_completeness_data(app.config['APP_STATE']['repo_path']) return jsonify(results) # Socket.IO event handlers @socketio.on('connect') def handle_connect(): """Handle client connection to Socket.IO.""" if app.config['APP_STATE']['is_running']: # Send current state to newly connected client socketio.emit('status_update', visualization_handler.get_current_status()) # Additional routes and event handlers can be added here return app, socketio ================================================ FILE: src/web/config_handler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Configuration handler for the docstring generation web interface. This module handles reading, writing, and validating the configuration for the docstring generation process. """ import os import yaml import json import tempfile from pathlib import Path def get_default_config(): """ Get the default configuration from agent_config.yaml. Returns: Dictionary containing the default configuration """ default_config_path = Path('config/agent_config.yaml') if not default_config_path.exists(): return { 'llm': { 'type': 'claude', 'api_key': '', 'model': 'claude-3-5-haiku-latest', 'temperature': 0.1, 'max_tokens': 4096 }, 'flow_control': { 'max_reader_search_attempts': 2, 'max_verifier_rejections': 1, 'status_sleep_time': 1 }, 'docstring_options': { 'overwrite_docstrings': False } } with open(default_config_path, 'r') as f: config = yaml.safe_load(f) return config def validate_config(config): """ Validate that the configuration has the required fields. Args: config: Dictionary containing the configuration to validate Returns: Tuple of (is_valid, error_message) """ required_keys = ['llm', 'flow_control', 'docstring_options'] for key in required_keys: if key not in config: return False, f"Missing required configuration section: {key}" # Check specific required fields in llm section llm_required = ['type', 'api_key', 'model'] for key in llm_required: if key not in config['llm']: return False, f"Missing required field in llm section: {key}" return True, "" def save_config(config): """ Save the configuration to a temporary file for use by the generation process. Args: config: Dictionary containing the configuration to save Returns: Path to the saved configuration file """ # Validate configuration is_valid, error_message = validate_config(config) if not is_valid: raise ValueError(f"Invalid configuration: {error_message}") # Create a temporary file temp_dir = tempfile.gettempdir() config_file = os.path.join(temp_dir, 'docstring_generator_config.yaml') with open(config_file, 'w') as f: yaml.dump(config, f, default_flow_style=False) return config_file ================================================ FILE: src/web/process_handler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Process handler for running the docstring generation. This module handles starting, monitoring, and stopping the docstring generation process, as well as capturing its output and sending it to the web interface. """ import os import sys import subprocess import threading import tempfile import signal import re from pathlib import Path from typing import Optional, Dict, Any from . import visualization_handler # Global variables to track the process process = None should_stop = False # Custom output handler to intercept and parse the output class OutputHandler(threading.Thread): """Thread to handle output from the docstring generation process.""" def __init__(self, process, socketio): """ Initialize the output handler. Args: process: The subprocess.Popen object for the docstring generation process socketio: The Flask-SocketIO instance for sending updates to clients """ threading.Thread.__init__(self) self.process = process self.socketio = socketio self.daemon = True def run(self): """Read output from the process and update the visualization state.""" global should_stop # Regular expressions for parsing different types of output status_regex = re.compile(r'STATUS: Agent: (\w+), Message: (.+)') component_regex = re.compile(r'COMPONENT: (.+) in file (.+)') progress_regex = re.compile(r'PROGRESS: (\d+)/(\d+) components processed') log_regex = re.compile(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - (\w+) - (\w+) - (.+)') # Additional regex to detect agent activity from regular logs agent_activity_regex = re.compile(r'(reader|writer|searcher|verifier)', re.IGNORECASE) docstring_update_regex = re.compile(r'Successfully updated docstring for (.+)|Completed docstring generation for (.+)', re.IGNORECASE) # Patterns to filter out visualization-related output from logs visualization_patterns = [ r'┌─+┐', # Box top r'│.*│', # Box content r'└─+┘', # Box bottom r'Agent:', # Agent status r'Status:', # Status message r'Component:', # Component info r'╔═+╗', # Double-line box top r'║.*║', # Double-line box content r'╚═+╝', # Double-line box bottom r'▶ ', # Progress indicators r'→ ', # Arrow indicators r'⦿', # Bullet indicators r'Processing component \d+/\d+', # Progress messages r'╡.*╞', # Table separators r'═+', # Table lines r'DocAgent (?:Workflow )?Status', # Workflow status header r'Processing: ', # Processing status line r'File: ', # File status line r'Active Agent: ', # Agent status line r'Status: ', # Status message line r'Workflow Input:', # Input section r'Component Name:', # Input component name r'File Path:', # Input file path r'Dependencies:', # Input dependencies r'Code:', # Input code r'^Input:', # Input header r'\[.*?\]', # Status messages in brackets ] visualization_filter = re.compile('|'.join(visualization_patterns)) # Read each line from the process output for line in iter(self.process.stdout.readline, b''): if should_stop: break # Decode the line try: line = line.decode('utf-8').rstrip() except UnicodeDecodeError: continue # Process workflow status lines separately to update agent status if 'Processing:' in line or 'File:' in line: if 'Processing:' in line: component = line.split('Processing:')[1].strip() if component: visualization_handler.update_component_focus(component, "") if 'File:' in line: file_path = line.split('File:')[1].strip() if file_path: # Update the current file without changing the component current_status = visualization_handler.get_current_status() if 'status' in current_status and current_status['status'].get('current_component'): visualization_handler.update_component_focus( current_status['status']['current_component'], file_path ) self.socketio.emit('status_update', visualization_handler.get_current_status()) # Add to log messages - filter out visualization if not visualization_filter.search(line): visualization_handler.add_log_message(line) self.socketio.emit('log_line', line) # Check for status updates status_match = status_regex.search(line) if status_match: agent, message = status_match.groups() visualization_handler.update_agent_status(agent, message) self.socketio.emit('status_update', visualization_handler.get_current_status()) continue # Check for agent activity in regular logs if not status_match: # Only check if we didn't already match a status agent_match = agent_activity_regex.search(line) if agent_match and ('active' in line.lower() or 'using' in line.lower() or 'processing' in line.lower()): # Extract agent name from logs agent = agent_match.group(1).capitalize() visualization_handler.update_agent_status(agent, "Processing") self.socketio.emit('status_update', visualization_handler.get_current_status()) # Check for component updates component_match = component_regex.search(line) if component_match: component, file_path = component_match.groups() visualization_handler.update_component_focus(component, file_path) visualization_handler.update_file_status(file_path, 'in_progress') self.socketio.emit('status_update', visualization_handler.get_current_status()) continue # Check for progress updates progress_match = progress_regex.search(line) if progress_match: processed, total = progress_match.groups() # We don't have the current component or component status from this regex, # so we'll just update the counts visualization_handler.update_progress(int(total), int(processed), '', {}) self.socketio.emit('status_update', visualization_handler.get_current_status()) continue # Also check for progress updates in normal log lines progress_in_log = re.search(r'Processing component (\d+)/(\d+)', line) if progress_in_log: current, total = progress_in_log.groups() visualization_handler.update_progress(int(total), int(current), '', {}) self.socketio.emit('status_update', visualization_handler.get_current_status()) # Check for docstring updates docstring_update_match = docstring_update_regex.search(line) if docstring_update_match: component = docstring_update_match.group(1) or docstring_update_match.group(2) # If this is a file path, extract it if component and '/' in component: file_path = component visualization_handler.update_file_status(file_path, 'complete') self.socketio.emit('status_update', visualization_handler.get_current_status()) # Emit a special event for docstring updates self.socketio.emit('docstring_updated', {'component': component}) # Try to extract component information from other log lines if 'Processing' in line and ':' in line and 'file' in line: parts = line.split('file') if len(parts) > 1: file_path = parts[1].strip() component = parts[0].split('Processing')[-1].strip() if component and file_path: visualization_handler.update_component_focus(component, file_path) visualization_handler.update_file_status(file_path, 'in_progress') self.socketio.emit('status_update', visualization_handler.get_current_status()) # Check for log messages log_match = log_regex.search(line) if log_match: _, level, message = log_match.groups() # If the message indicates completion of a file, update the file status if 'Completed docstring generation for' in message or 'Successfully updated docstring for' in message: # Try to extract the file path from the message file_match = re.search(r'for file (.+)$|for (.+)', message) if file_match: file_path = file_match.group(1) or file_match.group(2) if file_path and '.' in file_path: # Simple check to ensure it looks like a filename visualization_handler.update_file_status(file_path, 'complete') self.socketio.emit('status_update', visualization_handler.get_current_status()) # Emit a special event for docstring updates self.socketio.emit('docstring_updated', {'component': file_path}) self.socketio.emit('log_message', {'level': level, 'message': message}) def start_generation_process(socketio, repo_path: str, config_path: str): """ Start the docstring generation process. Args: socketio: The Flask-SocketIO instance for sending updates to clients repo_path: Path to the repository to generate docstrings for config_path: Path to the configuration file """ global process, should_stop should_stop = False # Set an initial status to show we're starting visualization_handler.update_agent_status("System", "Starting docstring generation...") socketio.emit('status_update', visualization_handler.get_current_status()) # Connect the socket to the web bridge try: from src.visualizer.web_bridge import WebSocketManager WebSocketManager.set_socket(socketio) except ImportError: socketio.emit('log_message', { 'level': 'warning', 'message': 'Web bridge not available. Some features may not work correctly.' }) # Get the repository structure and update the visualization state try: structure = visualization_handler.get_repo_structure(repo_path) socketio.emit('status_update', visualization_handler.get_current_status()) socketio.emit('log_message', { 'level': 'info', 'message': f'Repository structure loaded with {len(structure["children"])} top-level items' }) except Exception as e: socketio.emit('log_message', { 'level': 'error', 'message': f'Error loading repository structure: {str(e)}' }) # Find the generate_docstrings.py script script_path = Path(__file__).parent.parent.parent / 'generate_docstrings.py' if not script_path.exists(): socketio.emit('error', { 'message': f'Could not find docstring generation script at {script_path}' }) return # Start the process try: # Create a temporary file for redirecting stdout and stderr process = subprocess.Popen( [sys.executable, str(script_path), '--repo-path', repo_path, '--config-path', config_path, '--enable-web'], # Enable web integration stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=False ) # Start the output handler handler = OutputHandler(process, socketio) handler.start() # Wait for the process to complete return_code = process.wait() if return_code == 0: socketio.emit('complete', { 'message': 'Docstring generation completed successfully' }) else: socketio.emit('error', { 'message': f'Docstring generation failed with return code {return_code}' }) except Exception as e: socketio.emit('error', { 'message': f'Error starting docstring generation process: {str(e)}' }) finally: process = None def stop_generation_process(): """ Stop the docstring generation process. Returns: True if the process was stopped, False otherwise """ global process, should_stop if process is None: return False should_stop = True try: # Disconnect from the web bridge try: from src.visualizer.web_bridge import WebSocketManager WebSocketManager.disable() except ImportError: pass # Try to terminate the process gracefully first process.terminate() # Wait for up to 5 seconds for the process to terminate try: process.wait(timeout=5) except subprocess.TimeoutExpired: # If the process didn't terminate, kill it process.kill() return True except Exception as e: print(f"Error stopping process: {e}") return False ================================================ FILE: src/web/run.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Entry point for running the docstring generation visualization web application. This script creates and starts the Flask application for visualizing the docstring generation process. """ import os import sys import argparse from pathlib import Path from .app import create_app def main(): """ Parse command line arguments and start the web application. """ parser = argparse.ArgumentParser(description='Start the docstring generation visualization web application') parser.add_argument('--host', default='127.0.0.1', help='Host to bind the server to') parser.add_argument('--port', type=int, default=5000, help='Port to bind the server to') parser.add_argument('--debug', action='store_true', help='Run the application in debug mode') args = parser.parse_args() # Create the Flask application app, socketio = create_app(debug=args.debug) print(f"Starting docstring generation visualization web application on http://{args.host}:{args.port}") print("Press Ctrl+C to stop the server") # Start the server socketio.run(app, host=args.host, port=args.port, debug=args.debug, allow_unsafe_werkzeug=True) if __name__ == '__main__': # Add the parent directory to the path so we can import the module sys.path.insert(0, str(Path(__file__).parent.parent.parent)) main() ================================================ FILE: src/web/static/css/style.css ================================================ /* Copyright (c) Meta Platforms, Inc. and affiliates */ /* Main layout styles */ body { overflow-x: hidden; } .sidebar { transition: width 0.3s ease; box-shadow: 2px 0 5px rgba(0, 0, 0, 0.1); } /* Header logo styles */ .header-logo { max-height: 30px; margin-right: 10px; } /* Transition for main content when sidebar changes */ .main-content-transition { transition: all 0.3s ease; } /* Status visualizer styles */ .agent-box { border: 1px solid #ccc; border-radius: 5px; padding: 10px; margin-bottom: 10px; text-align: center; transition: all 0.3s ease; } .agent-box.active { border-color: #198754; box-shadow: 0 0 5px rgba(25, 135, 84, 0.5); background-color: rgba(25, 135, 84, 0.1); } .agent-box h3 { margin-top: 5px; font-size: 1.2rem; } .component-info { margin-top: 20px; padding: 10px; background-color: #f8f9fa; border-radius: 5px; border-left: 3px solid #007bff; } /* Agent workflow visualization styles */ #agent-workflow { min-height: 200px; } .workflow-node circle { fill: #ffffff; /* White background by default */ stroke: #6c757d; stroke-width: 1.5px; transition: all 0.3s ease; } .workflow-node.active circle { fill: #198754; /* Green background when active */ stroke: #0d6efd; stroke-width: 2px; } .workflow-link { stroke: #adb5bd; stroke-width: 2px; fill: none; marker-end: url(#arrowhead); } .workflow-label { font-size: 12px; text-anchor: middle; dominant-baseline: middle; fill: #212529; pointer-events: none; transition: all 0.3s ease; } .workflow-node.active .workflow-label { fill: #fff; font-weight: bold; } .workflow-text-label { font-size: 14px; text-anchor: middle; dominant-baseline: middle; fill: #666; font-weight: bold; } /* Repository structure styles */ .repo-node { cursor: pointer; transition: all 0.2s ease; } .repo-node:hover { filter: brightness(0.9); } .repo-node-label { font-size: 0.9rem; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; } .repo-node-complete { fill: #198754; /* Green */ } .repo-node-in-progress { fill: #ffc107; /* Yellow */ } .repo-node-not-started { fill: #f8f9fa; /* Light grey */ } .repo-node-focus { stroke: #dc3545; /* Red */ stroke-width: 2; } /* Log container styles */ #log-container { font-family: monospace; font-size: 0.85rem; line-height: 1.5; background-color: #212529; color: #f8f9fa; border-radius: 5px; height: 250px; max-height: 250px; } .log-line { margin-bottom: 2px; white-space: pre-wrap; word-break: break-word; } .log-info { color: #f8f9fa; } .log-warning { color: #ffc107; } .log-error { color: #dc3545; } .log-debug { color: #6c757d; } /* Completeness table styles */ .completeness-table { font-size: 0.9rem; } .progress-cell { width: 100px; } .progress-bar-mini { height: 10px; margin-top: 5px; border-radius: 5px; } /* Animation for focus transitions */ @keyframes pulse { 0% { transform: scale(1); opacity: 1; } 50% { transform: scale(1.05); opacity: 0.8; } 100% { transform: scale(1); opacity: 1; } } .highlight-focus { animation: pulse 1s; } ================================================ FILE: src/web/static/js/completeness.js ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates /** * Completeness visualization for the docstring generation web application. * * This file provides functions for rendering and updating the completeness * visualization in the web interface. */ /** * Update the completeness view with the evaluation results. * * @param {Object} completenessData - The completeness evaluation data from the server */ function updateCompletenessView(completenessData) { if (!completenessData || !completenessData.files) { $('#completeness-data').html(`
No completeness data available
`); return; } // Calculate overall statistics const totalFiles = completenessData.files.length; let totalClasses = 0; let totalClassesWithDocs = 0; let totalFunctions = 0; let totalFunctionsWithDocs = 0; completenessData.files.forEach(file => { if (file.classes) { totalClasses += file.classes.length; totalClassesWithDocs += file.classes.filter(c => c.has_docstring).length; } if (file.functions) { totalFunctions += file.functions.length; totalFunctionsWithDocs += file.functions.filter(f => f.has_docstring).length; } }); const classCompleteness = totalClasses > 0 ? Math.round((totalClassesWithDocs / totalClasses) * 100) : 0; const functionCompleteness = totalFunctions > 0 ? Math.round((totalFunctionsWithDocs / totalFunctions) * 100) : 0; const totalComponents = totalClasses + totalFunctions; const totalComponentsWithDocs = totalClassesWithDocs + totalFunctionsWithDocs; const overallCompleteness = totalComponents > 0 ? Math.round((totalComponentsWithDocs / totalComponents) * 100) : 0; // Create the HTML for the completeness view let html = `
Overall Completeness: ${overallCompleteness}%
${overallCompleteness}%
Classes: ${totalClassesWithDocs}/${totalClasses} (${classCompleteness}%)
Functions: ${totalFunctionsWithDocs}/${totalFunctions} (${functionCompleteness}%)
Files (${totalFiles})
`; // Sort files by completeness (lowest first) const sortedFiles = [...completenessData.files].sort((a, b) => { const aTotal = (a.classes?.length || 0) + (a.functions?.length || 0); const aWithDocs = (a.classes?.filter(c => c.has_docstring).length || 0) + (a.functions?.filter(f => f.has_docstring).length || 0); const aPercentage = aTotal > 0 ? (aWithDocs / aTotal) : 1; const bTotal = (b.classes?.length || 0) + (b.functions?.length || 0); const bWithDocs = (b.classes?.filter(c => c.has_docstring).length || 0) + (b.functions?.filter(f => f.has_docstring).length || 0); const bPercentage = bTotal > 0 ? (bWithDocs / bTotal) : 1; return aPercentage - bPercentage; }); // Add rows for each file sortedFiles.forEach(file => { const classes = file.classes || []; const functions = file.functions || []; const classesWithDocs = classes.filter(c => c.has_docstring).length; const functionsWithDocs = functions.filter(f => f.has_docstring).length; const totalInFile = classes.length + functions.length; const totalWithDocsInFile = classesWithDocs + functionsWithDocs; const fileCompleteness = totalInFile > 0 ? Math.round((totalWithDocsInFile / totalInFile) * 100) : 100; // Determine the row color based on completeness let rowClass = ''; if (fileCompleteness === 100) { rowClass = 'table-success'; } else if (fileCompleteness >= 50) { rowClass = 'table-warning'; } else { rowClass = 'table-danger'; } html += ` `; }); html += `
File Classes Functions Completeness
${file.file.split('/').pop()} ${classesWithDocs}/${classes.length} ${functionsWithDocs}/${functions.length}
${fileCompleteness}%
`; // Update the completeness data container $('#completeness-data').html(html); } ================================================ FILE: src/web/static/js/config.js ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates /** * Configuration handling for the docstring generation web application. * * This file provides functions for loading and saving configuration for the * docstring generation process. */ /** * Load the default configuration from the server. */ function loadDefaultConfig() { $.ajax({ url: '/api/default_config', type: 'GET', success: function(config) { applyConfigToForm(config); }, error: function(xhr, status, error) { console.error('Error loading default configuration:', error); showMessage('warning', 'Failed to load default configuration. Using fallback values.'); } }); } /** * Apply a configuration object to the form inputs. * * @param {Object} config - The configuration object to apply */ function applyConfigToForm(config) { // Set LLM configuration if (config.llm) { $('#llm-type').val(config.llm.type || 'claude'); $('#llm-api-key').val(config.llm.api_key || ''); $('#llm-model').val(config.llm.model || 'claude-3-5-haiku-latest'); $('#llm-temperature').val(config.llm.temperature || 0.1); $('#llm-max-tokens').val(config.llm.max_tokens || 4096); } // Set flow control configuration if (config.flow_control) { $('#max-reader-search-attempts').val(config.flow_control.max_reader_search_attempts || 2); $('#max-verifier-rejections').val(config.flow_control.max_verifier_rejections || 1); $('#status-sleep-time').val(config.flow_control.status_sleep_time || 1); } // Set docstring options if (config.docstring_options) { $('#overwrite-docstrings').prop('checked', config.docstring_options.overwrite_docstrings || false); } } /** * Build a configuration object from the form inputs. * * @returns {Object} The configuration object */ function buildConfigFromForm() { return { llm: { type: $('#llm-type').val(), api_key: $('#llm-api-key').val(), model: $('#llm-model').val(), temperature: parseFloat($('#llm-temperature').val()), max_tokens: parseInt($('#llm-max-tokens').val()) }, flow_control: { max_reader_search_attempts: parseInt($('#max-reader-search-attempts').val()), max_verifier_rejections: parseInt($('#max-verifier-rejections').val()), status_sleep_time: parseFloat($('#status-sleep-time').val()) }, docstring_options: { overwrite_docstrings: $('#overwrite-docstrings').is(':checked') } }; } ================================================ FILE: src/web/static/js/log-handler.js ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates /** * Log message handler for the docstring generation web application. * * This file provides functions for displaying and managing log messages * in the web interface. */ // Maximum number of log lines to keep in the UI const MAX_LOG_LINES = 5000; /** * Add a log message to the log container. * * @param {string} level - The log level (info, warning, error, debug) * @param {string} message - The log message to display */ function addLogMessage(level, message) { // Create a CSS class based on the log level let logClass = 'log-info'; switch (level.toLowerCase()) { case 'warning': case 'warn': logClass = 'log-warning'; break; case 'error': case 'critical': logClass = 'log-error'; break; case 'debug': logClass = 'log-debug'; break; } // Create the log line element const logLine = $(`
`); logLine.text(message); // Add the log line to the log content $('#log-content').append(logLine); // Trim log lines if necessary const logLines = $('#log-content .log-line'); if (logLines.length > MAX_LOG_LINES) { // Remove the oldest lines logLines.slice(0, logLines.length - MAX_LOG_LINES).remove(); } // Scroll to the bottom of the log container const logContainer = $('#log-container'); logContainer.scrollTop(logContainer[0].scrollHeight); } ================================================ FILE: src/web/static/js/main.js ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates /** * Main JavaScript for the docstring generation web application. * * This file provides the main functionality for the web interface, including * event handling, configuration, and communication with the server. */ // Global state variables let socket = null; let processRunning = false; let startTime = 0; let timerInterval = null; let apiTestModal = null; // Document ready handler $(document).ready(function() { // Load default configuration loadDefaultConfig(); // Set up form submission handler $('#config-form').on('submit', function(e) { e.preventDefault(); startGeneration(); }); // Set up test API button handler $('#test-api-button').on('click', function() { testApiConnection(); }); // Initialize the API test modal apiTestModal = new bootstrap.Modal(document.getElementById('api-test-modal')); // Check if a process is already running checkProcessStatus(); // Initialize the agent workflow visualization initAgentWorkflow(); // Handle window resize $(window).on('resize', function() { initAgentWorkflow(); }); }); /** * Test the API connection with the configured settings. */ function testApiConnection() { // Show the modal apiTestModal.show(); // Set the modal content to loading state $('#api-test-result').html(`
Testing API...

Testing API connection...

`); // Get the API configuration const config = { llm_type: $('#llm-type').val(), api_key: $('#llm-api-key').val(), model: $('#llm-model').val() }; // Send a test request to the server $.ajax({ url: '/api/test_api', type: 'POST', contentType: 'application/json', data: JSON.stringify(config), success: function(response) { if (response.status === 'success') { $('#api-test-result').html(`
API Connection Successful

${response.message || 'The API connection is working correctly.'}


Response from model:

${response.model_response || 'No response provided.'}

`); } else { $('#api-test-result').html(`
API Connection Failed

${response.message || 'Failed to connect to the API.'}


Please check your API key and other settings.

`); } }, error: function(xhr, status, error) { $('#api-test-result').html(`
API Connection Failed

Error: ${error}


Please check your API key and other settings.

`); } }); } /** * Check if a process is already running. */ function checkProcessStatus() { $.ajax({ url: '/api/status', type: 'GET', success: function(data) { processRunning = data.is_running; if (processRunning) { // Process is running, switch to the running view showRunningView(); // Connect to Socket.IO setupSocketHandlers(); // Start the timer startTimer(); // Load completeness data initially loadCompletenessData(); } else { // Show the configuration view showConfigView(); } }, error: function(xhr, status, error) { console.error('Error checking process status:', error); showMessage('error', 'Error checking process status: ' + error); } }); } /** * Set up Socket.IO event handlers. */ function setupSocketHandlers() { // Create Socket.IO connection if it doesn't exist if (!socket) { socket = io(); // Status update handler socket.on('status_update', function(data) { console.log('Status update received:', data); if (data.status) { updateStatusVisualizer(data.status); } if (data.repo_structure) { updateRepoStructure(data.repo_structure); } }); // Log message handler socket.on('log_message', function(data) { addLogMessage(data.level, data.message); // If this is a docstring generation success message, refresh completeness if (data.message && ( data.message.includes('Successfully updated docstring for') || data.message.includes('Completed docstring generation for') )) { // Wait a brief moment for file changes to be detected setTimeout(loadCompletenessData, 500); } }); // Raw log message handler (for system prints) socket.on('log_line', function(data) { addLogMessage('info', data); // Check if this is a message about docstring generation if (typeof data === 'string' && ( data.includes('Successfully updated docstring') || data.includes('Completed docstring generation') )) { // Refresh the completeness data setTimeout(loadCompletenessData, 500); } }); // Error handler socket.on('error', function(data) { addLogMessage('error', data.message); showMessage('error', data.message); }); // Completion handler socket.on('complete', function(data) { processRunning = false; $('#start-button').prop('disabled', false).text('Start Generation'); addLogMessage('info', data.message); showMessage('success', 'Docstring generation completed'); stopTimer(); // Final completeness refresh loadCompletenessData(); }); // Disconnection handler socket.on('disconnect', function() { addLogMessage('warning', 'Connection to server lost'); }); } } /** * Start the docstring generation process. */ function startGeneration() { if (processRunning) { showMessage('warning', 'Generation already in progress'); return; } // Get the repository path const repoPath = $('#repo-path').val(); if (!repoPath) { showMessage('error', 'Please enter a repository path'); return; } // Disable the start button $('#start-button').prop('disabled', true).text('Starting...'); // Get the configuration const config = buildConfigFromForm(); // Clear log content $('#log-content').empty(); // Send the request to start generation $.ajax({ url: '/api/start', type: 'POST', contentType: 'application/json', data: JSON.stringify({ repo_path: repoPath, config: config }), success: function(data) { if (data.status === 'success') { // Mark as running processRunning = true; // Show the running view showRunningView(); // Connect to Socket.IO setupSocketHandlers(); // Start the timer startTimer(); // Make the completeness section visible and load initial data $('#completeness-section').removeClass('d-none'); loadCompletenessData(); // Show success message showMessage('success', data.message); } else { // Show error message showMessage('error', data.message); $('#start-button').prop('disabled', false).text('Start Generation'); } }, error: function(xhr, status, error) { showMessage('error', 'Error starting generation: ' + error); $('#start-button').prop('disabled', false).text('Start Generation'); } }); } /** * Stop the docstring generation process. */ function stopGeneration() { if (!processRunning) { showMessage('warning', 'No generation in progress'); return; } // Confirm stop if (!confirm('Are you sure you want to stop the docstring generation process?')) { return; } // Send the request to stop generation $.ajax({ url: '/api/stop', type: 'POST', success: function(data) { if (data.status === 'success') { processRunning = false; $('#start-button').prop('disabled', false).text('Start Generation'); showMessage('success', data.message); stopTimer(); // Add log message addLogMessage('warning', 'Generation process stopped by user'); } else { showMessage('error', data.message); } }, error: function(xhr, status, error) { showMessage('error', 'Error stopping generation: ' + error); } }); } /** * Show the configuration view. */ function showConfigView() { $('#main-content').addClass('d-none'); $('#sidebar').removeClass('col-md-3').addClass('col-md-12'); $('#config-section').removeClass('d-none'); $('#completeness-section').addClass('d-none'); } /** * Show the running view. */ function showRunningView() { $('#config-section').addClass('d-none'); $('#completeness-section').removeClass('d-none'); $('#sidebar').removeClass('col-md-12').addClass('col-md-3'); $('#main-content').removeClass('d-none'); // Make sure the agent workflow is initialized setTimeout(function() { initAgentWorkflow(); }, 100); // Add a stop button to the header if ($('#stop-button').length === 0) { $('header').append(` `); // Add click handler $('#stop-button').on('click', function() { stopGeneration(); }); } } /** * Show a message to the user. * * @param {string} type - The type of message (success, error, warning, info) * @param {string} message - The message to show */ function showMessage(type, message) { // Create alert if it doesn't exist if ($('#alert-container').length === 0) { $('body').append(`
`); } // Create a unique ID for the alert const id = 'alert-' + Date.now(); // Add the alert to the container $('#alert-container').append(` `); // Automatically remove the alert after 5 seconds setTimeout(() => { $(`#${id}`).alert('close'); }, 5000); } /** * Start the timer. */ function startTimer() { // Set the start time startTime = Date.now(); // Clear any existing timer if (timerInterval) { clearInterval(timerInterval); } // Update every second timerInterval = setInterval(() => { const elapsedSeconds = Math.floor((Date.now() - startTime) / 1000); const minutes = Math.floor(elapsedSeconds / 60); const seconds = elapsedSeconds % 60; // Format as MM:SS const formattedTime = `${minutes.toString().padStart(2, '0')}:${seconds.toString().padStart(2, '0')}`; // Update the display $('#progress-time').text(`Elapsed: ${formattedTime}`); }, 1000); } /** * Stop the timer. */ function stopTimer() { if (timerInterval) { clearInterval(timerInterval); timerInterval = null; } } /** * Load completeness data from the server. */ function loadCompletenessData() { // Only load data if the completeness section is visible if ($('#completeness-section').hasClass('d-none')) { return; } $.ajax({ url: '/api/completeness', type: 'GET', success: function(response) { if (response.status === 'success' && response.data) { updateCompletenessView(response.data); } else { $('#completeness-data').html(`
${response.message || 'Failed to load completeness data'}
`); } }, error: function(xhr, status, error) { console.error('Error loading completeness data:', error); $('#completeness-data').html(`
Error loading completeness data: ${error}
`); } }); } ================================================ FILE: src/web/static/js/repo-structure.js ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates /** * Repository structure visualization for the docstring generation web application. * * This file provides functions for rendering and updating the repository structure * visualization using D3.js. */ // Store the current repository structure let currentRepoStructure = null; // Keep track of the current focus path let currentFocusPath = null; // D3 visualization settings const margin = { top: 20, right: 20, bottom: 20, left: 20 }; let width = 600; let height = 500; let nodeRadius = 7; let maxLabelLength = 20; /** * Update the repository structure visualization. * * @param {Object} repoStructure - The repository structure object from the server */ function updateRepoStructure(repoStructure) { // If there's no repo structure, show placeholder if (!repoStructure || !repoStructure.tree || Object.keys(repoStructure.tree).length === 0) { $('#repo-structure').html(`

No repository structure available

`); return; } // Store the previous focus path const prevFocusPath = currentFocusPath; // Update the current state currentRepoStructure = repoStructure; currentFocusPath = repoStructure.focus_path; // Update dimensions based on container size const container = document.getElementById('repo-structure'); width = container.clientWidth - margin.left - margin.right; height = container.clientHeight - margin.top - margin.bottom; // Clear existing visualization $('#repo-structure').empty(); // Create SVG container const svg = d3.select('#repo-structure') .append('svg') .attr('width', width + margin.left + margin.right) .attr('height', height + margin.top + margin.bottom) .append('g') .attr('transform', `translate(${margin.left},${margin.top})`); // Create hierarchy from the data const root = d3.hierarchy(repoStructure.tree); // Set node size based on number of nodes to avoid overlapping const nodeCount = root.descendants().length; const dynamicRadius = Math.max(3, Math.min(7, 10 - Math.log(nodeCount))); nodeRadius = dynamicRadius; // Create tree layout const treeLayout = d3.tree() .size([height, width - 160]); // Compute the tree layout treeLayout(root); // Add links between nodes svg.selectAll('.link') .data(root.links()) .enter() .append('path') .attr('class', 'link') .attr('d', d => { return `M${d.source.y},${d.source.x} C${(d.source.y + d.target.y) / 2},${d.source.x} ${(d.source.y + d.target.y) / 2},${d.target.x} ${d.target.y},${d.target.x}`; }) .attr('fill', 'none') .attr('stroke', '#ccc') .attr('stroke-width', 1.5); // Add nodes const nodes = svg.selectAll('.node') .data(root.descendants()) .enter() .append('g') .attr('class', 'node') .attr('transform', d => `translate(${d.y},${d.x})`) .attr('id', d => `node-${d.data.path.replace(/[\/\.]/g, '_')}`); // Add ID for easier selection // Add node circles nodes.append('circle') .attr('r', nodeRadius) .attr('class', d => { let classes = 'repo-node '; // Add status class if (d.data.type === 'file') { classes += `repo-node-${d.data.status || 'not-started'}`; } else { // For directories, determine status based on children const hasCompleteChildren = d.descendants().slice(1).some(node => node.data.type === 'file' && node.data.status === 'complete'); const hasInProgressChildren = d.descendants().slice(1).some(node => node.data.type === 'file' && node.data.status === 'in_progress'); if (hasCompleteChildren && !hasInProgressChildren) { classes += 'repo-node-complete'; } else if (hasInProgressChildren) { classes += 'repo-node-in-progress'; } else { classes += 'repo-node-not-started'; } } // Add focus class if this is the focused node if (d.data.path === currentFocusPath) { classes += ' repo-node-focus'; } return classes; }) .style('fill', d => { if (d.data.type === 'dir') { // Check children status for directory coloring const completeCount = d.descendants().slice(1).filter(node => node.data.type === 'file' && node.data.status === 'complete').length; const totalFiles = d.descendants().slice(1).filter(node => node.data.type === 'file').length; const progress = totalFiles > 0 ? completeCount / totalFiles : 0; // Use color gradient based on completion percentage if (progress === 1) return '#198754'; // All complete - green if (progress > 0) return '#ffc107'; // Some complete - yellow return '#6c757d'; // None complete - grey } else { // Colors for files based on status return d.data.status === 'complete' ? '#198754' : d.data.status === 'in_progress' ? '#ffc107' : '#f8f9fa'; } }) .style('stroke', d => d.data.path === currentFocusPath ? '#dc3545' : '#6c757d') .style('stroke-width', d => d.data.path === currentFocusPath ? 2 : 1); // Add node labels nodes.append('text') .attr('dy', 3) .attr('x', d => d.children ? -nodeRadius * 1.5 : nodeRadius * 1.5) .attr('text-anchor', d => d.children ? 'end' : 'start') .attr('class', 'repo-node-label') .text(d => { const name = d.data.name; if (name.length > maxLabelLength) { return name.substring(0, maxLabelLength - 3) + '...'; } return name; }) .append('title') // Add tooltip with full name .text(d => d.data.name); // Find the focused node if it exists if (currentFocusPath) { const focusedNode = root.descendants().find(d => d.data.path === currentFocusPath); if (focusedNode) { // If focus has changed, trigger the zoom animation if (prevFocusPath !== currentFocusPath) { zoomToNode(svg, focusedNode, width, height); } } } } /** * Zoom to a specific node in the visualization. * * @param {Object} svg - The D3 SVG selection * @param {Object} node - The node to zoom to * @param {number} width - The width of the container * @param {number} height - The height of the container */ function zoomToNode(svg, node, width, height) { // Calculate the scale factor based on how deep the node is in the tree const depth = node.depth; const scale = Math.max(1, Math.min(2, 1 + depth * 0.2)); // Calculate translation to center the node const x = node.x; const y = node.y; const tx = width/2 - y * scale; const ty = height/2 - x * scale; // Apply the zoom transformation svg.transition() .duration(750) .attr('transform', `translate(${margin.left + tx},${margin.top + ty}) scale(${scale})`); // Add a highlight animation to the node const nodeId = `#node-${node.data.path.replace(/[\/\.]/g, '_')} circle`; d3.select(nodeId) .classed('highlight-focus', true) .transition() .duration(750) .on('end', function() { d3.select(this).classed('highlight-focus', false); }); } /** * Update the status of a file in the repository structure. * * @param {string} file_path - The path of the file to update * @param {string} status - The new status (not_started, in_progress, complete) */ function updateFileStatus(file_path, status) { if (!currentRepoStructure) return; // Find the file in the tree function updateNodeStatus(node) { if (node.path === file_path) { // Only update if the status is actually changing if (node.status !== status) { node.status = status; return true; } return false; } if (node.children) { for (const child of node.children) { if (updateNodeStatus(child)) { return true; } } } return false; } // Update the node status if (updateNodeStatus(currentRepoStructure.tree)) { // If the file status has changed, update the visualization if (status === 'in_progress') { currentRepoStructure.focus_path = file_path; } updateRepoStructure(currentRepoStructure); } } // Initialize the visualization when the document is ready $(document).ready(function() { // If we receive a docstring_updated event, update the repository structure if (socket) { socket.on('docstring_updated', function(data) { if (data.component && currentRepoStructure) { updateFileStatus(data.component, 'complete'); } }); } }); // Handle window resize to update visualization $(window).on('resize', function() { if (currentRepoStructure) { updateRepoStructure(currentRepoStructure); } }); ================================================ FILE: src/web/static/js/status-visualizer.js ================================================ // Copyright (c) Meta Platforms, Inc. and affiliates /** * Status visualizer for the docstring generation web application. * * This file provides functions for rendering and updating the agent status * visualization in the web interface. */ // Define the agent workflow structure const agentWorkflow = { nodes: [ { id: "reader", label: "Reader", x: 150, y: 80, isAgent: true }, { id: "searcher", label: "Searcher", x: 350, y: 80, isAgent: true }, { id: "writer", label: "Writer", x: 150, y: 200, isAgent: true }, { id: "verifier", label: "Verifier", x: 350, y: 200, isAgent: true } ], labels: [ { id: "input", label: "Input", x: 30, y: 140 }, { id: "output", label: "Output", x: 470, y: 140 } ], links: [ { source: "input", target: "reader" }, { source: "reader", target: "searcher" }, { source: "searcher", target: "reader" }, { source: "reader", target: "writer" }, { source: "writer", target: "verifier" }, { source: "verifier", target: "output" }, { source: "verifier", target: "reader" } ] }; // Keep track of the current active agent let currentActiveAgent = null; // Initialize the agent workflow visualization function initAgentWorkflow() { const container = document.getElementById('agent-workflow'); if (!container) return; // Check if container is visible and has dimensions const width = container.clientWidth || 600; const height = container.clientHeight || 200; // Clear any existing content d3.select(container).selectAll("*").remove(); // Create SVG container const svg = d3.select(container) .append("svg") .attr("width", width) .attr("height", height) .append("g") .attr("transform", `translate(${Math.max(0, (width - 500) / 2)}, 0)`); // Add arrowhead marker definition svg.append("defs").append("marker") .attr("id", "arrowhead") .attr("viewBox", "0 -5 10 10") .attr("refX", 20) .attr("refY", 0) .attr("markerWidth", 6) .attr("markerHeight", 6) .attr("orient", "auto") .append("path") .attr("d", "M0,-5L10,0L0,5") .attr("fill", "#adb5bd"); // Helper function to get node coordinates by id function getNodeCoords(id) { const agentNode = agentWorkflow.nodes.find(n => n.id === id); if (agentNode) return { x: agentNode.x, y: agentNode.y }; const labelNode = agentWorkflow.labels.find(n => n.id === id); if (labelNode) return { x: labelNode.x, y: labelNode.y }; return null; } // Draw links svg.selectAll(".workflow-link") .data(agentWorkflow.links) .enter() .append("path") .attr("class", "workflow-link") .attr("d", d => { const source = getNodeCoords(d.source); const target = getNodeCoords(d.target); if (!source || !target) return ""; // Create curved paths const dx = target.x - source.x; const dy = target.y - source.y; const dr = Math.sqrt(dx * dx + dy * dy) * 1.5; return `M${source.x},${source.y}A${dr},${dr} 0 0,1 ${target.x},${target.y}`; }); // Draw agent nodes (circles) const nodes = svg.selectAll(".workflow-node") .data(agentWorkflow.nodes) .enter() .append("g") .attr("class", d => `workflow-node ${d.id}`) .attr("transform", d => `translate(${d.x}, ${d.y})`); // Add node circles for agents nodes.append("circle") .attr("r", 35); // Add node labels for agents nodes.append("text") .attr("class", "workflow-label") .attr("dy", ".35em") .text(d => d.label); // Add non-agent labels (input/output) const textLabels = svg.selectAll(".workflow-text") .data(agentWorkflow.labels) .enter() .append("g") .attr("class", d => `workflow-text ${d.id}`) .attr("transform", d => `translate(${d.x}, ${d.y})`); // Add text for non-agent nodes textLabels.append("text") .attr("class", "workflow-text-label") .attr("dy", ".35em") .attr("text-anchor", "middle") .style("font-size", "14px") .style("font-weight", "bold") .style("fill", "#444") .text(d => d.label); // Add event listeners to highlight nodes on hover nodes.on("mouseover", function() { d3.select(this).style("opacity", 0.8); }).on("mouseout", function() { d3.select(this).style("opacity", 1); }); // If we have a stored active agent, highlight it if (currentActiveAgent) { updateAgentWorkflow(currentActiveAgent); } console.log("Agent workflow initialized with dimensions:", width, "x", height); } // Ensure the workflow is initialized as soon as the document is ready $(document).ready(function() { // Delay initialization slightly to ensure DOM is fully ready setTimeout(initAgentWorkflow, 100); // Also handle window resize $(window).on('resize', function() { initAgentWorkflow(); }); // Poll to ensure the graph is visible (workaround for tabs/containers that might be hidden initially) let checkCount = 0; const checkInterval = setInterval(function() { const container = document.getElementById('agent-workflow'); if (container && container.clientWidth > 0 && container.clientHeight > 0) { initAgentWorkflow(); clearInterval(checkInterval); } else if (checkCount > 20) { // Stop after 20 attempts (10 seconds) clearInterval(checkInterval); } checkCount++; }, 500); }); /** * Update the status visualizer with the current status. * * @param {Object} status - The status object from the server */ function updateStatusVisualizer(status) { console.log("Updating status visualizer with:", status); // Update the agent workflow visualization updateAgentWorkflow(status.active_agent); // If there's no active agent, show placeholder if (!status.active_agent) { $('#status-visualizer').html(`

No active agent

`); return; } // Update component info and status message let statusHtml = `
Processing with ${status.active_agent}
`; if (status.status_message) { statusHtml += `
${status.status_message}
`; } if (status.current_component) { statusHtml += `
Current Processing Component: ${status.current_component}
Current Processing File: ${status.current_file}
`; } $('#status-visualizer').html(statusHtml); } /** * Update the agent workflow visualization with the active agent. * * @param {string} activeAgent - The name of the active agent */ function updateAgentWorkflow(activeAgent) { // Store the active agent currentActiveAgent = activeAgent; // Make sure the workflow is initialized if ($('#agent-workflow svg').length === 0) { initAgentWorkflow(); return; // The initialization will handle setting the active agent } console.log("Updating agent workflow with active agent:", activeAgent); // Remove active class from all nodes d3.selectAll(".workflow-node").classed("active", false); if (!activeAgent) { return; } // Skip non-agent entities if (activeAgent.toLowerCase() === 'system' || activeAgent.toLowerCase() === 'input' || activeAgent.toLowerCase() === 'output') { return; } // Normalize the agent name to lowercase const agentLower = activeAgent.toLowerCase(); // Map certain agent names to our workflow nodes let nodeId = null; if (agentLower.includes('reader')) nodeId = 'reader'; else if (agentLower.includes('searcher')) nodeId = 'searcher'; else if (agentLower.includes('writer')) nodeId = 'writer'; else if (agentLower.includes('verifier')) nodeId = 'verifier'; // Add active class to the current agent node if (nodeId) { const node = d3.select(`.workflow-node.${nodeId}`); if (!node.empty()) { node.classed("active", true); console.log("Activated node:", nodeId); // Briefly animate the node to draw attention node.select("circle") .transition() .duration(300) .attr("r", 40) .transition() .duration(300) .attr("r", 35); } else { console.warn("Could not find node for agent:", activeAgent, "mapped to:", nodeId); } } else { console.warn("Could not map agent name to a node:", activeAgent); } } /** * Update the progress information. * * @param {Object} progress - The progress object from the server */ function updateProgress(progress) { // Calculate percentage const total = progress.total_components || 0; const processed = progress.processed_components || 0; const percentage = total > 0 ? Math.floor((processed / total) * 100) : 0; // Update progress bar $('#progress-bar').css('width', `${percentage}%`); $('#progress-bar').attr('aria-valuenow', percentage); $('#progress-bar').text(`${percentage}%`); // Update progress text $('#progress-text').text(`${processed}/${total} components processed`); } ================================================ FILE: src/web/templates/index.html ================================================ DocAgent - Docstring Generation

DocAgent

Agent Status

No active agent

Repository Structure

No repository selected

Logs Elapsed: 00:00
================================================ FILE: src/web/visualization_handler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Visualization handler for the docstring generation web interface. This module provides functions to collect and format data for visualization in the web interface, including status updates, progress tracking, and repository structure visualization. """ import os import json import sys import subprocess from pathlib import Path from typing import Dict, List, Any # Singleton pattern to store current state class VisualizationState: """Singleton class to store the current visualization state.""" _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(VisualizationState, cls).__new__(cls) cls._instance.status = { 'active_agent': None, 'status_message': '', 'current_component': '', 'current_file': '' } cls._instance.progress = { 'total_components': 0, 'processed_components': 0, 'current_component': '', 'component_status': {} } cls._instance.repo_structure = { 'tree': {}, 'focus_path': '' } cls._instance.log_messages = [] return cls._instance # Initialize the state state = VisualizationState() def get_current_status(): """ Get the current status of the docstring generation process. Returns: Dictionary with the current status information """ return { 'status': state.status, 'progress': state.progress, 'repo_structure': state.repo_structure } def update_agent_status(active_agent: str, status_message: str): """ Update the current agent status. Args: active_agent: The currently active agent (reader, searcher, writer, verifier) status_message: Status message describing what the agent is doing """ state.status['active_agent'] = active_agent state.status['status_message'] = status_message def update_component_focus(component_path: str, file_path: str): """ Update the current component being processed. Args: component_path: The path to the component being processed file_path: The path to the file containing the component """ state.status['current_component'] = component_path state.status['current_file'] = file_path state.repo_structure['focus_path'] = file_path def update_progress(total: int, processed: int, current: str, components_status: Dict[str, str]): """ Update the progress of the docstring generation process. Args: total: Total number of components to process processed: Number of components processed so far current: The component currently being processed components_status: Dictionary mapping component paths to their status """ state.progress['total_components'] = total state.progress['processed_components'] = processed state.progress['current_component'] = current state.progress['component_status'] = components_status def add_log_message(message: str): """ Add a log message to the visualization state. Args: message: The log message to add """ state.log_messages.append(message) # Keep only the latest 1000 messages if len(state.log_messages) > 1000: state.log_messages = state.log_messages[-1000:] def get_repo_structure(repo_path: str) -> Dict[str, Any]: """ Get the structure of the repository as a tree. Args: repo_path: Path to the repository Returns: Dictionary representing the repository structure """ tree = {'name': os.path.basename(repo_path), 'path': repo_path, 'type': 'dir', 'children': []} def build_tree(path, node): """Recursively build the tree structure.""" for item in os.listdir(path): item_path = os.path.join(path, item) # Skip hidden files and directories if item.startswith('.'): continue # Skip __pycache__ and other common non-Python directories if item in ['__pycache__', 'venv', 'env', '.git', '.idea', '.vscode']: continue if os.path.isdir(item_path): child = {'name': item, 'path': item_path, 'type': 'dir', 'children': []} build_tree(item_path, child) node['children'].append(child) elif item.endswith('.py'): node['children'].append({ 'name': item, 'path': item_path, 'type': 'file', 'status': 'not_started' # Possible values: not_started, in_progress, complete }) try: build_tree(repo_path, tree) except Exception as e: print(f"Error building repo structure: {e}") state.repo_structure['tree'] = tree return tree def update_file_status(file_path: str, status: str): """ Update the status of a file in the repository structure. Args: file_path: Path to the file status: New status of the file (not_started, in_progress, complete) """ def update_status(node): """Recursively update the status of the file in the tree.""" if node['type'] == 'file' and node['path'] == file_path: node['status'] = status return True if node['type'] == 'dir' and 'children' in node: for child in node['children']: if update_status(child): return True return False update_status(state.repo_structure['tree']) def get_completeness_data(repo_path: str) -> Dict[str, Any]: """ Get the completeness evaluation data for the repository. Args: repo_path: Path to the repository Returns: Dictionary containing the completeness evaluation results """ try: # Run the eval_completeness.py script to get the results eval_script_path = Path(__file__).parent.parent.parent / 'eval_completeness.py' if not eval_script_path.exists(): return { 'status': 'error', 'message': f'Evaluation script not found at {eval_script_path}' } # Create a simplified mock result for testing or when the script fails mock_results = { 'status': 'success', 'files': [] } # Get Python files in the repository all_python_files = [] for root, _, files in os.walk(repo_path): for file in files: if file.endswith('.py'): file_path = os.path.join(root, file) rel_path = os.path.relpath(file_path, repo_path) # Count functions and classes with simple parsing with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: content = f.read() # Simple counting of functions and classes functions = [] classes = [] function_count = content.count('def ') class_count = content.count('class ') # Simple docstring check (very basic) doc_count = content.count('"""') // 2 # Rough estimate # Create mock function and class objects for i in range(function_count): has_doc = i < doc_count functions.append({ 'name': f'function_{i}', 'has_docstring': has_doc }) for i in range(class_count): has_doc = i < (doc_count - function_count if doc_count > function_count else 0) classes.append({ 'name': f'class_{i}', 'has_docstring': has_doc }) mock_results['files'].append({ 'file': rel_path, 'functions': functions, 'classes': classes }) # Try to run the actual script try: cmd = [sys.executable, str(eval_script_path), '--repo-path', repo_path] result = subprocess.run( cmd, capture_output=True, text=True, timeout=30 # Add timeout to prevent hanging ) if result.returncode == 0 and result.stdout.strip(): try: data = json.loads(result.stdout) if 'files' in data and isinstance(data['files'], list): return { 'status': 'success', 'data': data } except json.JSONDecodeError: pass # Fall back to mock data # If script execution fails, use mock data but log the error print(f"Warning: Using mock completeness data. Script error: {result.stderr}") return { 'status': 'success', 'data': mock_results } except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: print(f"Error running completeness script: {e}") # Fall back to mock data return { 'status': 'success', 'data': mock_results } except Exception as e: print(f"Error evaluating completeness: {e}") return { 'status': 'error', 'message': f'Error evaluating completeness: {str(e)}' } ================================================ FILE: src/web_eval/README.md ================================================ # DocAgent - Docstring Evaluation System A web application for evaluating the quality of Python docstrings in your codebase, providing objective metrics and actionable feedback. ## Overview DocAgentis a powerful tool that analyzes Python docstrings in a repository and evaluates them based on two key metrics: 1. **Completeness**: Automatically checks if docstrings contain all required components (summary, description, arguments, returns, etc.) 2. **Helpfulness**: Uses LLM-based evaluation to assess how helpful and informative each docstring component is on a scale of 1-5 The system provides an intuitive web interface for configuring evaluation settings, viewing results, and getting actionable feedback to improve your codebase documentation. ## Features - **Configuration Interface**: User-friendly setup for LLM API (OpenAI or Claude) and repository path - **API Connection Testing**: Verify API credentials before running evaluations - **Automated Completeness Evaluation**: Scan all Python files in a repository to check for required docstring components - **Interactive Results Dashboard**: View completeness scores for all classes and functions with detailed breakdowns - **On-demand Helpfulness Assessment**: Use LLM-powered evaluation for specific docstring components - **Visual Status Indicators**: Clear visual feedback for required vs. optional components and their quality - **Component-specific Evaluations**: Different criteria for evaluating summaries, descriptions, parameters, etc. - **Refresh Functionality**: Re-run evaluation after making code changes - **Detailed Explanations**: Get specific feedback on why a component received its score and how to improve it ## System Architecture DocAgent's web evaluation system consists of several key components: ``` src/web_eval/ │ ├── app.py # Main Flask application ├── helpers.py # Utility functions (parsing, extraction, etc.) ├── requirements.txt # Python dependencies ├── start_server.sh # Convenience script for starting the server ├── test_docstring_parser.py # Tests for the docstring parser │ ├── templates/ # HTML templates │ ├── index.html # Configuration page │ └── results.html # Results display page │ └── static/ # Static assets ├── css/ # CSS stylesheets ├── js/ # JavaScript files └── assets/ # Images and other assets ``` The system follows a Model-View-Controller architecture: - **Model**: Evaluation logic in the imported evaluator modules and parsing functions in helpers.py - **View**: HTML templates with Jinja2 for rendering the UI - **Controller**: Flask routes in app.py that handle requests and connect the model with views The application integrates with two key external components: 1. **DocAgent Evaluator Modules**: Core evaluation logic for assessing docstring quality 2. **LLM APIs**: OpenAI or Anthropic Claude for helpfulness evaluation ================================================ FILE: src/web_eval/app.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates import os import sys import ast import json import argparse from flask import Flask, render_template, request, jsonify, redirect, url_for from typing import Dict, Any, List # Add parent directory to path to import from src sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) # Import evaluation modules from evaluator.completeness import ClassCompletenessEvaluator, FunctionCompletenessEvaluator from evaluator.helpfulness_summary import DocstringSummaryEvaluator from evaluator.helpfulness_description import DocstringDescriptionEvaluator # from evaluator.helpfulness_arguments import DocstringArgumentEvaluator from evaluator.helpfulness_parameters import DocstringParametersEvaluator from evaluator.helpfulness_attributes import DocstringAttributeEvaluator # from evaluator.helpfulness_examples import DocstringExampleEvaluator # Import our helpers from src.web_eval.helpers import parse_llm_score_from_text, extract_docstring_component # Initialize Flask app app = Flask(__name__) app.config['SECRET_KEY'] = 'DocAgent-evaluation-system' # Add template filter for extracting docstring components @app.template_filter('extract_component') def extract_component_filter(docstring, component): """ Jinja2 template filter for extracting docstring components. Args: docstring: The full docstring component: The component to extract (summary, description, etc.) Returns: The extracted component, or empty string if not found """ result = extract_docstring_component(docstring, component) return result or "" # Global variable to store evaluation results evaluation_results = {} config = {} @app.route('/') def index(): """ Renders the configuration page (entry page). This page allows users to configure LLM settings and repository path. """ return render_template('index.html') @app.route('/test_api', methods=['POST']) def test_api(): """ Tests the LLM API connection by sending a simple query. Returns: JSON response with success/failure and any error message """ data = request.get_json() # Save config for later use global config config = { 'llm_type': data.get('llm_type'), 'api_key': data.get('api_key'), 'model': data.get('model'), 'temperature': float(data.get('temperature', 0.1)), 'max_output_tokens': int(data.get('max_output_tokens', 4096)) } # Test API connection based on LLM type try: if config['llm_type'] == 'openai': import openai openai.api_key = config['api_key'] response = openai.chat.completions.create( model=config['model'], messages=[{"role": "user", "content": "Who are you?"}], temperature=config['temperature'], max_tokens=100 ) return jsonify({"success": True, "response": response.choices[0].message.content}) elif config['llm_type'] == 'claude': from anthropic import Anthropic client = Anthropic(api_key=config['api_key']) response = client.messages.create( model=config['model'], max_tokens=100, temperature=config['temperature'], messages=[{"role": "user", "content": "Who are you?"}] ) return jsonify({"success": True, "response": response.content[0].text}) else: return jsonify({"success": False, "error": f"Unsupported LLM type: {config['llm_type']}"}) except Exception as e: return jsonify({"success": False, "error": str(e)}) @app.route('/evaluate', methods=['POST']) def evaluate(): """ Initiates the evaluation process for the specified repository. Returns: Redirects to the results page """ data = request.get_json() repo_path = data.get('repo_path') if not os.path.exists(repo_path): return jsonify({"success": False, "error": f"Repository path does not exist: {repo_path}"}) try: # Start evaluation global evaluation_results evaluation_results = process_directory(repo_path) return jsonify({"success": True, "redirect": url_for('results')}) except Exception as e: return jsonify({"success": False, "error": str(e)}) @app.route('/results') def results(): """ Renders the evaluation results page. """ return render_template('results.html', results=evaluation_results) @app.route('/evaluate_helpfulness', methods=['POST']) def evaluate_helpfulness(): """ Evaluates the helpfulness of a specific docstring component. Returns: JSON response with the helpfulness score """ data = request.get_json() component_type = data.get('component_type') # class or function component_name = data.get('component_name') docstring_part = data.get('docstring_part') # summary, description, etc. docstring_content = data.get('docstring_content') signature = data.get('signature', '') try: # Select appropriate evaluator based on docstring part evaluator = None if docstring_part == 'summary': evaluator = DocstringSummaryEvaluator() elif docstring_part == 'description': evaluator = DocstringDescriptionEvaluator() # elif docstring_part == 'arguments': # evaluator = DocstringArgumentsEvaluator() elif docstring_part == 'parameters': evaluator = DocstringParametersEvaluator() elif docstring_part == 'attributes': evaluator = DocstringAttributesEvaluator() elif docstring_part == 'examples': evaluator = DocstringExamplesEvaluator() else: return jsonify({"success": False, "error": f"Unsupported docstring part: {docstring_part}"}) # Generate prompt prompt = evaluator.get_evaluation_prompt(signature, docstring_content) # Call LLM API based on configured type if config['llm_type'] == 'openai': import openai openai.api_key = config['api_key'] response = openai.chat.completions.create( model=config['model'], messages=[{"role": "user", "content": prompt}], temperature=config['temperature'], max_tokens=config['max_output_tokens'] ) llm_response = response.choices[0].message.content elif config['llm_type'] == 'claude': from anthropic import Anthropic client = Anthropic(api_key=config['api_key']) response = client.messages.create( model=config['model'], max_tokens=config['max_output_tokens'], temperature=config['temperature'], messages=[{"role": "user", "content": prompt}] ) llm_response = response.content[0].text else: return jsonify({"success": False, "error": f"Unsupported LLM type: {config['llm_type']}"}) # Parse LLM response to get score score, explanation = parse_llm_score_from_text(llm_response) # Update evaluation results with helpfulness score if component_type == 'class': for cls in evaluation_results['classes']: if cls['name'] == component_name: if 'helpfulness_scores' not in cls: cls['helpfulness_scores'] = {} cls['helpfulness_scores'][docstring_part] = { 'score': score, 'explanation': explanation } break else: # function or method for func in evaluation_results['functions']: if func['name'] == component_name: if 'helpfulness_scores' not in func: func['helpfulness_scores'] = {} func['helpfulness_scores'][docstring_part] = { 'score': score, 'explanation': explanation } break return jsonify({ "success": True, "score": score, "explanation": explanation }) except Exception as e: return jsonify({"success": False, "error": str(e)}) @app.route('/refresh', methods=['POST']) def refresh_evaluation(): """ Refreshes the completeness evaluation results. Returns: Redirects to the updated results page """ data = request.get_json() repo_path = data.get('repo_path') try: # Re-run evaluation global evaluation_results evaluation_results = process_directory(repo_path) return jsonify({"success": True}) except Exception as e: return jsonify({"success": False, "error": str(e)}) def run_docstring_tests(source_file: str) -> Dict[str, Any]: """ Run comprehensive docstring evaluation tests on a Python source file. This function reads a Python file and evaluates docstrings for all classes, functions, and methods found within. It provides detailed evaluation results. Args: source_file: Path to the Python file to analyze Returns: Dictionary containing evaluation results for each found element """ with open(source_file, 'r', encoding='utf-8') as f: source = f.read() try: tree = ast.parse(source) except SyntaxError as e: return { 'status': 'error', 'message': f'Failed to parse {source_file}: {str(e)}' } results = { 'status': 'success', 'file': source_file, 'classes': [], 'functions': [], 'debug_info': {} } # Instantiate evaluators class_evaluator = ClassCompletenessEvaluator() func_evaluator = FunctionCompletenessEvaluator() # Process all nodes in the AST for node in ast.iter_child_nodes(tree): if isinstance(node, ast.ClassDef): # Get actual docstring content class_docstring = ast.get_docstring(node) or "" class_result = { 'name': node.name, 'type': 'class', 'docstring': class_docstring, 'signature': f"class {node.name}:", 'completeness_score': class_evaluator.evaluate(node), 'completeness_elements': class_evaluator.element_scores.copy(), 'element_required': class_evaluator.element_required.copy() } results['classes'].append(class_result) # Evaluate methods within the class for method in [n for n in ast.iter_child_nodes(node) if isinstance(n, ast.FunctionDef)]: # Skip __init__ methods for display purposes if method.name == '__init__': continue # Get actual method docstring content method_docstring = ast.get_docstring(method) or "" method_result = { 'name': f"{node.name}.{method.name}", 'type': 'method', 'docstring': method_docstring, 'signature': f"def {method.name}():", # Simplified signature 'completeness_score': func_evaluator.evaluate(method), 'completeness_elements': func_evaluator.element_scores.copy(), 'element_required': func_evaluator.element_required.copy() } results['functions'].append(method_result) elif isinstance(node, ast.FunctionDef): # Get actual function docstring content func_docstring = ast.get_docstring(node) or "" # Only process top-level functions func_result = { 'name': node.name, 'type': 'function', 'docstring': func_docstring, 'signature': f"def {node.name}():", # Simplified signature 'completeness_score': func_evaluator.evaluate(node), 'completeness_elements': func_evaluator.element_scores.copy(), 'element_required': func_evaluator.element_required.copy() } results['functions'].append(func_result) # Add overall statistics results['statistics'] = { 'total_classes': len(results['classes']), 'total_functions': len(results['functions']), 'average_class_score': sum(r['completeness_score'] for r in results['classes']) / max(1, len(results['classes'])), 'average_function_score': sum(r['completeness_score'] for r in results['functions']) / max(1, len(results['functions'])) } return results def process_directory(directory_path: str) -> Dict[str, Any]: """ Process all Python files in a directory and its subdirectories. Args: directory_path: Path to the directory to analyze Returns: Dictionary containing aggregated evaluation results for all files """ # Initialize aggregate results aggregate_results = { 'status': 'success', 'directory': directory_path, 'files': [], 'file_results': [], 'classes': [], 'functions': [], 'statistics': { 'total_files': 0, 'successful_files': 0, 'failed_files': 0, 'total_classes': 0, 'total_functions': 0, 'average_class_score': 0.0, 'average_function_score': 0.0, 'overall_average_score': 0.0 } } # Find all Python files recursively python_files = [] for root, _, files in os.walk(directory_path): for file in files: if file.endswith('.py'): python_files.append(os.path.join(root, file)) if not python_files: aggregate_results['status'] = 'error' aggregate_results['message'] = f'No Python files found in {directory_path}' return aggregate_results aggregate_results['statistics']['total_files'] = len(python_files) # Process each Python file all_class_scores = [] all_function_scores = [] for py_file in python_files: file_result = run_docstring_tests(py_file) if file_result['status'] == 'success': aggregate_results['statistics']['successful_files'] = aggregate_results['statistics'].get('successful_files', 0) + 1 aggregate_results['file_results'].append(file_result) aggregate_results['files'].append(py_file) # Accumulate classes and functions with file path context for class_result in file_result['classes']: class_result['file'] = py_file aggregate_results['classes'].append(class_result) all_class_scores.append(class_result['completeness_score']) for func_result in file_result['functions']: func_result['file'] = py_file aggregate_results['functions'].append(func_result) all_function_scores.append(func_result['completeness_score']) # Update statistics aggregate_results['statistics']['total_classes'] += file_result['statistics']['total_classes'] aggregate_results['statistics']['total_functions'] += file_result['statistics']['total_functions'] else: aggregate_results['statistics']['failed_files'] = aggregate_results['statistics'].get('failed_files', 0) + 1 # Calculate average scores if all_class_scores: aggregate_results['statistics']['average_class_score'] = sum(all_class_scores) / len(all_class_scores) if all_function_scores: aggregate_results['statistics']['average_function_score'] = sum(all_function_scores) / len(all_function_scores) # Calculate overall average score (classes and functions combined) all_scores = all_class_scores + all_function_scores if all_scores: aggregate_results['statistics']['overall_average_score'] = sum(all_scores) / len(all_scores) return aggregate_results if __name__ == '__main__': # Parse command line arguments parser = argparse.ArgumentParser(description='Docstring Evaluation Web App') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host address to bind to (default: 0.0.0.0 - accessible from outside)') parser.add_argument('--port', type=int, default=5000, help='Port to run the server on (default: 5000)') parser.add_argument('--debug', action='store_true', help='Run in debug mode (default: False)') args = parser.parse_args() # Print access information if args.host == '0.0.0.0': print(f"\n🚀 DocAgent web server starting!") print(f"💻 Local access: http://localhost:{args.port}") print(f"🌐 Network access: http://:{args.port}") print(f" (Replace with your server's IP address)") if args.debug: print(f"⚠️ Running in debug mode - not recommended for production use") print("\nPress CTRL+C to stop the server\n") # Run the Flask app app.run(host=args.host, port=args.port, debug=args.debug) ================================================ FILE: src/web_eval/helpers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates """ Helper functions for the DocAgent web application """ import re from typing import Tuple, Optional, Dict, List def parse_llm_score_from_text(text: str) -> Tuple[int, str]: """ Parse score and explanation from LLM response text. Args: text: The raw LLM response text Returns: Tuple containing (score, explanation) """ # Try to extract score from tags score_match = re.search(r'(\d+)', text) if score_match: score = int(score_match.group(1)) else: # Try looking for the score in various formats score_patterns = [ r'score:?\s*(\d+)/5', r'score:?\s*(\d+)', r'rating:?\s*(\d+)/5', r'rating:?\s*(\d+)', r'(\d+)/5', r'I would rate this as a (\d+)', r'I would give this a (\d+)' ] for pattern in score_patterns: match = re.search(pattern, text, re.IGNORECASE) if match: score = int(match.group(1)) break else: # Default score if we can't find one score = 3 # Limit score to 1-5 range score = max(1, min(5, score)) # Extract explanation (everything except the score tags) explanation = re.sub(r'\d+', '', text).strip() # If explanation is very long, truncate it if len(explanation) > 500: explanation = explanation[:497] + "..." return score, explanation from typing import Dict def parse_google_style_docstring(docstring: str) -> Dict[str, str]: """ A robust parser for Google-style docstrings that handles multiple possible labels for each section. Args: docstring: The docstring to parse Returns: Dictionary with canonical section names as keys and their content as values """ # If docstring is empty or None, return empty sections if not docstring: return {key: "" for key in ['summary', 'description', 'parameters', 'attributes', 'returns', 'raises', 'examples']} # Define all recognized sections. The key is the canonical name (lowercase). # The value is a set of synonyms (also lowercase). SECTION_LABELS = { "summary": {"summary:", "brief:", "overview:"}, "description": {"description:", "desc:", "details:", "long description:"}, "parameters": {"parameters:", "params:", "args:", "arguments:", "keyword args:", "keyword arguments:", "**kwargs:"}, "attributes": {"attributes:", "members:", "member variables:", "instance variables:", "properties:", "vars:", "variables:"}, "returns": {"returns:", "return:", "return value:", "return values:"}, "raises": {"raises:", "exceptions:", "throws:", "raise:", "exception:", "throw:"}, "examples": {"example:", "examples:", "usage:", "usage example:", "usage examples:", "example usage:"}, } # Prepare a dictionary to hold the parsed content for each canonical key parsed_content = {key: [] for key in SECTION_LABELS.keys()} # Split by lines; if docstring uses Windows line endings, .splitlines() handles that gracefully lines = docstring.strip().splitlines() # -- 1) Fallback: no explicit sections at all in the entire docstring -- # If no recognized label appears anywhere, treat the first line as summary, rest as description. has_section_labels = False for line in lines: line_lower = line.strip().lower() for labels in SECTION_LABELS.values(): for label in labels: if line_lower.startswith(label): has_section_labels = True break if has_section_labels: break if has_section_labels: break if len(lines) > 0 and not has_section_labels: parsed_content["summary"] = [lines[0]] if len(lines) > 1: parsed_content["description"] = lines[1:] # Convert lists to single strings return {key: "\n".join(value).strip() for key, value in parsed_content.items()} # We'll track the current section as we parse line by line current_section = None # -- 2) Partial Fallback for the first line only -- # If the first line doesn't match any known label, treat it as summary and then # switch to "description" until an explicit label is found. first_line = lines[0].strip().lower() if lines else "" if not any(first_line.startswith(label) for labels in SECTION_LABELS.values() for label in labels): if lines: # Save first line as summary parsed_content["summary"] = [lines[0]] # Make the current section "description" current_section = "description" lines = lines[1:] # We'll handle the rest below # -- 3) Main Parsing Loop -- for line in lines: trimmed_line = line.strip().lower() matched_section = None # Check if this line begins with a known label (case-insensitive) # If so, we identify that as a new section. for canonical_name, synonyms in SECTION_LABELS.items(): for synonym in synonyms: if trimmed_line.startswith(synonym): matched_section = canonical_name # Extract leftover text on the same line, after the label leftover = line.strip()[len(synonym):].strip() if leftover: parsed_content[matched_section].append(leftover) break if matched_section: break if matched_section is not None: # We found a new section header on this line current_section = matched_section # No need to append the header line to content - we've already handled any content after the label else: # Otherwise, continue appending lines to the current section if current_section is not None: parsed_content[current_section].append(line) # -- 4) Convert list of lines to single string, preserving line breaks -- for section in parsed_content: parsed_content[section] = "\n".join(parsed_content[section]).strip() return parsed_content def extract_docstring_component(docstring: str, component: str) -> Optional[str]: """ Extract a specific component from a docstring using the robust parser. Args: docstring: The full docstring text component: The component to extract (summary, description, etc.) Returns: The extracted component text, or None if not found """ if not docstring: return None # Map component name to canonical name used in the parser component_map = { 'summary': 'summary', 'description': 'description', # 'arguments': 'parameters', 'params': 'parameters', 'parameters': 'parameters', 'attributes': 'attributes', 'returns': 'returns', 'raises': 'raises', 'examples': 'examples' } canonical_component = component_map.get(component.lower(), component.lower()) # Parse the docstring parsed = parse_google_style_docstring(docstring) # Return the requested component if canonical_component in parsed: return parsed[canonical_component] or None return None ================================================ FILE: src/web_eval/requirements.txt ================================================ flask>=2.0.0 openai>=1.0.0 anthropic>=0.5.0 tabulate>=0.8.0 ================================================ FILE: src/web_eval/start_server.sh ================================================ #!/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates # Default values HOST="0.0.0.0" PORT="8080" DEBUG="" # Show help function show_help() { echo "Usage: ./start_server.sh [options]" echo "" echo "Options:" echo " -h, --host HOST Host address to bind to (default: 0.0.0.0)" echo " -p, --port PORT Port to run the server on (default: 8080)" echo " -d, --debug Run in debug mode" echo " --help Show this help message" echo "" echo "Examples:" echo " ./start_server.sh # Run on default host:port (0.0.0.0:8080)" echo " ./start_server.sh -p 9090 # Run on port 9090" echo " ./start_server.sh -h 127.0.0.1 # Run on localhost only" echo " ./start_server.sh -d # Run in debug mode" echo "" } # Parse command line arguments while [[ $# -gt 0 ]]; do case "$1" in -h|--host) HOST="$2" shift 2 ;; -p|--port) PORT="$2" shift 2 ;; -d|--debug) DEBUG="--debug" shift ;; --help) show_help exit 0 ;; *) echo "Unknown option: $1" show_help exit 1 ;; esac done # Display startup message echo "Starting DocAgent Web Server..." echo "Host: $HOST" echo "Port: $PORT" if [ -n "$DEBUG" ]; then echo "Mode: DEBUG (not recommended for production)" else echo "Mode: Production" fi echo "" # Run the Flask app with the specified options python app.py --host "$HOST" --port "$PORT" $DEBUG ================================================ FILE: src/web_eval/static/css/style.css ================================================ /* Copyright (c) Meta Platforms, Inc. and affiliates */ /* DocAgent - Docstring Evaluation System Styles */ /* General Styles */ body { background-color: #f8f9fa; } .card { border-radius: 0.5rem; overflow: hidden; } .card-header { border-bottom: none; } /* Table Styles */ .table { font-size: 0.9rem; } .table th { font-weight: 600; } .table-responsive { max-height: 70vh; overflow-y: auto; } /* Button Styles */ .evaluate-btn { font-size: 0.75rem; padding: 0.2rem 0.5rem; } /* Modal Styles */ .modal-content { border-radius: 0.5rem; overflow: hidden; } .modal-header { border-bottom: none; } .modal-footer { border-top: none; } /* Docstring content display */ pre#docstringContent { max-height: 300px; overflow-y: auto; font-size: 0.9rem; white-space: pre-wrap; } /* Badges */ .badge { font-weight: 500; padding: 0.35rem 0.65rem; } /* Alert Styles */ .alert { border-radius: 0.5rem; } /* Responsive Adjustments */ @media (max-width: 992px) { .table { font-size: 0.8rem; } .evaluate-btn { font-size: 0.7rem; padding: 0.15rem 0.4rem; } .badge { font-size: 0.7rem; padding: 0.25rem 0.5rem; } } /* Custom scrollbar */ ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: #f1f1f1; border-radius: 4px; } ::-webkit-scrollbar-thumb { background: #888; border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: #555; } ================================================ FILE: src/web_eval/templates/index.html ================================================ DocAgent - Docstring Evaluation System

DocAgent - Docstring Evaluation System

LLM Configuration


Repository Configuration

================================================ FILE: src/web_eval/templates/results.html ================================================ DocAgent - Evaluation Results

DocAgent - Evaluation Results

Return to Config
Repository: {{ results.directory }}
Total Files: {{ results.statistics.total_files }}
Total Classes: {{ results.statistics.total_classes }}
Total Functions/Methods: {{ results.statistics.total_functions }}
Overall Score: {{ '%.2f'|format(results.statistics.overall_average_score) }}

Classes

{% for class in results.classes %} {% endfor %}
Class Name Score Summary Description Parameters Attributes Examples File
{{ class.name }} {{ '%.2f'|format(class.completeness_score) }} {% if class.completeness_elements.summary %} {% if class.element_required.summary %} {% endif %} {% else %} {% if class.element_required.summary %} Required {% endif %} {% endif %} {% if class.completeness_elements.description %} {% if class.element_required.description %} {% endif %} {% else %} {% if class.element_required.description %} Required {% endif %} {% endif %} {% if class.completeness_elements.parameters %} {% if class.element_required.parameters %} {% endif %} {% else %} {% if class.element_required.parameters %} Required {% endif %} {% endif %} {% if class.completeness_elements.attributes %} {% if class.element_required.attributes %} {% endif %} {% else %} {% if class.element_required.attributes %} Required {% endif %} {% endif %} {% if class.completeness_elements.examples %} {% if class.element_required.examples %} {% endif %} {% else %} {% if class.element_required.examples %} Required {% endif %} {% endif %} {{ class.file.split('/')[-1] }}

Functions/Methods

{% for func in results.functions %} {% endfor %}
Function Name Type Score Summary Description Returns Raises Examples File
{{ func.name }} {{ func.type }} {{ '%.2f'|format(func.completeness_score) }} {% if func.completeness_elements.summary %} {% if func.element_required.summary %} {% endif %} {% else %} {% if func.element_required.summary %} Required {% endif %} {% endif %} {% if func.completeness_elements.description %} {% if func.element_required.description %} {% endif %} {% else %} {% if func.element_required.description %} Required {% endif %} {% endif %} {% if func.completeness_elements.returns %} {% if func.element_required.returns %} {% endif %} {% else %} {% if func.element_required.returns %} Required {% endif %} {% endif %} {% if func.completeness_elements.raises %} {% if func.element_required.raises %} {% endif %} {% else %} {% if func.element_required.raises %} Required {% endif %} {% endif %} {% if func.completeness_elements.examples %} {% if func.element_required.examples %} {% endif %} {% else %} {% if func.element_required.examples %} Required {% endif %} {% endif %} {{ func.file.split('/')[-1] }}
================================================ FILE: src/web_eval/test_docstring_parser.py ================================================ #!/usr/bin/env python # Copyright (c) Meta Platforms, Inc. and affiliates # -*- coding: utf-8 -*- """Test script for the parse_google_style_docstring function.""" from helpers import parse_google_style_docstring, extract_docstring_component import json from typing import Dict, Any, Optional def test_and_print_result(test_name: str, docstring: str) -> Dict[str, Any]: """ Run a test case and print results in a formatted way. Args: test_name: The name of the test docstring: The docstring to parse Returns: The parsed docstring components """ print(f"\n{'=' * 80}") print(f"TEST: {test_name}") print(f"{'-' * 80}") print("INPUT DOCSTRING:") print(f"{'-' * 40}") print(docstring) print(f"{'-' * 40}") # Parse the docstring result = parse_google_style_docstring(docstring) # Print the result in a formatted way print("PARSED RESULT:") print(f"{'-' * 40}") for section, content in result.items(): if content: print(f"{section.upper()}:") print(f"{content!r}") print() print(f"{'-' * 40}") return result def test_extract_component(docstring: str) -> None: """ Test the extract_docstring_component function with a given docstring. Args: docstring: The docstring to test with """ print(f"\n{'=' * 80}") print("TESTING extract_docstring_component") print(f"{'-' * 80}") print("INPUT DOCSTRING:") print(f"{'-' * 40}") print(docstring) print(f"{'-' * 40}") # Test extracting different components components = ["summary", "description", "parameters", "arguments", "returns", "raises", "examples"] print("EXTRACTED COMPONENTS:") print(f"{'-' * 40}") for component in components: result = extract_docstring_component(docstring, component) print(f"{component.upper()}: {result!r}") print(f"{'-' * 40}") def main(): """Run all tests for the docstring parser.""" # Test 1: Standard Google-style docstring test_and_print_result( "Standard Google-style docstring", """This is the summary line. This is the extended description that spans multiple lines. Args: param1: Description of param1 param2: Description of param2 Returns: Description of the return value Raises: ValueError: If something goes wrong Examples: >>> example_function(1, 2) 3 """ ) # Test 2: Docstring with Google-style section markers and colons test_and_print_result( "Docstring with explicit Google-style section markers", """Summary: This is a summary on the same line as the marker. Description: This is a multi-line description. Args: param1: Description of param1 param2: Description of param2 Returns: Description of the return value Examples: Example 1 Example 2 """ ) # Test 3: Docstring with content on the same line as section headers test_and_print_result( "Docstring with content on the same line as section headers", """Summary: This is a summary on the same line. Description: This is a description on the same line. Args: These are args on the same line. param1: Description of param1 param2: Description of param2 Returns: This is the return value on the same line. Raises: These are exceptions on the same line. ValueError: If something goes wrong Examples: This is an example on the same line. >>> example_function(1, 2) 3 """ ) # Test 4: Docstring with alternative labels test_and_print_result( "Docstring with alternative section labels", """Brief: This is the summary with alternative label. Detailed Description: This is the description. Arguments: param1: Description of param1 param2: Description of param2 Return Value: Description of the return value Exceptions: ValueError: If something goes wrong Usage: >>> example_function(1, 2) 3 """ ) # Test 5: Docstring with no explicit section markers test_and_print_result( "Docstring with no explicit section markers", """This is just a simple docstring with no section markers. It has a second paragraph, but no explicit Args, Returns, etc. """ ) # Test 6: Empty docstring test_and_print_result( "Empty docstring", "" ) # Test 7: Single line docstring test_and_print_result( "Single line docstring", "This is a single line docstring." ) # Test 8: Docstring with unusual indentation test_and_print_result( "Docstring with unusual indentation", """ This is an indented summary. This description has extra indentation. Args: param1: Indented param param2: Indented param Returns: Indented return value """ ) # Test 9: Incomplete docstring with some sections missing test_and_print_result( "Incomplete docstring with some sections missing", """Summary: This is the summary. Args: param1: First parameter param2: Second parameter """ ) # Test 10: Docstring with uppercase section labels test_and_print_result( "Docstring with uppercase section labels", """SUMMARY: This is the summary. DESCRIPTION: This is the description. ARGS: param1: First parameter param2: Second parameter RETURNS: The return value. """ ) # Test 11: Docstring with mixed case section labels test_and_print_result( "Docstring with mixed case section labels", """Summary: This is the summary. Description: This is the description. Arguments: param1: First parameter param2: Second parameter ReTuRnS: The return value. """ ) # Test 12: Docstring with complex examples section test_and_print_result( "Docstring with complex examples section", """Summary: This function does something. Examples: >>> example_function(1, 2) 3 More complex example: ```python result = example_function( a=1, b=2 ) assert result == 3 ``` """ ) # Test 13: Docstring with parameters that look like section labels test_and_print_result( "Docstring with parameters that look like section labels", """Validates input parameters. Args: summary: A parameter named "summary" description: A parameter named "description" returns: A parameter named "returns" examples: A parameter named "examples" """ ) # Test 14: Docstring with non-standard sections test_and_print_result( "Docstring with non-standard sections", """Summary: This is the summary. Description: This is the description. Note: This is an important note. Warning: This is a warning. Args: param1: First parameter """ ) # Test 15: Docstring with section labels with extra spaces test_and_print_result( "Docstring with section labels with extra spaces", """Summary : This is the summary with extra spaces around the colon. Description : This is the description. Args : param1: First parameter """ ) # Test 16: Docstring with section label on a line by itself (no colon) # This is a tricky case! test_and_print_result( "Docstring with section label on a line by itself (no colon)", """This is the summary. Description This is the description. Arguments param1: First parameter param2: Second parameter Returns The return value. """ ) # Test 17: Docstring with Summary section without a colon test_and_print_result( "Docstring with Summary section without a colon", """Summary This is a summary without a colon after the section label. Description: This is the description. """ ) # Test 18: Docstring with multiple colons in the summary line test_and_print_result( "Docstring with multiple colons in the summary line", """Summary: This is a summary: with another colon in it. Description: This is the description with: a colon. """ ) # Test 19: Docstring with summary containing special characters test_and_print_result( "Docstring with summary containing special characters", """Summary: This summary has *special* characters like: [], (), {} Args: param1: Description with `code` and *formatting* """ ) # Test 20: Docstring with only a summary section test_and_print_result( "Docstring with only a summary section", """Summary: This is only a summary section without other sections. """ ) # Test 21: Docstring with summary containing multiple paragraphs test_and_print_result( "Docstring with summary containing multiple paragraphs", """Summary: This is a multi-paragraph summary. It has more than one paragraph. Description: This is the description. """ ) # Test 22: Docstring with extra spacing between sections test_and_print_result( "Docstring with extra spacing between sections", """Summary: This is the summary. Description: This is the description. Args: param1: First parameter """ ) # Test 23: Docstring with no content after section label test_and_print_result( "Docstring with no content after section label", """Summary: Description: Args: param1: This parameter has a description Returns: """ ) # Test 24: Docstring with inconsistent indentation test_and_print_result( "Docstring with inconsistent indentation", """Summary: This is a summary. Description: This description has inconsistent indentation. Args: param1: Indented 6 spaces param2: Indented differently """ ) # Test 25: Real-world complex docstring example test_and_print_result( "Real-world complex docstring example", '''""" Process and analyze data from multiple sources. This utility function combines data from different sources, performs advanced analytics, and returns a processed result. It handles various edge cases and data inconsistencies. Args: data_source (str or Path): Path to the main data source secondary_sources (List[str], optional): Additional data sources to include config (Dict[str, Any]): Configuration parameters with the following structure: { "preprocessing": { "normalize": bool, "fill_missing": str }, "analysis": { "method": str, "parameters": Dict[str, Any] } } callback (Callable, optional): Function to call with progress updates Returns: Dict[str, Any]: Processed results with the following structure: { "summary": { "total_records": int, "processed_records": int, "anomalies": int }, "detailed_results": List[Dict[str, Any]] } Raises: FileNotFoundError: If any data source cannot be found ValueError: If the configuration is invalid ProcessingError: If analysis fails during execution Examples: Basic usage: >>> result = process_data("data.csv", config={"preprocessing": {"normalize": True}}) >>> print(result["summary"]["total_records"]) 1000 Advanced usage with multiple sources: ```python sources = ["secondary1.csv", "secondary2.csv"] config = { "preprocessing": {"normalize": True, "fill_missing": "mean"}, "analysis": {"method": "advanced", "parameters": {"iterations": 100}} } def progress(percent): print(f"Processed: {percent}%") result = process_data("main.csv", sources, config, callback=progress) ``` """''' ) # New: Test the extract_docstring_component function specifically print("\n\n") print("*" * 100) print("TESTING extract_docstring_component FUNCTION") print("*" * 100) # Test Case 1: Standard docstring test_extract_component( """This is a standard docstring summary. This is the description. Args: param1: First parameter param2: Second parameter Returns: The return value """ ) # Test Case 2: Google-style docstring with explicit section markers test_extract_component( """Summary: This is a summary with explicit section marker. Description: This is a description. Args: param1: First parameter param2: Second parameter Returns: The return value """ ) # Test Case 3: Docstring with content on the same line as section headers test_extract_component( """Summary: This is a summary on the same line. Description: This is a description on the same line. Args: These are arguments on the same line. param1: First parameter param2: Second parameter Returns: This is the return value on the same line. """ ) # Test Case 4: Real-world docstring that might be causing issues test_extract_component( """Parses a Google-style docstring into its components. This function takes a docstring and extracts the summary, description, parameters, returns, raises, and examples sections. Args: docstring: The docstring to parse Returns: A dictionary containing the parsed components """ ) # Test Case 5: Empty docstring test_extract_component("") # Specific cases reported as problematic print("\n\n") print("*" * 100) print("TESTING SPECIFIC PROBLEM CASES") print("*" * 100) # Problem Case: Summary followed immediately by content test_extract_component( """Summary:This is a summary with no space after the colon. Description: This is a description. """ ) # Problem Case: Summary with line break before content test_extract_component( """Summary: This is a summary after a line break. Description: This is a description. """ ) if __name__ == "__main__": main() ================================================ FILE: tool/remove_docstrings.py ================================================ #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates """ Tool to remove docstrings from Python files in a repository. """ import os import ast import astor import argparse from typing import List, Tuple class DocstringRemover(ast.NodeTransformer): """ AST NodeTransformer that removes docstrings from classes, methods, and functions. """ def visit_ClassDef(self, node): """Remove docstrings from class definitions.""" # Process class body first (recursive) node = self.generic_visit(node) # Remove docstring if present if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)): node.body = node.body[1:] return node def visit_FunctionDef(self, node): """Remove docstrings from function/method definitions.""" # Process function body first (recursive) node = self.generic_visit(node) # Remove docstring if present if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)): node.body = node.body[1:] return node def visit_AsyncFunctionDef(self, node): """Remove docstrings from async function/method definitions.""" # Process function body first (recursive) node = self.generic_visit(node) # Remove docstring if present if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)): node.body = node.body[1:] return node def find_python_files(directory: str) -> List[str]: """Find all Python files in the given directory and its subdirectories.""" python_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith('.py'): python_files.append(os.path.join(root, file)) return python_files def remove_docstrings_from_file(file_path: str, dry_run: bool = False) -> Tuple[bool, str]: """ Remove docstrings from a Python file. Args: file_path: Path to the Python file dry_run: If True, don't actually write changes to file Returns: Tuple of (success, message) """ try: with open(file_path, 'r', encoding='utf-8') as f: source = f.read() # Parse the source code into an AST tree = ast.parse(source) # Remove docstrings transformer = DocstringRemover() new_tree = transformer.visit(tree) # Generate the modified source code new_source = astor.to_source(new_tree) if not dry_run: with open(file_path, 'w', encoding='utf-8') as f: f.write(new_source) return True, f"Successfully removed docstrings from {file_path}" else: return True, f"Would remove docstrings from {file_path} (dry run)" except Exception as e: return False, f"Error processing {file_path}: {str(e)}" def main(): parser = argparse.ArgumentParser(description="Remove docstrings from Python files in a repository") parser.add_argument("directory", help="Directory containing Python files to process") parser.add_argument("--dry-run", action="store_true", help="Don't actually modify files, just show what would be done") args = parser.parse_args() # Find all Python files python_files = find_python_files(args.directory) print(f"Found {len(python_files)} Python files to process") # Process each file success_count = 0 for file_path in python_files: success, message = remove_docstrings_from_file(file_path, args.dry_run) print(message) if success: success_count += 1 # Summary print(f"\nProcessed {len(python_files)} files, {success_count} successful") if __name__ == "__main__": main() ================================================ FILE: tool/remove_docstrings.sh ================================================ #!/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates # Shell script wrapper for the remove_docstrings.py tool set -e # Script directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # Show usage function show_usage { echo "Usage: $(basename $0) [options] DIRECTORY" echo "" echo "Options:" echo " -h, --help Show this help message" echo " -d, --dry-run Perform a dry run (no changes are made)" echo " -b, --backup Create backup files before making changes" echo "" echo "Example:" echo " $(basename $0) ~/my-python-project" echo " $(basename $0) --dry-run ~/my-python-project" exit 1 } # Parse arguments DRY_RUN="" BACKUP=false DIRECTORY="" while [[ $# -gt 0 ]]; do case $1 in -h|--help) show_usage ;; -d|--dry-run) DRY_RUN="--dry-run" shift ;; -b|--backup) BACKUP=true shift ;; *) if [[ -z "$DIRECTORY" ]]; then DIRECTORY="$1" else echo "Error: Too many arguments" show_usage fi shift ;; esac done # Check if directory is provided if [[ -z "$DIRECTORY" ]]; then echo "Error: No directory specified" show_usage fi # Check if directory exists if [[ ! -d "$DIRECTORY" ]]; then echo "Error: Directory does not exist: $DIRECTORY" exit 1 fi # Create backups if requested if [[ "$BACKUP" = true ]]; then echo "Creating backups of Python files..." find "$DIRECTORY" -name "*.py" -type f -exec cp {} {}.bak \; echo "Backups created with .bak extension" fi # Run the Python script python3 "$SCRIPT_DIR/remove_docstrings.py" $DRY_RUN "$DIRECTORY" echo "Done!" ================================================ FILE: tool/serve_local_llm.sh ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \ --model Your-Model-Name \ --tensor-parallel-size 8 \ --quantization fp8 \ --gpu-memory-utilization 0.9 \ --dtype bfloat16 \ --host 0.0.0.0 \ --port 8000