Repository: google/langextract Branch: main Commit: f48cdb27c7f5 Files: 124 Total size: 1.1 MB Directory structure: gitextract_s1tifoud/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── 1-bug.md │ │ ├── 2-feature-request.md │ │ └── config.yml │ ├── PULL_REQUEST_TEMPLATE/ │ │ └── pull_request_template.md │ ├── scripts/ │ │ ├── add-new-checks.sh │ │ ├── add-size-labels.sh │ │ ├── revalidate-all-prs.sh │ │ └── zenodo_publish.py │ └── workflows/ │ ├── auto-update-pr.yaml │ ├── check-infrastructure-changes.yml │ ├── check-linked-issue.yml │ ├── check-pr-size.yml │ ├── check-pr-up-to-date.yaml │ ├── ci.yaml │ ├── publish.yml │ ├── revalidate-pr.yml │ ├── validate-community-providers.yaml │ ├── validate_pr_template.yaml │ └── zenodo-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── CITATION.cff ├── COMMUNITY_PROVIDERS.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── autoformat.sh ├── benchmarks/ │ ├── benchmark.py │ ├── config.py │ ├── plotting.py │ └── utils.py ├── docs/ │ └── examples/ │ ├── batch_api_example.md │ ├── japanese_extraction.md │ ├── longer_text_example.md │ └── medication_examples.md ├── examples/ │ ├── custom_provider_plugin/ │ │ ├── README.md │ │ ├── langextract_provider_example/ │ │ │ ├── __init__.py │ │ │ ├── provider.py │ │ │ └── schema.py │ │ ├── pyproject.toml │ │ └── test_example_provider.py │ ├── notebooks/ │ │ └── romeo_juliet_extraction.ipynb │ └── ollama/ │ ├── .dockerignore │ ├── Dockerfile │ ├── README.md │ ├── demo_ollama.py │ └── docker-compose.yml ├── langextract/ │ ├── __init__.py │ ├── _compat/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── exceptions.py │ │ ├── inference.py │ │ ├── registry.py │ │ └── schema.py │ ├── annotation.py │ ├── chunking.py │ ├── core/ │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── data.py │ │ ├── debug_utils.py │ │ ├── exceptions.py │ │ ├── format_handler.py │ │ ├── schema.py │ │ ├── tokenizer.py │ │ └── types.py │ ├── data.py │ ├── data_lib.py │ ├── exceptions.py │ ├── extraction.py │ ├── factory.py │ ├── inference.py │ ├── io.py │ ├── plugins.py │ ├── progress.py │ ├── prompt_validation.py │ ├── prompting.py │ ├── providers/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── builtin_registry.py │ │ ├── gemini.py │ │ ├── gemini_batch.py │ │ ├── ollama.py │ │ ├── openai.py │ │ ├── patterns.py │ │ ├── router.py │ │ └── schemas/ │ │ ├── __init__.py │ │ └── gemini.py │ ├── py.typed │ ├── registry.py │ ├── resolver.py │ ├── schema.py │ ├── tokenizer.py │ └── visualization.py ├── pyproject.toml ├── scripts/ │ ├── create_provider_plugin.py │ └── validate_community_providers.py ├── tests/ │ ├── .pylintrc │ ├── annotation_test.py │ ├── chunking_test.py │ ├── data_lib_test.py │ ├── extract_precedence_test.py │ ├── extract_schema_integration_test.py │ ├── factory_schema_test.py │ ├── factory_test.py │ ├── format_handler_test.py │ ├── inference_test.py │ ├── init_test.py │ ├── progress_test.py │ ├── prompt_validation_test.py │ ├── prompting_test.py │ ├── provider_plugin_test.py │ ├── provider_schema_test.py │ ├── registry_test.py │ ├── resolver_test.py │ ├── schema_test.py │ ├── test_gemini_batch_api.py │ ├── test_kwargs_passthrough.py │ ├── test_live_api.py │ ├── test_ollama_integration.py │ ├── tokenizer_test.py │ └── visualization_test.py └── tox.ini ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/1-bug.md ================================================ --- name: Bug Report about: Create a bug report to help us improve title: 'Bug: ' labels: 'bug', 'needs triage' assignees: '' --- ## Describe the overall issue and situation Provide a clear summary of what the issue is about, the area of the project you found it in, and what you were trying to do. ## Expected behavior Provide a clear and concise description of what you expected to happen ## Actual behavior Provide a clear and concise description of what actually happened. ## Steps to reproduce the issue Provide a sequence of steps we can use to reproduce the issue. 1. 2. 3. ## Any additional content Describe your environment or any other set up details that might help us reproduce the issue. ================================================ FILE: .github/ISSUE_TEMPLATE/2-feature-request.md ================================================ --- name: Feature Request about: Suggest an idea or improvement title: 'Request: ' labels: 'enhancement', 'needs triage' assignees: '' --- ## Describe the overall idea and motivation Provide a clear summary of the idea and what use cases it's addressing. ## Related to an issue? Is this addressing a known / documented issue? If so, which one? ## Possible solutions and alternatives Do you already have an idea of how the solution should work? If so, document that here. Also, if there are alternatives, please document those as well. ## Priority and timeline considerations Is this time sensitive? Is it a nice to have? Please describe what priority you feel this should have and why. We'll take this into advisement as we go through our internal prioritization process. ## Additional context Is there anything else to consider that wasn't covered by the above? Would you like to contribute to the project and work on this request? ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Allow users to create issues that don't follow the templates since they don't cover all use cases blank_issues_enabled: true # Redirect users to other channels for general support or security issues contact_links: - name: Community Support url: https://github.com/google/langextract/discussions about: Please ask and answer questions here. - name: Security Bug Reporting url: https://g.co/vulnz about: > To report a security issue, please use https://g.co/vulnz. The Google Security Team will respond within 5 working days of your report on https://g.co/vulnz. ================================================ FILE: .github/PULL_REQUEST_TEMPLATE/pull_request_template.md ================================================ # Description Replace this with a clear and concise change description Fixes/Related to #[issue number] Choose one: (Bug fix | Feature | Documentation | Testing | Code health | Other) # How Has This Been Tested? Replace this with a description of the tests that you ran to verify your changes. If executing the existing test suite without customization, simply paste the command line used. ``` $ python -m unittest discover ... ``` # Checklist: - [ ] I have read and acknowledged Google's Open Source [Code of conduct](https://opensource.google/conduct). - [ ] I have read the [Contributing](https://github.com/google-health/langextract/blob/master/CONTRIBUTING.md) page, and I either signed the Google [Individual CLA](https://cla.developers.google.com/about/google-individual) or am covered by my company's [Corporate CLA](https://cla.developers.google.com/about/google-corporate). - [ ] I have discussed my proposed solution with code owners in the linked issue(s) and we have agreed upon the general approach. - [ ] I have made any needed documentation changes, or noted in the linked issue(s) that documentation elsewhere needs updating. - [ ] I have added tests, or I have ensured existing tests cover the changes - [ ] I have followed [Google's Python Style Guide](https://google.github.io/styleguide/pyguide.html) and ran `pylint` over the affected code. ================================================ FILE: .github/scripts/add-new-checks.sh ================================================ #!/bin/bash # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Script to add new required status checks to an existing branch protection rule. # This preserves all your current settings and just adds the new checks echo "Adding new PR validation checks to existing branch protection..." # Add the new checks to existing ones echo "Adding new checks: enforce, size, and protect-infrastructure..." gh api repos/:owner/:repo/branches/main/protection/required_status_checks/contexts \ --method POST \ --input - <<< '["enforce", "size", "protect-infrastructure"]' echo "" echo "✓ New checks added!" echo "" echo "Updated required status checks will include:" echo "- test (3.10) [existing]" echo "- test (3.11) [existing]" echo "- test (3.12) [existing]" echo "- Validate PR Template [existing]" echo "- live-api-tests [existing]" echo "- ollama-integration-test [existing]" echo "- enforce [NEW - linked issue validation]" echo "- size [NEW - PR size limit]" echo "- protect-infrastructure [NEW - infrastructure file protection]" ================================================ FILE: .github/scripts/add-size-labels.sh ================================================ #!/bin/bash # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Add size labels to PRs based on their change count echo "Adding size labels to PRs..." # Get all open PRs with their additions and deletions gh pr list --limit 50 --json number,additions,deletions --jq '.[]' | while read -r pr_data; do pr_number=$(echo "$pr_data" | jq -r '.number') additions=$(echo "$pr_data" | jq -r '.additions') deletions=$(echo "$pr_data" | jq -r '.deletions') total_changes=$((additions + deletions)) # Determine size label if [ $total_changes -lt 50 ]; then size_label="size/XS" elif [ $total_changes -lt 150 ]; then size_label="size/S" elif [ $total_changes -lt 600 ]; then size_label="size/M" elif [ $total_changes -lt 1000 ]; then size_label="size/L" else size_label="size/XL" fi echo "PR #$pr_number: $total_changes lines -> $size_label" # Remove any existing size labels first existing_labels=$(gh pr view $pr_number --json labels --jq '.labels[].name' | grep "^size/" || true) if [ ! -z "$existing_labels" ]; then echo " Removing existing label: $existing_labels" gh pr edit $pr_number --remove-label "$existing_labels" fi # Add the new size label gh pr edit $pr_number --add-label "$size_label" sleep 1 # Avoid rate limiting done echo "Done adding size labels!" ================================================ FILE: .github/scripts/revalidate-all-prs.sh ================================================ #!/bin/bash # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Revalidate all open PRs echo "Fetching all open PRs..." PR_NUMBERS=$(gh pr list --limit 50 --json number --jq '.[].number') TOTAL=$(echo "$PR_NUMBERS" | wc -w | tr -d ' ') echo "Found $TOTAL open PRs" echo "Starting revalidation..." echo "" COUNT=0 for pr in $PR_NUMBERS; do COUNT=$((COUNT + 1)) echo "[$COUNT/$TOTAL] Triggering revalidation for PR #$pr..." gh workflow run revalidate-pr.yml -f pr_number=$pr # Small delay to avoid rate limiting sleep 2 done echo "" echo "All workflows triggered!" echo "" echo "To monitor progress:" echo " gh run list --workflow=revalidate-pr.yml --limit=$TOTAL" echo "" echo "To see results, check comments on each PR" ================================================ FILE: .github/scripts/zenodo_publish.py ================================================ #!/usr/bin/env python3 # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Publish a new version to Zenodo via REST API. This script reads project metadata from pyproject.toml to avoid duplication. For subsequent releases, it creates new versions from the existing Zenodo record, inheriting most metadata automatically. """ import glob import os import sys import tomllib import urllib.request import requests API = "https://zenodo.org/api" TOKEN = os.environ["ZENODO_TOKEN"] RECORD_ID = os.environ["ZENODO_RECORD_ID"] VERSION = os.environ["RELEASE_TAG"].lstrip("v") REPO = os.environ["GITHUB_REPOSITORY"] SERVER = os.environ.get("GITHUB_SERVER_URL", "https://github.com") HEADERS = { "Authorization": f"Bearer {TOKEN}", "Content-Type": "application/json", } try: with open("pyproject.toml", "rb") as f: pyproject = tomllib.load(f) PROJECT_META = pyproject["project"] PROJECT = PROJECT_META["name"] except (KeyError, FileNotFoundError) as e: print(f"❌ Error loading project metadata: {e}", file=sys.stderr) sys.exit(1) def new_version_from_record(record_id: str): """Create a new draft that inherits metadata from the latest published record.""" r = requests.post( f"{API}/deposit/depositions/{record_id}/actions/newversion", headers=HEADERS, timeout=30, ) r.raise_for_status() # Zenodo returns a link to the draft, not the draft itself latest_draft_url = r.json()["links"]["latest_draft"] return requests.get(latest_draft_url, headers=HEADERS, timeout=30).json() def upload_file(bucket_url: str, path: str, dest_name: str = None): """Upload a file to the deposition bucket.""" dest = dest_name or os.path.basename(path) with open(path, "rb") as fp: r = requests.put( f"{bucket_url}/{dest}", data=fp, headers={"Authorization": f"Bearer {TOKEN}"}, timeout=60, ) r.raise_for_status() def main(): """Main workflow.""" try: draft = new_version_from_record(RECORD_ID) bucket = draft["links"]["bucket"] dep_id = draft["id"] # GitHub auto-generates source archives for tags tarball = f"/tmp/{PROJECT}-v{VERSION}.tar.gz" src_url = f"{SERVER}/{REPO}/archive/refs/tags/v{VERSION}.tar.gz" urllib.request.urlretrieve(src_url, tarball) upload_file(bucket, tarball, f"{PROJECT}-{VERSION}.tar.gz") for path in glob.glob("dist/*"): upload_file(bucket, path) # Update only version-specific metadata; rest is inherited meta = { "metadata": { "title": f"{PROJECT.replace('-', ' ').title()} v{VERSION}", "version": VERSION, "upload_type": "software", } } r = requests.put( f"{API}/deposit/depositions/{dep_id}", headers=HEADERS, json=meta, timeout=30, ) r.raise_for_status() # Publish to mint DOI r = requests.post( f"{API}/deposit/depositions/{dep_id}/actions/publish", headers=HEADERS, timeout=30, ) r.raise_for_status() record = r.json() doi = record.get("doi") record_id = record.get("record_id") print(f"✅ Published to Zenodo: https://doi.org/{doi}") if "GITHUB_OUTPUT" in os.environ: with open(os.environ["GITHUB_OUTPUT"], "a") as f: f.write(f"doi={doi}\n") f.write(f"record_id={record_id}\n") f.write(f"zenodo_url=https://zenodo.org/records/{record_id}\n") return 0 except Exception as e: print(f"❌ Error: {e}", file=sys.stderr) return 1 if __name__ == "__main__": sys.exit(main()) ================================================ FILE: .github/workflows/auto-update-pr.yaml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Auto Update PR on: push: branches: [main] schedule: # Run daily at 2 AM UTC to catch stale PRs - cron: '0 2 * * *' workflow_dispatch: inputs: pr_number: description: 'PR number to update (optional, updates all if not specified)' required: false type: string permissions: contents: write # Required for updateBranch API pull-requests: write issues: write jobs: update-prs: runs-on: ubuntu-latest concurrency: group: auto-update-pr-${{ github.event_name }} cancel-in-progress: true steps: - name: Update PRs that are behind main uses: actions/github-script@v7 with: script: | const prNumber = context.payload.inputs?.pr_number; // Get list of open PRs const prs = prNumber ? [(await github.rest.pulls.get({ owner: context.repo.owner, repo: context.repo.repo, pull_number: parseInt(prNumber) })).data] : await github.paginate(github.rest.pulls.list, { owner: context.repo.owner, repo: context.repo.repo, state: 'open', sort: 'updated', direction: 'desc' }); console.log(`Found ${prs.length} open PRs to check`); // Constants for comment flood control const UPDATE_COMMENT_COOLDOWN_DAYS = 7; const COOLDOWN_MS = UPDATE_COMMENT_COOLDOWN_DAYS * 24 * 60 * 60 * 1000; for (const pr of prs) { // Skip bot PRs and drafts if (pr.user.login.includes('[bot]')) { console.log(`Skipping bot PR #${pr.number} from ${pr.user.login}`); continue; } if (pr.draft) { console.log(`Skipping draft PR #${pr.number}`); continue; } try { // Check if PR is behind main (base...head comparison) const { data: comparison } = await github.rest.repos.compareCommits({ owner: context.repo.owner, repo: context.repo.repo, base: pr.base.ref, // main branch head: `${pr.head.repo.owner.login}:${pr.head.ref}` // Fully qualified ref for forks }); if (comparison.behind_by > 0) { console.log(`PR #${pr.number} is ${comparison.behind_by} commits behind ${pr.base.ref}`); // Check if the PR allows maintainer edits if (pr.maintainer_can_modify) { // Try to update the branch try { await github.rest.pulls.updateBranch({ owner: context.repo.owner, repo: context.repo.repo, pull_number: pr.number }); console.log(`✅ Updated PR #${pr.number}`); // Add a comment await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pr.number, body: `🔄 **Branch Updated**\n\nYour branch was ${comparison.behind_by} commits behind \`${pr.base.ref}\` and has been automatically updated. CI checks will re-run shortly.` }); } catch (updateError) { console.log(`Could not auto-update PR #${pr.number}: ${updateError.message}`); // Determine the reason for failure let failureReason = ''; if (updateError.status === 409 || updateError.message.includes('merge conflict')) { failureReason = '\n\n**Note:** Automatic update failed due to merge conflicts. Please resolve them manually.'; } else if (updateError.status === 422) { failureReason = '\n\n**Note:** Cannot push to fork. Please update manually.'; } // Notify the contributor to update manually await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pr.number, body: `⚠️ **Branch Update Required**\n\nYour branch is ${comparison.behind_by} commits behind \`${pr.base.ref}\`.${failureReason}\n\nPlease update your branch:\n\n\`\`\`bash\ngit fetch origin ${pr.base.ref}\ngit merge origin/${pr.base.ref}\ngit push\n\`\`\`\n\nOr use GitHub's "Update branch" button if available.` }); } } else { // Can't modify, just notify console.log(`PR #${pr.number} doesn't allow maintainer edits`); // Check if we already commented recently (within last 7 days) const { data: comments } = await github.rest.issues.listComments({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pr.number, since: new Date(Date.now() - COOLDOWN_MS).toISOString() }); const hasRecentUpdateComment = comments.some(c => c.body?.includes('Branch Update Required') && c.user?.login === 'github-actions[bot]' ); if (!hasRecentUpdateComment) { await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pr.number, body: `⚠️ **Branch Update Required**\n\nYour branch is ${comparison.behind_by} commits behind \`${pr.base.ref}\`. Please update your branch to ensure CI checks run with the latest code:\n\n\`\`\`bash\ngit fetch origin ${pr.base.ref}\ngit merge origin/${pr.base.ref}\ngit push\n\`\`\`\n\nNote: Enable "Allow edits by maintainers" to allow automatic updates.` }); } } } else { console.log(`PR #${pr.number} is up to date`); } } catch (error) { console.error(`Error processing PR #${pr.number}:`, error.message); } } // Log rate limit status const { data: rateLimit } = await github.rest.rateLimit.get(); console.log(`API rate limit remaining: ${rateLimit.rate.remaining}/${rateLimit.rate.limit}`); ================================================ FILE: .github/workflows/check-infrastructure-changes.yml ================================================ name: Protect Infrastructure Files on: pull_request_target: types: [opened, synchronize, reopened] workflow_dispatch: permissions: contents: read pull-requests: write jobs: protect-infrastructure: if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false runs-on: ubuntu-latest steps: - name: Check for infrastructure file changes if: github.event_name == 'pull_request_target' uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | // Get the PR author and check if they're a maintainer const prAuthor = context.payload.pull_request.user.login; const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({ owner: context.repo.owner, repo: context.repo.repo, username: prAuthor }); const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission); // Get list of files changed in the PR const { data: files } = await github.rest.pulls.listFiles({ owner: context.repo.owner, repo: context.repo.repo, pull_number: context.payload.pull_request.number }); // Check for infrastructure file changes const infrastructureFiles = files.filter(file => file.filename.startsWith('.github/') || file.filename === 'pyproject.toml' || file.filename === 'tox.ini' || file.filename === '.pre-commit-config.yaml' || file.filename === '.pylintrc' || file.filename === 'Dockerfile' || file.filename === 'autoformat.sh' || file.filename === '.gitignore' || file.filename === 'CONTRIBUTING.md' || file.filename === 'LICENSE' || file.filename === 'CITATION.cff' ); if (infrastructureFiles.length > 0 && !isMaintainer) { // Check if changes are only formatting/whitespace let hasStructuralChanges = false; for (const file of infrastructureFiles) { const additions = file.additions || 0; const deletions = file.deletions || 0; const changes = file.changes || 0; // If file has significant changes (not just whitespace), consider it structural if (additions > 5 || deletions > 5 || changes > 10) { hasStructuralChanges = true; break; } } const fileList = infrastructureFiles.map(f => ` - ${f.filename} (${f.changes} changes)`).join('\n'); // Post a comment explaining the issue await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.payload.pull_request.number, body: `❌ **Infrastructure File Protection**\n\n` + `This PR modifies protected infrastructure files:\n\n${fileList}\n\n` + `Only repository maintainers are allowed to modify infrastructure files (including \`.github/\`, build configuration, and repository documentation).\n\n` + `**Note**: If these are only formatting changes, please:\n` + `1. Revert changes to \`.github/\` files\n` + `2. Use \`./autoformat.sh\` to format only source code directories\n` + `3. Avoid running formatters on infrastructure files\n\n` + `If structural changes are necessary:\n` + `1. Open an issue describing the needed infrastructure changes\n` + `2. A maintainer will review and implement the changes if approved\n\n` + `For more information, see our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md).` }); core.setFailed( `This PR modifies ${infrastructureFiles.length} protected infrastructure file(s). ` + `Only maintainers can modify these files. ` + `Use ./autoformat.sh to format code without touching infrastructure.` ); } else if (infrastructureFiles.length > 0 && isMaintainer) { core.info(`PR modifies ${infrastructureFiles.length} infrastructure file(s) - allowed for maintainer ${prAuthor}`); } else { core.info('No infrastructure files modified'); } ================================================ FILE: .github/workflows/check-linked-issue.yml ================================================ name: Require linked issue with community support on: pull_request_target: types: [opened, edited, synchronize, reopened, ready_for_review] permissions: contents: read issues: write pull-requests: write concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: enforce: if: github.event_name == 'pull_request_target' && !github.event.pull_request.draft runs-on: ubuntu-latest steps: - name: Check linked issue and community support uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | // Strip code blocks and inline code to avoid false matches const stripCode = txt => txt.replace(/```[\s\S]*?```/g, '').replace(/`[^`]*`/g, ''); // Combine title + body for comprehensive search const prText = stripCode(`${context.payload.pull_request.title || ''}\n${context.payload.pull_request.body || ''}`); // Issue reference pattern: #123, org/repo#123, or full URL (with http/https and optional www) const issueRef = String.raw`(?:#(?\d+)|(?[\w.-]+)\/(?[\w.-]+)#(?\d+)|https?:\/\/(?:www\.)?github\.com\/(?[\w.-]+)\/(?[\w.-]+)\/issues\/(?\d+))`; // Keywords - supporting common variants const closingRe = new RegExp(String.raw`\b(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)\b\s*:?\s+${issueRef}`, 'gi'); const referenceRe = new RegExp(String.raw`\b(?:related\s+to|relates\s+to|refs?|part\s+of|addresses|see(?:\s+also)?|depends\s+on|blocked\s+by|supersedes)\b\s*:?\s+${issueRef}`, 'gi'); // Gather all matches const closings = [...prText.matchAll(closingRe)]; const references = [...prText.matchAll(referenceRe)]; const first = closings[0] || references[0]; // Check for draft PRs and bots const pr = context.payload.pull_request; const isDraft = !!pr.draft; const login = pr.user.login; const isBot = pr.user.type === 'Bot' || /\[bot\]$/.test(login); if (isDraft || isBot) { core.info('Draft or bot PR – skipping enforcement'); return; } // Check if PR author is a maintainer let authorPerm = 'none'; try { const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ owner: context.repo.owner, repo: context.repo.repo, username: pr.user.login, }); authorPerm = data.permission || 'none'; } catch (_) { // User might not have any permissions } core.info(`Author permission: ${authorPerm}`); const isMaintainer = ['admin', 'maintain'].includes(authorPerm); // Removed 'write' for stricter maintainer definition // Maintainers bypass entirely if (isMaintainer) { core.info(`Maintainer ${pr.user.login} - bypassing linked issue requirement`); return; } if (!first) { // Check for existing comment to avoid duplicates const MARKER = ''; const existing = await github.paginate(github.rest.issues.listComments, { owner: context.repo.owner, repo: context.repo.repo, issue_number: context.payload.pull_request.number, per_page: 100, }); const alreadyLeft = existing.some(c => c.body && c.body.includes(MARKER)); if (!alreadyLeft) { const contribUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/blob/main/CONTRIBUTING.md#pull-request-guidelines`; const commentBody = [ 'No linked issues found. Please link an issue in your pull request description or title.', '', `Per our [Contributing Guidelines](${contribUrl}), all PRs must:`, '- Reference an issue with one of:', ' - **Closing keywords**: `Fixes #123`, `Closes #123`, `Resolves #123` (auto-closes on merge in the same repository)', ' - **Reference keywords**: `Related to #123`, `Refs #123`, `Part of #123`, `See #123` (links without closing)', '- The linked issue should have 5+ 👍 reactions from unique users (excluding bots and the PR author)', '- Include discussion demonstrating the importance of the change', '', 'You can also use cross-repo references like `owner/repo#123` or full URLs.', '', MARKER ].join('\n'); await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.payload.pull_request.number, body: commentBody }); } core.setFailed('No linked issue found. Use "Fixes #123" to close an issue or "Related to #123" to reference it.'); return; } // Resolve owner/repo/number, defaulting to the current repo const groups = first.groups || {}; const owner = groups.o1 || groups.o2 || context.repo.owner; const repo = groups.r1 || groups.r2 || context.repo.repo; const issue_number = Number(groups.num || groups.n1 || groups.n2); // Validate issue number if (!Number.isInteger(issue_number) || issue_number <= 0) { core.setFailed( 'Found a potential issue link but no valid number. ' + 'Use "Fixes #123" or "Related to owner/repo#123".' ); return; } core.info(`Found linked issue: ${owner}/${repo}#${issue_number}`); // Count unique users who reacted with 👍 on the linked issue (excluding bots and PR author) try { const reactions = await github.paginate(github.rest.reactions.listForIssue, { owner, repo, issue_number, per_page: 100, }); const prAuthorId = pr.user.id; const uniqueThumbs = new Set( reactions .filter(r => r.content === '+1' && r.user && r.user.id !== prAuthorId && r.user.type !== 'Bot' && !String(r.user.login || '').endsWith('[bot]') ) .map(r => r.user.id) ).size; core.info(`Issue ${owner}/${repo}#${issue_number} has ${uniqueThumbs} unique 👍 reactions`); const REQUIRED_THUMBS_UP = 5; if (uniqueThumbs < REQUIRED_THUMBS_UP) { core.setFailed(`Linked issue ${owner}/${repo}#${issue_number} has only ${uniqueThumbs} 👍 (need ${REQUIRED_THUMBS_UP}).`); return; } } catch (error) { const isSameRepo = owner === context.repo.owner && repo === context.repo.repo; if (error.status === 404 || error.status === 403) { if (!isSameRepo) { core.setFailed( `Linked issue ${owner}/${repo}#${issue_number} is not accessible. ` + `Please link to an issue in ${context.repo.owner}/${context.repo.repo} or a public repo.` ); } else { core.info(`Cannot access reactions for ${owner}/${repo}#${issue_number}; skipping enforcement for same-repo issue.`); } return; } // Any other error should fail to prevent accidental bypass const msg = (error && error.message) ? String(error.message).toLowerCase() : ''; const isRateLimit = msg.includes('rate limit') || error?.headers?.['x-ratelimit-remaining'] === '0'; if (isRateLimit) { core.setFailed(`Rate limit while checking reactions for ${owner}/${repo}#${issue_number}. Please retry the workflow.`); } else { core.setFailed(`Unexpected error checking reactions for ${owner}/${repo}#${issue_number}: ${error?.message || error}`); } } ================================================ FILE: .github/workflows/check-pr-size.yml ================================================ name: Check PR size on: pull_request_target: types: [opened, synchronize, reopened] workflow_dispatch: inputs: pr_number: description: 'PR number to check (optional)' required: false type: string permissions: contents: read pull-requests: write issues: write concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.run_id }} cancel-in-progress: true jobs: size: runs-on: ubuntu-latest steps: - name: Get PR data for manual trigger if: github.event_name == 'workflow_dispatch' && github.event.inputs.pr_number id: get_pr uses: actions/github-script@v7 with: result-encoding: string script: | const { data } = await github.rest.pulls.get({ owner: context.repo.owner, repo: context.repo.repo, pull_number: ${{ github.event.inputs.pr_number }} }); return JSON.stringify(data); - name: Evaluate PR size if: github.event_name == 'pull_request_target' || (github.event_name == 'workflow_dispatch' && github.event.inputs.pr_number) uses: actions/github-script@v7 env: PR_JSON: ${{ steps.get_pr.outputs.result }} with: script: | const pr = context.payload.pull_request || JSON.parse(process.env.PR_JSON || '{}'); if (!pr || !pr.number) { core.setFailed('Unable to resolve PR data. For workflow_dispatch, pass a valid pr_number.'); return; } // Check for draft PRs and bots const isDraft = !!pr.draft; const login = pr.user.login; const isBot = pr.user.type === 'Bot' || /\[bot\]$/.test(login); if (isDraft || isBot) { core.info('Draft or bot PR – skipping size enforcement'); return; } const totalChanges = pr.additions + pr.deletions; core.info(`PR contains ${pr.additions} additions and ${pr.deletions} deletions (${totalChanges} total)`); const sizeLabel = totalChanges < 50 ? 'size/XS' : totalChanges < 150 ? 'size/S' : totalChanges < 600 ? 'size/M' : totalChanges < 1000 ? 'size/L' : 'size/XL'; // Re-fetch labels to avoid acting on stale payload data const { data: freshIssue } = await github.rest.issues.get({ ...context.repo, issue_number: pr.number }); const currentLabels = (freshIssue.labels || []).map(l => l.name); // Remove old size labels before adding new one const allSizeLabels = ['size/XS', 'size/S', 'size/M', 'size/L', 'size/XL']; const toRemove = currentLabels.filter(name => allSizeLabels.includes(name) && name !== sizeLabel); for (const name of toRemove) { try { await github.rest.issues.removeLabel({ ...context.repo, issue_number: pr.number, name }); } catch (_) { // Ignore if already removed } } await github.rest.issues.addLabels({ ...context.repo, issue_number: pr.number, labels: [sizeLabel] }); // Check if PR author is a maintainer let authorPerm = 'none'; try { const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ owner: context.repo.owner, repo: context.repo.repo, username: pr.user.login, }); authorPerm = data.permission || 'none'; } catch (_) { // User might not have any permissions } core.info(`Author permission: ${authorPerm}`); const isMaintainer = ['admin', 'maintain'].includes(authorPerm); // Stricter maintainer definition // Check for bypass label (using fresh labels) const hasBypass = currentLabels.includes('bypass:size-limit'); const MAX_LINES = 1000; if (totalChanges > MAX_LINES) { if (isMaintainer || hasBypass) { core.info(`${isMaintainer ? 'Maintainer' : 'Bypass label'} - allowing large PR with ${totalChanges} lines`); } else { core.setFailed( `This PR contains ${totalChanges} lines of changes, which exceeds the maximum of ${MAX_LINES} lines. ` + `Please split this into smaller, focused pull requests.` ); } } ================================================ FILE: .github/workflows/check-pr-up-to-date.yaml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Check PR Up-to-Date on: pull_request: types: [opened, synchronize] permissions: contents: read pull-requests: write jobs: check-up-to-date: runs-on: ubuntu-latest # Skip for bot PRs if: ${{ !contains(github.actor, '[bot]') }} concurrency: group: check-pr-${{ github.event.pull_request.number }} cancel-in-progress: true steps: - name: Checkout repository uses: actions/checkout@v4 with: fetch-depth: 2 # Sufficient for rev-list comparison - name: Check if PR is up-to-date with main id: check run: | # Fetch the latest main branch git fetch origin main # Check how many commits behind main BEHIND=$(git rev-list --count HEAD..origin/main) echo "commits_behind=$BEHIND" >> $GITHUB_OUTPUT if [ "$BEHIND" -gt 0 ]; then echo "::warning::PR is $BEHIND commits behind main" exit 0 # Don't fail the check, just warn else echo "PR is up-to-date with main" fi - name: Comment if PR needs update if: ${{ steps.check.outputs.commits_behind != '0' }} uses: actions/github-script@v7 with: script: | const behind = ${{ steps.check.outputs.commits_behind }}; const COMMENT_COOLDOWN_HOURS = 24; const COOLDOWN_MS = COMMENT_COOLDOWN_HOURS * 60 * 60 * 1000; // Check for recent similar comments const { data: comments } = await github.rest.issues.listComments({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.payload.pull_request.number, per_page: 10 }); const hasRecentComment = comments.some(c => c.body?.includes('commits behind `main`') && c.user?.login === 'github-actions[bot]' && new Date(c.created_at) > new Date(Date.now() - COOLDOWN_MS) ); if (!hasRecentComment) { await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.payload.pull_request.number, body: `📊 **PR Status**: ${behind} commits behind \`main\`\n\nConsider updating your branch for the most accurate CI results:\n\n**Option 1**: Use GitHub's "Update branch" button (if available)\n\n**Option 2**: Update locally:\n\`\`\`bash\ngit fetch origin main\ngit merge origin/main\ngit push\n\`\`\`\n\n*Note: If you use a different remote name (e.g., upstream), adjust the commands accordingly.*\n\nThis ensures your changes are tested against the latest code.` }); } ================================================ FILE: .github/workflows/ci.yaml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: CI on: workflow_dispatch: push: branches: ["main"] pull_request: branches: ["main"] pull_request_target: types: [labeled] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true permissions: contents: read jobs: format-check: runs-on: ubuntu-latest if: github.event_name == 'pull_request' permissions: contents: read issues: write steps: - name: Checkout PR branch uses: actions/checkout@v4 with: repository: ${{ github.event.pull_request.head.repo.full_name }} ref: ${{ github.event.pull_request.head.ref }} persist-credentials: false - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install format tools run: | python -m pip install --upgrade pip pip install -e ".[dev]" - name: Check formatting id: format-check env: GITHUB_TOKEN: "" run: | set -euo pipefail pyink --check --diff . isort --check-only --diff . - name: Check import structure id: import-check env: GITHUB_TOKEN: "" run: | set -euo pipefail lint-imports --config pyproject.toml - name: Comment on PR if formatting fails if: failure() && steps.format-check.outcome == 'failure' uses: actions/github-script@v7 continue-on-error: true with: script: | github.rest.issues.createComment({ issue_number: context.payload.pull_request.number, owner: context.repo.owner, repo: context.repo.repo, body: '❌ **Formatting Check Failed**\n\nYour PR has formatting issues. Please run the following command locally and push the changes:\n\n```bash\n./autoformat.sh\n```\n\nThis will automatically fix all formatting issues using pyink (Google\'s Python formatter) and isort.' }).catch(err => { console.log('Comment posting failed:', err.message); }); test: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 with: persist-credentials: false - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install tox pip install -e ".[dev,test]" - name: Run unit tests and linting run: | PY_VERSION=$(echo "${{ matrix.python-version }}" | tr -d '.') # Format check is handled by separate job for better isolation tox -e py${PY_VERSION},lint-src,lint-tests live-api-tests: needs: test runs-on: ubuntu-latest if: | github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) steps: - uses: actions/checkout@v4 with: persist-credentials: false - name: Set up Python 3.11 uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip pip install tox pip install -e ".[dev,test]" - name: Run live API tests env: GITHUB_TOKEN: "" run: | set -euo pipefail if [[ -z '${{ secrets.GEMINI_API_KEY }}' && -z '${{ secrets.OPENAI_API_KEY }}' ]]; then echo "::notice::Live API tests skipped - API keys not configured" exit 0 fi GEMINI_API_KEY="${{ secrets.GEMINI_API_KEY }}" \ LANGEXTRACT_API_KEY="${{ secrets.GEMINI_API_KEY }}" \ OPENAI_API_KEY="${{ secrets.OPENAI_API_KEY }}" \ tox -e live-api plugin-integration-test: needs: test runs-on: ubuntu-latest if: github.event_name == 'pull_request' permissions: contents: read pull-requests: read steps: - uses: actions/checkout@v4 with: persist-credentials: false fetch-depth: 0 - name: Detect provider-related changes id: provider-changes uses: tj-actions/changed-files@v46 with: files: | langextract/providers/** langextract/factory.py langextract/inference.py tests/provider_plugin_test.py pyproject.toml .github/workflows/ci.yaml - name: Skip if no provider changes if: steps.provider-changes.outputs.any_changed == 'false' run: | echo "No provider-related changes detected – skipping plugin integration test." exit 0 - name: Set up Python 3.11 uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip pip install tox - name: Run plugin smoke test run: tox -e plugin-smoke - name: Run plugin integration test run: tox -e plugin-integration ollama-integration-test: needs: test runs-on: ubuntu-latest if: github.event_name == 'pull_request' permissions: contents: read pull-requests: read steps: - uses: actions/checkout@v4 with: persist-credentials: false fetch-depth: 0 - name: Detect file changes id: changes uses: tj-actions/changed-files@v46 with: files: | langextract/inference.py examples/ollama/** tests/test_ollama_integration.py .github/workflows/ci.yaml - name: Skip if no Ollama changes if: steps.changes.outputs.any_changed == 'false' run: | echo "No Ollama-related changes detected – skipping job." exit 0 - name: Set up Python 3.11 uses: actions/setup-python@v5 with: python-version: "3.11" - name: Launch Ollama container run: | docker run -d --name ollama \ -p 127.0.0.1:11434:11434 \ -v ollama:/root/.ollama \ ollama/ollama:0.5.4 for i in {1..20}; do curl -fs http://localhost:11434/api/version && break sleep 3 done - name: Pull gemma2 model run: docker exec ollama ollama pull gemma2:2b || true - name: Install tox run: | python -m pip install --upgrade pip pip install tox - name: Run Ollama integration tests run: tox -e ollama-integration test-fork-pr: runs-on: ubuntu-latest timeout-minutes: 30 environment: name: live-keys # Triggered when a maintainer adds 'ready-to-merge' label to fork PRs only if: | github.event_name == 'pull_request_target' && github.event.action == 'labeled' && github.event.label.name == 'ready-to-merge' && github.event.pull_request.head.repo.full_name != github.repository permissions: contents: read issues: write steps: - name: Check if user is maintainer uses: actions/github-script@v7 with: script: | const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({ owner: context.repo.owner, repo: context.repo.repo, username: context.actor }); const isMaintainer = ['admin', 'maintain'].includes(permission.permission); if (!isMaintainer) { throw new Error(`User ${context.actor} does not have maintainer permissions.`); } - name: Pin commit SHA for security id: sha-pin run: | SHA_TO_TEST="${{ github.event.pull_request.head.sha }}" echo "SHA_TO_TEST=${SHA_TO_TEST}" >> $GITHUB_OUTPUT echo "::notice title=Security::Pinned commit SHA for testing: ${SHA_TO_TEST}" - name: Checkout base repo uses: actions/checkout@v4 with: ref: main fetch-depth: 0 persist-credentials: false - name: Fetch and verify exact PR commit run: | set -euo pipefail EXPECTED_SHA="${STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST}" echo "Fetching exact commit: $EXPECTED_SHA" # Fetch the specific commit SHA git fetch --no-tags --prune --no-recurse-submodules origin "$EXPECTED_SHA" || { echo "::error::Failed to fetch PR commit $EXPECTED_SHA. The commit may have been deleted." exit 1 } git checkout -b pr-to-test "$EXPECTED_SHA" # Verify checkout ACTUAL_SHA="$(git rev-parse HEAD)" if [ "$ACTUAL_SHA" != "$EXPECTED_SHA" ]; then echo "::error::SHA verification failed! Expected $EXPECTED_SHA but got $ACTUAL_SHA" exit 1 fi echo "::notice title=Security::Successfully verified commit SHA: $ACTUAL_SHA" env: STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST: ${{ steps.sha-pin.outputs.SHA_TO_TEST }} - name: Set up Python 3.11 uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install format tools run: | python -m pip install --upgrade pip # Install formatter tools with pinned versions pip install pyink==24.3.0 isort==5.13.2 lint-imports==0.3.1 - name: Validate PR formatting run: | set -euo pipefail echo "Validating code formatting..." pyink --check --diff . || { echo "::error::Code formatting (pyink) does not meet project standards. Please run ./autoformat.sh locally and push the changes." exit 1 } isort --check-only --diff . || { echo "::error::Import sorting (isort) does not meet project standards. Please run ./autoformat.sh locally and push the changes." exit 1 } - name: Checkout main branch uses: actions/checkout@v4 with: ref: main fetch-depth: 0 persist-credentials: false - name: Merge verified PR commit run: | set -euo pipefail git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" SHA_TO_MERGE="${STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST}" echo "Merging verified commit: $SHA_TO_MERGE" git fetch --no-tags --prune --no-recurse-submodules origin "$SHA_TO_MERGE" git merge --no-ff --no-edit "$SHA_TO_MERGE" || { echo "::error::Failed to merge commit $SHA_TO_MERGE" exit 1 } echo "::notice title=Security::Successfully merged verified commit" env: STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST: ${{ steps.sha-pin.outputs.SHA_TO_TEST }} - name: Add status comment uses: actions/github-script@v7 with: script: | github.rest.issues.createComment({ issue_number: context.payload.pull_request.number, owner: context.repo.owner, repo: context.repo.repo, body: 'Preparing to run live API tests (pending environment approval and API key availability)...' }); - name: Run live API tests env: GITHUB_TOKEN: "" run: | set -euo pipefail if [[ -z '${{ secrets.GEMINI_API_KEY }}' && -z '${{ secrets.OPENAI_API_KEY }}' ]]; then echo "::notice::Live API tests skipped - API keys not configured" exit 0 fi python -m pip install --upgrade pip pip install tox pip install -e ".[dev,test]" GEMINI_API_KEY="${{ secrets.GEMINI_API_KEY }}" \ LANGEXTRACT_API_KEY="${{ secrets.GEMINI_API_KEY }}" \ OPENAI_API_KEY="${{ secrets.OPENAI_API_KEY }}" \ tox -e live-api - name: Report success if: success() uses: actions/github-script@v7 with: script: | github.rest.issues.createComment({ issue_number: context.payload.pull_request.number, owner: context.repo.owner, repo: context.repo.repo, body: '✅ Live API tests passed! All endpoints are working correctly.' }); - name: Report failure if: failure() uses: actions/github-script@v7 with: script: | github.rest.issues.createComment({ issue_number: context.payload.pull_request.number, owner: context.repo.owner, repo: context.repo.repo, body: '❌ Live API tests failed. Please check the workflow logs for details.' }); ================================================ FILE: .github/workflows/publish.yml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Publish to PyPI on: release: types: [published] permissions: contents: read id-token: write jobs: pypi-publish: name: Publish to PyPI runs-on: ubuntu-latest environment: pypi permissions: id-token: write steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.11' - name: Install build dependencies run: | python -m pip install --upgrade pip pip install build - name: Build package run: python -m build - name: Verify build artifacts run: | ls -la dist/ pip install twine twine check dist/* - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 ================================================ FILE: .github/workflows/revalidate-pr.yml ================================================ name: Revalidate PR on: workflow_dispatch: inputs: pr_number: description: 'PR number to validate' required: true type: string permissions: contents: read pull-requests: write issues: write checks: write statuses: write jobs: revalidate: runs-on: ubuntu-latest steps: - name: Get PR data id: pr_data uses: actions/github-script@v7 with: script: | const { data: pr } = await github.rest.pulls.get({ owner: context.repo.owner, repo: context.repo.repo, pull_number: ${{ inputs.pr_number }} }); core.info(`Validating PR #${pr.number}: ${pr.title}`); core.info(`Author: ${pr.user.login}`); core.info(`Changes: +${pr.additions} -${pr.deletions}`); // Store head SHA for creating status core.setOutput('head_sha', pr.head.sha); return pr; - name: Create pending status uses: actions/github-script@v7 with: script: | await github.rest.repos.createCommitStatus({ owner: context.repo.owner, repo: context.repo.repo, sha: '${{ steps.pr_data.outputs.head_sha }}', state: 'pending', context: 'Manual Validation', description: 'Running validation checks...' }); - name: Validate PR id: validate uses: actions/github-script@v7 with: script: | const pr = ${{ steps.pr_data.outputs.result }}; const errors = []; let passed = true; // Check size const totalChanges = pr.additions + pr.deletions; const MAX_LINES = 1000; if (totalChanges > MAX_LINES) { errors.push(`PR size (${totalChanges} lines) exceeds ${MAX_LINES} line limit`); passed = false; } // Check template const body = pr.body || ''; const requiredSections = ["# Description", "Fixes #", "# How Has This Been Tested?", "# Checklist"]; const missingSections = requiredSections.filter(section => !body.includes(section)); if (missingSections.length > 0) { errors.push(`Missing PR template sections: ${missingSections.join(', ')}`); passed = false; } if (body.match(/Replace this with|Choose one:|Fixes #\[issue number\]/i)) { errors.push('PR template contains unmodified placeholders'); passed = false; } // Check linked issue const issueMatch = body.match(/(?:Fixes|Closes|Resolves)\s+#(\d+)/i); if (!issueMatch) { errors.push('No linked issue found'); passed = false; } // Store results core.setOutput('passed', passed); core.setOutput('errors', errors.join('; ')); core.setOutput('totalChanges', totalChanges); core.setOutput('hasTemplate', missingSections.length === 0); core.setOutput('hasIssue', !!issueMatch); if (!passed) { core.setFailed(errors.join('; ')); } - name: Update commit status if: always() uses: actions/github-script@v7 with: script: | const passed = ${{ steps.validate.outputs.passed }}; const errors = '${{ steps.validate.outputs.errors }}'; await github.rest.repos.createCommitStatus({ owner: context.repo.owner, repo: context.repo.repo, sha: '${{ steps.pr_data.outputs.head_sha }}', state: passed ? 'success' : 'failure', context: 'Manual Validation', description: passed ? 'All validation checks passed' : errors.substring(0, 140), target_url: `https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}` }); - name: Add validation comment if: always() uses: actions/github-script@v7 with: script: | const pr = ${{ steps.pr_data.outputs.result }}; const passed = ${{ steps.validate.outputs.passed }}; const totalChanges = ${{ steps.validate.outputs.totalChanges }}; const hasTemplate = ${{ steps.validate.outputs.hasTemplate }}; const hasIssue = ${{ steps.validate.outputs.hasIssue }}; const errors = '${{ steps.validate.outputs.errors }}'.split('; ').filter(e => e); let body = `### Manual Validation Results\n\n`; body += `**Status**: ${passed ? '✅ Passed' : '❌ Failed'}\n\n`; body += `| Check | Status | Details |\n`; body += `|-------|--------|----------|\n`; body += `| PR Size | ${totalChanges <= 1000 ? '✅' : '❌'} | ${totalChanges} lines ${totalChanges > 1000 ? '(exceeds 1000 limit)' : ''} |\n`; body += `| Template | ${hasTemplate ? '✅' : '❌'} | ${hasTemplate ? 'Complete' : 'Missing required sections'} |\n`; body += `| Linked Issue | ${hasIssue ? '✅' : '❌'} | ${hasIssue ? 'Found' : 'Missing Fixes/Closes #XXX'} |\n`; if (errors.length > 0) { body += `\n**Errors:**\n`; errors.forEach(error => { body += `- ❌ ${error}\n`; }); } body += `\n[View workflow run](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId})`; await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pr.number, body: body }); ================================================ FILE: .github/workflows/validate-community-providers.yaml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Validate Community Providers on: pull_request: paths: - 'COMMUNITY_PROVIDERS.md' - 'scripts/validate_community_providers.py' permissions: contents: read pull-requests: read concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: validate: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.11' - name: Validate table format run: | python scripts/validate_community_providers.py COMMUNITY_PROVIDERS.md ================================================ FILE: .github/workflows/validate_pr_template.yaml ================================================ name: Validate PR template on: pull_request_target: types: [opened, edited, synchronize, reopened] workflow_dispatch: permissions: contents: read pull-requests: read jobs: check: runs-on: ubuntu-latest steps: - name: Check PR author permissions id: check if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const pr = context.payload.pull_request; const {owner, repo} = context.repo; const actor = pr.user.login; const authorType = pr.user.type; // Check if PR author is a bot (e.g., Dependabot) if (authorType === 'Bot') { core.setOutput('skip_validation', 'true'); console.log(`Skipping validation for bot-authored PR: ${actor}`); return; } // Check if this is a community provider PR (only modifies COMMUNITY_PROVIDERS.md) const { data: files } = await github.rest.pulls.listFiles({ owner, repo, pull_number: pr.number }); const isCommunityProviderPR = files.length === 1 && files[0].filename === 'COMMUNITY_PROVIDERS.md'; if (isCommunityProviderPR) { core.setOutput('is_community_provider', 'true'); console.log('Community provider PR detected - relaxed validation will apply'); } else { core.setOutput('is_community_provider', 'false'); } // Get permission level try { const { data } = await github.rest.repos.getCollaboratorPermissionLevel({ owner, repo, username: actor }); const permission = data.permission; // admin|maintain|write|triage|read|none console.log(`Actor ${actor} has permission level: ${permission}`); // Check if user has write+ permissions if (['admin', 'maintain', 'write'].includes(permission)) { core.setOutput('skip_validation', 'true'); console.log(`Skipping validation for maintainer: ${actor} (${permission})`); } else { core.setOutput('skip_validation', 'false'); console.log(`Validation required for: ${actor} (${permission})`); } } catch (e) { // If we can't determine permissions, require validation core.setOutput('skip_validation', 'false'); core.warning(`Permission lookup failed: ${e.message}`); } - name: Validate PR template if: | github.event_name == 'pull_request_target' && github.event.pull_request.draft == false && steps.check.outputs.skip_validation != 'true' env: PR_BODY: ${{ github.event.pull_request.body }} IS_COMMUNITY_PROVIDER: ${{ steps.check.outputs.is_community_provider }} run: | printf '%s\n' "$PR_BODY" | tr -d '\r' > body.txt # Required sections from the template required=( "# Description" "# How Has This Been Tested?" "# Checklist" ) err=0 # Check for required sections for h in "${required[@]}"; do grep -Fq "$h" body.txt || { echo "::error::$h missing"; err=1; } done # Check for issue reference - relaxed for community provider PRs if [ "$IS_COMMUNITY_PROVIDER" = "true" ]; then # For community provider PRs, accept either "Fixes #" or "Related to #" (case-insensitive) if ! grep -Eiq '(Fixes #[0-9]+|Related to #[0-9]+)' body.txt; then echo "::error::Issue reference missing (need 'Fixes #NNN' or 'Related to #NNN')" err=1 fi else # For other PRs, require "Fixes #" with a number if ! grep -Eq 'Fixes #[0-9]+' body.txt; then echo "::error::Missing 'Fixes #NNN' reference" err=1 fi fi # Check for placeholder text that should be replaced grep -Eiq 'Replace this with|Choose one:' body.txt && { echo "::error::Template placeholders still present"; err=1; } # Also check for the unmodified issue number placeholder grep -Fq 'Fixes #[issue number]' body.txt && { echo "::error::Issue number placeholder not updated"; err=1; } exit $err - name: Log skip reason if: | github.event_name == 'pull_request_target' && (github.event.pull_request.draft == true || steps.check.outputs.skip_validation == 'true') run: | echo "Skipping PR template validation. Draft: ${{ github.event.pull_request.draft }}; skip_validation: ${{ steps.check.outputs.skip_validation || 'N/A' }}" ================================================ FILE: .github/workflows/zenodo-publish.yml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Publish to Zenodo on: release: types: [published] concurrency: group: zenodo-${{ github.ref }} cancel-in-progress: false jobs: zenodo: # Only run on releases from the main repository, not forks # Skip pre-releases to avoid creating DOIs for test releases if: ${{ !github.event.release.prerelease && github.repository == 'google/langextract' }} runs-on: ubuntu-latest timeout-minutes: 15 permissions: contents: read env: ZENODO_TOKEN: ${{ secrets.ZENODO_TOKEN }} ZENODO_RECORD_ID: ${{ secrets.ZENODO_RECORD_ID }} RELEASE_TAG: ${{ github.ref_name }} GITHUB_REPOSITORY: ${{ github.repository }} steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.11' - name: Build distributions run: | python -m pip install --upgrade pip build python -m build - name: Install dependencies run: python -m pip install requests - name: Publish new Zenodo version run: python .github/scripts/zenodo_publish.py ================================================ FILE: .gitignore ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Byte-compiled / Cache files __pycache__/ *.py[cod] *$py.class # Distribution / Packaging build/ dist/ *.egg-info/ .eggs/ eggs/ # Virtual Environments .env .venv env/ venv/ ENV/ *_env/ # Test & Coverage Reports .pytest_cache/ .tox/ htmlcov/ .coverage .coverage.* # Generated Output & Data # LangExtract outputs are defaulted to test_output/ /test_output/ # Sphinx documentation build output docs/_build/ # IDE / Editor specific .idea/ .vscode/ *.swp *.swo *~ .*.swp .*.swo # OS-specific .DS_Store Thumbs.db ehthumbs.db Desktop.ini $RECYCLE.BIN/ *.cab *.msi *.msm *.msp *.lnk # Development tools & environments .python-version .pytype/ .mypy_cache/ .dmypy.json dmypy.json .pyre/ .ruff_cache/ *.sage.py .hypothesis/ .scrapy # Jupyter Notebooks .ipynb_checkpoints */.ipynb_checkpoints/* profile_default/ ipython_config.py # Logs and databases *.log *.sql *.sqlite *.sqlite3 db.sqlite3 db.sqlite3-journal logs/ *.pid # Security and secrets *.key *.pem *.crt *.csr .env.local .env.production .env.*.local secrets/ credentials/ # AI tooling CLAUDE.md .claude/settings.local.json .aider.chat.history.* .aider.input.history .gemini/ GEMINI.md # Package managers pip-log.txt pip-delete-this-directory.txt node_modules/ npm-debug.log* yarn-debug.log* yarn-error.log* .pnpm-debug.log* package-lock.json yarn.lock pnpm-lock.yaml # Local development local_settings.py instance/ .webassets-cache .sass-cache/ *.css.map *.js.map .dev/ # Temporary files tmp/ temp/ cache/ *.tmp *.bak *.backup *.orig .~lock.*# # Archives *.tar *.tar.gz *.zip *.rar *.7z *.dmg *.iso *.jar # Media files *.mp4 *.avi *.mov *.wmv *.flv *.mp3 *.wav *.ogg # Benchmark results and local environment langextract_env/ benchmarks/benchmark_results # Benchmark results in root benchmark_results/**/*.json benchmark_results/**/*.jsonl benchmark_results/**/*.html ================================================ FILE: .pre-commit-config.yaml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Pre-commit hooks for LangExtract # Install with: pre-commit install # Run manually: pre-commit run --all-files repos: - repo: https://github.com/PyCQA/isort rev: 6.0.0 hooks: - id: isort name: isort (import sorting) # Configuration is in pyproject.toml - repo: https://github.com/google/pyink rev: 24.3.0 hooks: - id: pyink name: pyink (Google's Black fork) args: ["--config", "pyproject.toml"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - id: end-of-file-fixer exclude: \.gif$|\.svg$ - id: trailing-whitespace - id: check-yaml - id: check-added-large-files args: ['--maxkb=1000'] - id: check-merge-conflict - id: check-case-conflict - id: mixed-line-ending args: ['--fix=lf'] ================================================ FILE: .pylintrc ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [MASTER] # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. jobs=0 # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. # Note: These plugins require Pylint >= 3.0 load-plugins= pylint.extensions.docparams, pylint.extensions.typing # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time. enable= useless-suppression # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). disable= abstract-method, # Protocol/ABC classes often have abstract methods too-few-public-methods, # Valid for data classes with minimal interface fixme, # TODO/FIXME comments are useful for tracking work # --- Code style and formatting --- line-too-long, # Handled by pyink formatter bad-indentation, # Pyink uses 2-space indentation # --- Design complexity --- too-many-positional-arguments, too-many-locals, too-many-arguments, too-many-branches, too-many-statements, too-many-nested-blocks, # --- Style preferences --- no-else-return, no-else-raise, # --- Documentation --- missing-function-docstring, missing-class-docstring, missing-raises-doc, # --- Gradual improvements --- deprecated-typing-alias, # For typing.Type etc. unspecified-encoding [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. output-format=text # Tells whether to display a full report or only the messages reports=no # Activate the evaluation score. score=no [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=5 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Naming style matching correct attribute names. attr-naming-style=snake_case # Bad variable names which should always be refused, separated by a comma. bad-names=foo,bar,baz,toto,tutu,tata # Naming style matching correct class attribute names. class-attribute-naming-style=any # Naming style matching correct class names. class-naming-style=PascalCase # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 # Naming style matching correct function names. function-naming-style=snake_case # Good variable names which should always be accepted, separated by a comma. good-names=i,j,k,ex,Run,_,id,ok # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted good-names-rgxs=^T[A-Z][a-zA-Z]*$ # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Naming style matching correct method names. method-naming-style=snake_case # Naming style matching correct module names. module-naming-style=snake_case # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Naming style matching correct variable names. variable-naming-style=snake_case [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format=LF # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=2 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=" " # Maximum number of characters on a single line. max-line-length=80 # Maximum number of lines in a module. max-module-lines=2000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [LOGGING] # The type of string formatting that logging methods do. `old` means using % # formatting, `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [SIMILARITIES] # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no # Minimum lines number of a similarity. min-similarity-lines=6 [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package.. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local,dataclasses.InitVar,typing.Any # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules=dotenv,absl,more_itertools,pandas,requests,pydantic,yaml,IPython.display, tqdm,numpy,google,langfun,typing_extensions # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 # List of decorators that change the signature of a decorated function. signature-mutators= [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name # with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp, __post_init__ # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=cls [DESIGN] # Maximum number of arguments for function / method. max-args=7 # Maximum number of attributes for a class (see R0902). max-attributes=10 # Maximum number of boolean expressions in an if statement. max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=12 # Maximum number of locals for function / method body. max-locals=15 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of return / yield for function / method body. max-returns=6 # Maximum number of statements in function / method body. max-statements=50 # Minimum number of public methods for a class (see R0903). min-public-methods=0 [IMPORTS] # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=yes # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules=optparse,tkinter.tix # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled). ext-import-graph= # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled). import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant,numpy,pandas,torch,langfun,pyglove # Couples of modules and preferred modules, separated by a comma. preferred-modules= [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". overgeneral-exceptions=BaseException, Exception ================================================ FILE: CITATION.cff ================================================ # SPDX-FileCopyrightText: 2025 Google LLC # SPDX-License-Identifier: Apache-2.0 # # This file contains citation metadata for LangExtract. # For more information visit: https://citation-file-format.github.io/ cff-version: 1.2.0 title: "LangExtract" message: "If you use this software, please cite it as below." type: software authors: - given-names: Akshay family-names: Goel email: goelak@google.com affiliation: Google LLC repository-code: "https://github.com/google/langextract" url: "https://github.com/google/langextract" repository: "https://github.com/google/langextract" abstract: "LangExtract: LLM-powered structured information extraction from text with source grounding" keywords: - language-models - structured-data-extraction - nlp - machine-learning - python license: Apache-2.0 version: 1.1.1 date-released: 2025-11-27 doi: "10.5281/zenodo.17015089" identifiers: - type: doi value: "10.5281/zenodo.17015089" description: "Concept DOI for LangExtract" ================================================ FILE: COMMUNITY_PROVIDERS.md ================================================ # Community Provider Plugins Community-developed provider plugins that extend LangExtract with additional model backends. **Supporting the Community:** Star plugin repositories you find useful and add 👍 reactions to their tracking issues to support maintainers' efforts. **⚠️ Important:** These are community-maintained packages. Please review the [safety guidelines](#safety-disclaimer) before use. ## Plugin Registry | Plugin Name | PyPI Package | Maintainer | GitHub Repo | Description | Issue Link | |-------------|--------------|------------|-------------|-------------|------------| | AWS Bedrock | `langextract-bedrock` | [@andyxhadji](https://github.com/andyxhadji) | [andyxhadji/langextract-bedrock](https://github.com/andyxhadji/langextract-bedrock) | AWS Bedrock provider for LangExtract, supports all models & inference profiles | [#148](https://github.com/google/langextract/issues/148) | | LiteLLM | `langextract-litellm` | [@JustStas](https://github.com/JustStas) | [JustStas/langextract-litellm](https://github.com/JustStas/langextract-litellm) | LiteLLM provider for LangExtract, supports all models covered in LiteLLM, including OpenAI, Azure, Anthropic, etc., See [LiteLLM's supported models](https://docs.litellm.ai/docs/providers) | [#187](https://github.com/google/langextract/issues/187) | | Llama.cpp | `langextract-llamacpp` | [@fgarnadi](https://github.com/fgarnadi) | [fgarnadi/langextract-llamacpp](https://github.com/fgarnadi/langextract-llamacpp) | Llama.cpp provider for LangExtract, supports GGUF models from HuggingFace and local files | [#199](https://github.com/google/langextract/issues/199) | | Outlines | `langextract-outlines` | [@RobinPicard](https://github.com/RobinPicard) | [dottxt-ai/langextract-outlines](https://github.com/dottxt-ai/langextract-outlines) | Outlines provider for LangExtract, supports structured generation for various local and API-based models | [#101](https://github.com/google/langextract/issues/101) | | vLLM | `langextract-vllm` | [@wuli666](https://github.com/wuli666) | [wuli666/langextract-vllm](https://github.com/wuli666/langextract-vllm) | vLLM provider for LangExtract, supports local and distributed model serving | [#236](https://github.com/google/langextract/issues/236) | ## How to Add Your Plugin (PR Checklist) Copy this row template, replace placeholders, and insert **above** the marker line: ```markdown | Your Plugin | `langextract-provider-yourname` | [@yourhandle](https://github.com/yourhandle) | [yourorg/yourrepo](https://github.com/yourorg/yourrepo) | Brief description (min 10 chars) | [#456](https://github.com/google/langextract/issues/456) | ``` **Before submitting your PR:** - [ ] PyPI package name starts with `langextract-` (recommended: `langextract-provider-`) - [ ] PyPI package is published (or will be soon) and listed in backticks - [ ] Maintainer(s) listed as GitHub profile links (comma-separated if multiple) - [ ] Repository link points to public GitHub repo - [ ] Description clearly explains what your provider does - [ ] Issue Link points to a tracking issue in the LangExtract repository for integration and usage feedback (plugin-specific features and discussions can optionally happen in the plugin's repository) - [ ] Entries are sorted alphabetically by Plugin Name ## Documentation For detailed plugin development instructions, see the [Custom Provider Plugin Example](examples/custom_provider_plugin/README.md). ## Safety Disclaimer Community plugins are independently developed and maintained. While we encourage community contributions, the LangExtract team cannot guarantee the safety, security, or functionality of third-party packages. **Before installing any plugin, we recommend:** - **Review the code** - Examine the source code and dependencies on GitHub - **Check community feedback** - Read issues and discussions for user experiences - **Verify the maintainer** - Look for active maintenance and responsive support - **Test safely** - Try plugins in isolated environments before production use - **Assess security needs** - Consider your specific security requirements Community plugins are used at your own discretion. When in doubt, reach out to the community through the plugin's issue tracker or the main LangExtract discussions. ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute We would love to accept your patches and contributions to this project. ## Before you begin ### Sign our Contributor License Agreement Contributions to this project must be accompanied by a [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. If you or your current employer have already signed the Google CLA (even if it was for a different project), you probably don't need to do it again. Visit to see your current agreements or to sign a new one. ### Review our Community Guidelines This project follows HAI-DEF's [Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines) ## Reporting Issues If you encounter a bug or have a feature request, please open an issue on GitHub. We have templates to help guide you: - **[Bug Report](.github/ISSUE_TEMPLATE/1-bug.md)**: For reporting bugs or unexpected behavior - **[Feature Request](.github/ISSUE_TEMPLATE/2-feature-request.md)**: For suggesting new features or improvements When creating an issue, GitHub will prompt you to choose the appropriate template. Please provide as much detail as possible to help us understand and address your concern. ## Contribution Process ### 1. Development Setup To get started, clone the repository and install the necessary dependencies for development and testing. Detailed instructions can be found in the [Installation from Source](https://github.com/google/langextract#from-source) section of the `README.md`. **Windows Users**: The formatting scripts use bash. Please use one of: - Git Bash (comes with Git for Windows) - WSL (Windows Subsystem for Linux) - PowerShell with bash-compatible commands ### 2. Code Style and Formatting This project uses automated tools to maintain a consistent code style. Before submitting a pull request, please format your code: ```bash # Run the auto-formatter ./autoformat.sh ``` This script uses: - `isort` to organize imports with Google style (single-line imports) - `pyink` (Google's fork of Black) to format code according to Google's Python Style Guide You can also run the formatters manually: ```bash isort langextract tests pyink langextract tests --config pyproject.toml ``` Note: The formatters target only `langextract` and `tests` directories by default to avoid formatting virtual environments or other non-source directories. ### 3. Pre-commit Hooks (Recommended) For automatic formatting checks before each commit: ```bash # Install pre-commit pip install pre-commit # Install the git hooks pre-commit install # Run manually on all files pre-commit run --all-files ``` ### 4. Linting and Testing All contributions must pass linting checks and unit tests. Please run these locally before submitting your changes: ```bash # Run linting with Pylint 3.x pylint --rcfile=.pylintrc langextract tests # Run tests pytest tests ``` **Note on Pylint Configuration**: We use a modern, minimal configuration that: - Only disables truly noisy checks (not entire categories) - Keeps critical error detection enabled - Uses plugins for enhanced docstring and type checking - Aligns with our pyink formatter (80-char lines, 2-space indents) For full testing across Python versions: ```bash tox # runs pylint + pytest on Python 3.10 and 3.11 ``` ### 5. Adding Custom Model Providers If you want to add support for a new LLM provider, please refer to the [Provider System Documentation](langextract/providers/README.md). The recommended approach is to create an external plugin package rather than modifying the core library. This allows for: - Independent versioning and releases - Faster iteration without core review cycles - Custom dependencies without affecting core users ### 6. Submit Your Pull Request All submissions, including submissions by project members, require review. We use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) for this purpose. When you create a pull request, GitHub will automatically populate it with our [pull request template](.github/PULL_REQUEST_TEMPLATE/pull_request_template.md). Please fill out all sections of the template to help reviewers understand your changes. #### Pull Request Guidelines - **Keep PRs focused and small**: Each PR should address a single issue and contain one cohesive change. PRs are automatically labeled by size to help reviewers: - **size/XS**: < 50 lines — Small fixes and documentation updates - **size/S**: 50-150 lines — Typical features or bug fixes - **size/M**: 150-600 lines — Larger features that remain well-scoped - **size/L**: 600-1000 lines — Consider splitting into smaller PRs if possible - **size/XL**: > 1000 lines — Requires strong justification and may need special review - **Reference related issues**: All PRs must include "Fixes #123" or "Closes #123" in the description. The linked issue should have at least 5 👍 reactions from the community and include discussion that demonstrates the importance and need for the change. - **No infrastructure changes**: Contributors cannot modify infrastructure files, build configuration, and core documentation. These files are protected and can only be changed by maintainers. Use `./autoformat.sh` to format code without affecting infrastructure files. In special circumstances, build configuration updates may be considered if they include discussion and evidence of robust testing, ideally with community support. - **Single-change commits**: A PR should typically comprise a single git commit. Squash multiple commits before submitting. - **Clear description**: Explain what your change does and why it's needed. - **Ensure all tests pass**: Check that both formatting and tests are green before requesting review. - **Respond to feedback promptly**: Address reviewer comments in a timely manner. If your change is large or complex, consider: - Opening an issue first to discuss the approach - Breaking it into multiple smaller PRs - Clearly explaining in the PR description why a larger change is necessary For more details, read HAI-DEF's [Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing) ================================================ FILE: Dockerfile ================================================ # Production Dockerfile for LangExtract FROM python:3.10-slim # Set working directory WORKDIR /app # Install LangExtract from PyPI RUN pip install --no-cache-dir langextract # Set default command CMD ["python"] ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

LangExtract Logo

# LangExtract [![PyPI version](https://img.shields.io/pypi/v/langextract.svg)](https://pypi.org/project/langextract/) [![GitHub stars](https://img.shields.io/github/stars/google/langextract.svg?style=social&label=Star)](https://github.com/google/langextract) ![Tests](https://github.com/google/langextract/actions/workflows/ci.yaml/badge.svg) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17015089.svg)](https://doi.org/10.5281/zenodo.17015089) ## Table of Contents - [Introduction](#introduction) - [Why LangExtract?](#why-langextract) - [Quick Start](#quick-start) - [Installation](#installation) - [API Key Setup for Cloud Models](#api-key-setup-for-cloud-models) - [Adding Custom Model Providers](#adding-custom-model-providers) - [Using OpenAI Models](#using-openai-models) - [Using Local LLMs with Ollama](#using-local-llms-with-ollama) - [More Examples](#more-examples) - [*Romeo and Juliet* Full Text Extraction](#romeo-and-juliet-full-text-extraction) - [Medication Extraction](#medication-extraction) - [Radiology Report Structuring: RadExtract](#radiology-report-structuring-radextract) - [Community Providers](#community-providers) - [Contributing](#contributing) - [Testing](#testing) - [Disclaimer](#disclaimer) ## Introduction LangExtract is a Python library that uses LLMs to extract structured information from unstructured text documents based on user-defined instructions. It processes materials such as clinical notes or reports, identifying and organizing key details while ensuring the extracted data corresponds to the source text. ## Why LangExtract? 1. **Precise Source Grounding:** Maps every extraction to its exact location in the source text, enabling visual highlighting for easy traceability and verification. 2. **Reliable Structured Outputs:** Enforces a consistent output schema based on your few-shot examples, leveraging controlled generation in supported models like Gemini to guarantee robust, structured results. 3. **Optimized for Long Documents:** Overcomes the "needle-in-a-haystack" challenge of large document extraction by using an optimized strategy of text chunking, parallel processing, and multiple passes for higher recall. 4. **Interactive Visualization:** Instantly generates a self-contained, interactive HTML file to visualize and review thousands of extracted entities in their original context. 5. **Flexible LLM Support:** Supports your preferred models, from cloud-based LLMs like the Google Gemini family to local open-source models via the built-in Ollama interface. 6. **Adaptable to Any Domain:** Define extraction tasks for any domain using just a few examples. LangExtract adapts to your needs without requiring any model fine-tuning. 7. **Leverages LLM World Knowledge:** Utilize precise prompt wording and few-shot examples to influence how the extraction task may utilize LLM knowledge. The accuracy of any inferred information and its adherence to the task specification are contingent upon the selected LLM, the complexity of the task, the clarity of the prompt instructions, and the nature of the prompt examples. ## Quick Start > **Note:** Using cloud-hosted models like Gemini requires an API key. See the [API Key Setup](#api-key-setup-for-cloud-models) section for instructions on how to get and configure your key. Extract structured information with just a few lines of code. ### 1. Define Your Extraction Task First, create a prompt that clearly describes what you want to extract. Then, provide a high-quality example to guide the model. ```python import langextract as lx import textwrap # 1. Define the prompt and extraction rules prompt = textwrap.dedent("""\ Extract characters, emotions, and relationships in order of appearance. Use exact text for extractions. Do not paraphrase or overlap entities. Provide meaningful attributes for each entity to add context.""") # 2. Provide a high-quality example to guide the model examples = [ lx.data.ExampleData( text="ROMEO. But soft! What light through yonder window breaks? It is the east, and Juliet is the sun.", extractions=[ lx.data.Extraction( extraction_class="character", extraction_text="ROMEO", attributes={"emotional_state": "wonder"} ), lx.data.Extraction( extraction_class="emotion", extraction_text="But soft!", attributes={"feeling": "gentle awe"} ), lx.data.Extraction( extraction_class="relationship", extraction_text="Juliet is the sun", attributes={"type": "metaphor"} ), ] ) ] ``` > **Note:** Examples drive model behavior. Each `extraction_text` should ideally be verbatim from the example's `text` (no paraphrasing), listed in order of appearance. LangExtract raises `Prompt alignment` warnings by default if examples don't follow this pattern—resolve these for best results. ### 2. Run the Extraction Provide your input text and the prompt materials to the `lx.extract` function. ```python # The input text to be processed input_text = "Lady Juliet gazed longingly at the stars, her heart aching for Romeo" # Run the extraction result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id="gemini-2.5-flash", ) ``` > **Model Selection**: `gemini-2.5-flash` is the recommended default, offering an excellent balance of speed, cost, and quality. For highly complex tasks requiring deeper reasoning, `gemini-2.5-pro` may provide superior results. For large-scale or production use, a Tier 2 Gemini quota is suggested to increase throughput and avoid rate limits. See the [rate-limit documentation](https://ai.google.dev/gemini-api/docs/rate-limits#tier-2) for details. > > **Model Lifecycle**: Note that Gemini models have a lifecycle with defined retirement dates. Users should consult the [official model version documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions) to stay informed about the latest stable and legacy versions. ### 3. Visualize the Results The extractions can be saved to a `.jsonl` file, a popular format for working with language model data. LangExtract can then generate an interactive HTML visualization from this file to review the entities in context. ```python # Save the results to a JSONL file lx.io.save_annotated_documents([result], output_name="extraction_results.jsonl", output_dir=".") # Generate the visualization from the file html_content = lx.visualize("extraction_results.jsonl") with open("visualization.html", "w") as f: if hasattr(html_content, 'data'): f.write(html_content.data) # For Jupyter/Colab else: f.write(html_content) ``` This creates an animated and interactive HTML file: ![Romeo and Juliet Basic Visualization ](https://raw.githubusercontent.com/google/langextract/main/docs/_static/romeo_juliet_basic.gif) > **Note on LLM Knowledge Utilization:** This example demonstrates extractions that stay close to the text evidence - extracting "longing" for Lady Juliet's emotional state and identifying "yearning" from "gazed longingly at the stars." The task could be modified to generate attributes that draw more heavily from the LLM's world knowledge (e.g., adding `"identity": "Capulet family daughter"` or `"literary_context": "tragic heroine"`). The balance between text-evidence and knowledge-inference is controlled by your prompt instructions and example attributes. ### Scaling to Longer Documents For larger texts, you can process entire documents directly from URLs with parallel processing and enhanced sensitivity: ```python # Process Romeo & Juliet directly from Project Gutenberg result = lx.extract( text_or_documents="https://www.gutenberg.org/files/1513/1513-0.txt", prompt_description=prompt, examples=examples, model_id="gemini-2.5-flash", extraction_passes=3, # Improves recall through multiple passes max_workers=20, # Parallel processing for speed max_char_buffer=1000 # Smaller contexts for better accuracy ) ``` This approach can extract hundreds of entities from full novels while maintaining high accuracy. The interactive visualization seamlessly handles large result sets, making it easy to explore hundreds of entities from the output JSONL file. **[See the full *Romeo and Juliet* extraction example →](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** for detailed results and performance insights. ### Vertex AI Batch Processing Save costs on large-scale tasks by enabling Vertex AI Batch API: `language_model_params={"vertexai": True, "batch": {"enabled": True}}`. See an example of the Vertex AI Batch API usage in [this example](docs/examples/batch_api_example.md). ## Installation ### From PyPI ```bash pip install langextract ``` *Recommended for most users. For isolated environments, consider using a virtual environment:* ```bash python -m venv langextract_env source langextract_env/bin/activate # On Windows: langextract_env\Scripts\activate pip install langextract ``` ### From Source LangExtract uses modern Python packaging with `pyproject.toml` for dependency management: *Installing with `-e` puts the package in development mode, allowing you to modify the code without reinstalling.* ```bash git clone https://github.com/google/langextract.git cd langextract # For basic installation: pip install -e . # For development (includes linting tools): pip install -e ".[dev]" # For testing (includes pytest): pip install -e ".[test]" ``` ### Docker ```bash docker build -t langextract . docker run --rm -e LANGEXTRACT_API_KEY="your-api-key" langextract python your_script.py ``` ## API Key Setup for Cloud Models When using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to set up an API key. On-device models don't require an API key. For developers using local LLMs, LangExtract offers built-in support for Ollama and can be extended to other third-party APIs by updating the inference endpoints. ### API Key Sources Get API keys from: * [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models * [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use * [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models ### Setting up API key in your environment **Option 1: Environment Variable** ```bash export LANGEXTRACT_API_KEY="your-api-key-here" ``` **Option 2: .env File (Recommended)** Add your API key to a `.env` file: ```bash # Add API key to .env file cat >> .env << 'EOF' LANGEXTRACT_API_KEY=your-api-key-here EOF # Keep your API key secure echo '.env' >> .gitignore ``` In your Python code: ```python import langextract as lx result = lx.extract( text_or_documents=input_text, prompt_description="Extract information...", examples=[...], model_id="gemini-2.5-flash" ) ``` **Option 3: Direct API Key (Not Recommended for Production)** You can also provide the API key directly in your code, though this is not recommended for production use: ```python result = lx.extract( text_or_documents=input_text, prompt_description="Extract information...", examples=[...], model_id="gemini-2.5-flash", api_key="your-api-key-here" # Only use this for testing/development ) ``` **Option 4: Vertex AI (Service Accounts)** Use [Vertex AI](https://cloud.google.com/vertex-ai/docs/start/introduction-unified-platform) for authentication with service accounts: ```python result = lx.extract( text_or_documents=input_text, prompt_description="Extract information...", examples=[...], model_id="gemini-2.5-flash", language_model_params={ "vertexai": True, "project": "your-project-id", "location": "global" # or regional endpoint } ) ``` ## Adding Custom Model Providers LangExtract supports custom LLM providers via a lightweight plugin system. You can add support for new models without changing core code. - Add new model support independently of the core library - Distribute your provider as a separate Python package - Keep custom dependencies isolated - Override or extend built-in providers via priority-based resolution See the detailed guide in [Provider System Documentation](langextract/providers/README.md) to learn how to: - Register a provider with `@registry.register(...)` - Publish an entry point for discovery - Optionally provide a schema with `get_schema_class()` for structured output - Integrate with the factory via `create_model(...)` ## Using OpenAI Models LangExtract supports OpenAI models (requires optional dependency: `pip install langextract[openai]`): ```python import langextract as lx result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id="gpt-4o", # Automatically selects OpenAI provider api_key=os.environ.get('OPENAI_API_KEY'), fence_output=True, use_schema_constraints=False ) ``` Note: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet. ## Using Local LLMs with Ollama LangExtract supports local inference using Ollama, allowing you to run models without API keys: ```python import langextract as lx result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id="gemma2:2b", # Automatically selects Ollama provider model_url="http://localhost:11434", fence_output=False, use_schema_constraints=False ) ``` **Quick setup:** Install Ollama from [ollama.com](https://ollama.com/), run `ollama pull gemma2:2b`, then `ollama serve`. For detailed installation, Docker setup, and examples, see [`examples/ollama/`](examples/ollama/). ## More Examples Additional examples of LangExtract in action: ### *Romeo and Juliet* Full Text Extraction LangExtract can process complete documents directly from URLs. This example demonstrates extraction from the full text of *Romeo and Juliet* from Project Gutenberg (147,843 characters), showing parallel processing, sequential extraction passes, and performance optimization for long document processing. **[View *Romeo and Juliet* Full Text Example →](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** ### Medication Extraction > **Disclaimer:** This demonstration is for illustrative purposes of LangExtract's baseline capability only. It does not represent a finished or approved product, is not intended to diagnose or suggest treatment of any disease or condition, and should not be used for medical advice. LangExtract excels at extracting structured medical information from clinical text. These examples demonstrate both basic entity recognition (medication names, dosages, routes) and relationship extraction (connecting medications to their attributes), showing LangExtract's effectiveness for healthcare applications. **[View Medication Examples →](https://github.com/google/langextract/blob/main/docs/examples/medication_examples.md)** ### Radiology Report Structuring: RadExtract Explore RadExtract, a live interactive demo on HuggingFace Spaces that shows how LangExtract can automatically structure radiology reports. Try it directly in your browser with no setup required. **[View RadExtract Demo →](https://huggingface.co/spaces/google/radextract)** ## Community Providers Extend LangExtract with custom model providers! Check out our [Community Provider Plugins](COMMUNITY_PROVIDERS.md) registry to discover providers created by the community or add your own. For detailed instructions on creating a provider plugin, see the [Custom Provider Plugin Example](examples/custom_provider_plugin/). ## Contributing Contributions are welcome! See [CONTRIBUTING.md](https://github.com/google/langextract/blob/main/CONTRIBUTING.md) to get started with development, testing, and pull requests. You must sign a [Contributor License Agreement](https://cla.developers.google.com/about) before submitting patches. ## Testing To run tests locally from the source: ```bash # Clone the repository git clone https://github.com/google/langextract.git cd langextract # Install with test dependencies pip install -e ".[test]" # Run all tests pytest tests ``` Or reproduce the full CI matrix locally with tox: ```bash tox # runs pylint + pytest on Python 3.10 and 3.11 ``` ### Ollama Integration Testing If you have Ollama installed locally, you can run integration tests: ```bash # Test Ollama integration (requires Ollama running with gemma2:2b model) tox -e ollama-integration ``` This test will automatically detect if Ollama is available and run real inference tests. ## Development ### Code Formatting This project uses automated formatting tools to maintain consistent code style: ```bash # Auto-format all code ./autoformat.sh # Or run formatters separately isort langextract tests --profile google --line-length 80 pyink langextract tests --config pyproject.toml ``` ### Pre-commit Hooks For automatic formatting checks: ```bash pre-commit install # One-time setup pre-commit run --all-files # Manual run ``` ### Linting Run linting before submitting PRs: ```bash pylint --rcfile=.pylintrc langextract tests ``` See [CONTRIBUTING.md](CONTRIBUTING.md) for full development guidelines. ## Disclaimer This is not an officially supported Google product. If you use LangExtract in production or publications, please cite accordingly and acknowledge usage. Use is subject to the [Apache 2.0 License](https://github.com/google/langextract/blob/main/LICENSE). For health-related applications, use of LangExtract is also subject to the [Health AI Developer Foundations Terms of Use](https://developers.google.com/health-ai-developer-foundations/terms). --- **Happy Extracting!** ================================================ FILE: autoformat.sh ================================================ #!/bin/bash # Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Autoformat LangExtract codebase # # Usage: ./autoformat.sh [target_directory ...] # If no target is specified, formats the current directory # # This script runs: # 1. isort for import sorting # 2. pyink (Google's Black fork) for code formatting # 3. pre-commit hooks for additional formatting (trailing whitespace, end-of-file, etc.) set -e echo "LangExtract Auto-formatter" echo "==========================" echo # Check for required tools check_tool() { if ! command -v "$1" &> /dev/null; then echo "Error: $1 not found. Please install with: pip install $1" exit 1 fi } check_tool "isort" check_tool "pyink" check_tool "pre-commit" # Parse command line arguments show_usage() { echo "Usage: $0 [target_directory ...]" echo echo "Formats Python code using isort and pyink according to Google style." echo echo "Arguments:" echo " target_directory One or more directories to format (default: langextract tests)" echo echo "Examples:" echo " $0 # Format langextract and tests directories" echo " $0 langextract # Format only langextract directory" echo " $0 src tests # Format multiple specific directories" } # Check for help flag if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then show_usage exit 0 fi # Determine target directories if [ $# -eq 0 ]; then TARGETS="langextract tests" echo "No target specified. Formatting default directories: langextract tests" else TARGETS="$@" echo "Formatting targets: $TARGETS" fi # Find pyproject.toml relative to script location SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" CONFIG_FILE="${SCRIPT_DIR}/pyproject.toml" if [ ! -f "$CONFIG_FILE" ]; then echo "Warning: pyproject.toml not found at ${CONFIG_FILE}" echo "Using default configuration." CONFIG_ARG="" else CONFIG_ARG="--config $CONFIG_FILE" fi echo # Run isort echo "Running isort to organize imports..." if isort $TARGETS; then echo "Import sorting complete" else echo "Import sorting failed" exit 1 fi echo # Run pyink echo "Running pyink to format code (Google style, 80 chars)..." if pyink $TARGETS $CONFIG_ARG; then echo "Code formatting complete" else echo "Code formatting failed" exit 1 fi echo # Run pre-commit hooks for additional formatting echo "Running pre-commit hooks for additional formatting..." if pre-commit run --all-files; then echo "Pre-commit hooks passed" else echo "Pre-commit hooks made changes - please review" # Exit with success since formatting was applied exit 0 fi echo echo "All formatting complete!" echo echo "Next steps:" echo " - Run: pylint --rcfile=${SCRIPT_DIR}/.pylintrc $TARGETS" echo " - Commit your changes" ================================================ FILE: benchmarks/benchmark.py ================================================ #!/usr/bin/env python3 # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LangExtract benchmark suite for performance and quality testing. Measures tokenization speed and extraction quality across multiple languages and text types. Automatically downloads test texts from Project Gutenberg and generates comparative visualizations. Usage: # Run diverse text type benchmark (default) python benchmarks/benchmark.py # Test with specific model python benchmarks/benchmark.py --model gemini-2.5-flash python benchmarks/benchmark.py --model gemma2:2b # Local model via Ollama # Generate comparison plots from existing results python benchmarks/benchmark.py --compare Requirements: - Set GEMINI_API_KEY for cloud models - Install Ollama for local model testing - Results saved to benchmark_results/ """ import argparse from datetime import datetime import json import os from pathlib import Path import time from typing import Any import urllib.error import dotenv from benchmarks import config from benchmarks import plotting from benchmarks import utils import langextract from langextract import core from langextract import data from langextract import visualize import langextract.io as lio # Load API key from environment dotenv.load_dotenv(override=True) GEMINI_API_KEY = os.environ.get( "GEMINI_API_KEY", os.environ.get("LANGEXTRACT_API_KEY") ) class BenchmarkRunner: """Orchestrates benchmark execution and result collection.""" def __init__(self): """Initialize runner with timestamp and git metadata.""" self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.git_info = utils.get_git_info() self.tokenizer = core.tokenizer.RegexTokenizer() def set_tokenizer(self, tokenizer_type: str): """Set the tokenizer to use.""" if tokenizer_type.lower() == "unicode": self.tokenizer = core.tokenizer.UnicodeTokenizer() print("Using UnicodeTokenizer") else: self.tokenizer = core.tokenizer.RegexTokenizer() print("Using RegexTokenizer (default)") def print_header(self): """Print benchmark header.""" print("=" * config.DISPLAY.separator_width) print("LANGEXTRACT BENCHMARK") print("=" * config.DISPLAY.separator_width) print( f"Branch: {self.git_info['branch']} | Commit: {self.git_info['commit']}" ) print("-" * config.DISPLAY.separator_width) def benchmark_tokenization(self) -> list[dict[str, Any]]: """Measure tokenization throughput at different text sizes. Returns: List of dicts with words, tokens, timing, and throughput metrics. """ print("\nTokenization Performance") print("-" * config.DISPLAY.subseparator_width) results = [] for word_count in config.TOKENIZATION.default_text_sizes: text = " ".join(["word"] * word_count) _ = self.tokenizer.tokenize(text) times = [] for _ in range(config.TOKENIZATION.benchmark_iterations): start = time.perf_counter() tokenized = self.tokenizer.tokenize(text) elapsed = time.perf_counter() - start times.append(elapsed) avg_time = sum(times) / len(times) avg_ms = avg_time * 1000 num_tokens = len(tokenized.tokens) tokens_per_sec = num_tokens / avg_time if avg_time > 0 else 0 word_str = ( f"{word_count//1000:,}k" if word_count >= 1000 else f"{word_count:,}" ) print( f"{word_str:>6} words: {avg_ms:7.2f}ms " f"({tokens_per_sec/1e6:.1f}M tokens/sec)" ) results.append({ "words": word_count, "tokens": num_tokens, "avg_ms": avg_ms, "tokens_per_sec": tokens_per_sec, }) return results def test_single_extraction( self, model_id: str = config.MODELS.default_model, text_type: config.TextTypes = config.TextTypes.ENGLISH, ) -> dict[str, Any]: """Execute extraction test. Args: model_id: Model identifier (e.g., 'gemini-2.5-flash', 'gemma2:2b'). text_type: Language/text type to test. Returns: Dict with success status, timing, entity counts, and metrics. """ print("\nExtraction Test") print("-" * config.DISPLAY.subseparator_width) try: # Get test text test_text = utils.get_text_from_gutenberg(text_type) test_text = utils.get_optimal_text_size(test_text, model_id) print(f" Text: {len(test_text):,} characters ({text_type.value})") print(f" Model: {model_id}") # Analyze tokenization tokenization_analysis = utils.analyze_tokenization( test_text, self.tokenizer ) print( " Tokenization:" f" {utils.format_tokenization_summary(tokenization_analysis)}" ) # Get extraction config for text type extraction_config = utils.get_extraction_example(text_type) example = data.ExampleData( text="MACBETH speaks to LADY MACBETH about Duncan.", extractions=[ data.Extraction( extraction_text="Macbeth", extraction_class="Character" ), data.Extraction( extraction_text="Lady Macbeth", extraction_class="Character" ), data.Extraction( extraction_text="Duncan", extraction_class="Character" ), ], ) max_retries = 5 retry_delay = 3.0 # Retry logic for transient network/API failures for attempt in range(max_retries): try: start_time = time.time() result = langextract.extract( text_or_documents=test_text, model_id=model_id, api_key=GEMINI_API_KEY, prompt_description=extraction_config["prompt"], examples=[example], max_workers=config.MODELS.default_max_workers, temperature=config.MODELS.default_temperature, extraction_passes=config.MODELS.default_extraction_passes, tokenizer=self.tokenizer, ) elapsed = time.time() - start_time break except (ConnectionError, TimeoutError): if attempt < max_retries - 1: print(f" Retrying in {retry_delay}s...") time.sleep(retry_delay) retry_delay *= 1.5 continue raise print(f"Extraction completed in {elapsed:.1f}s") grounded_entities = [] ungrounded_entities = [] if result.extractions: for extraction in result.extractions: is_grounded = ( extraction.char_interval and extraction.char_interval.start_pos is not None and extraction.char_interval.end_pos is not None ) entity_text = extraction.extraction_text if entity_text: if is_grounded: grounded_entities.append(entity_text) else: ungrounded_entities.append(entity_text) unique_grounded = list(set(grounded_entities)) unique_ungrounded = list(set(ungrounded_entities)) print(f"Found {len(unique_grounded)} grounded entities") if unique_ungrounded: print(f" ({len(unique_ungrounded)} ungrounded entities ignored)") if unique_grounded: sample = unique_grounded[:5] sample_str = ", ".join(sample) + ( "..." if len(unique_grounded) > 5 else "" ) print(f" Sample: {sample_str}") return { "success": True, "model": model_id, "text_type": text_type.value, "time_seconds": elapsed, "entity_count": len(unique_grounded), "ungrounded_count": len(unique_ungrounded), "sample_entities": unique_grounded[:10], "tokenization": tokenization_analysis, config.EXTRACTION_RESULT_KEY: result, } except (urllib.error.URLError, RuntimeError) as e: # Handle expected text download failures. print(f"Failed: {e}") return { "success": False, "model": model_id, "text_type": text_type.value, "error": str(e), } def test_diverse_text_types( self, models: list[str] | None = None ) -> list[dict[str, Any]]: """Test extraction with diverse text types.""" print("\n" + "=" * config.DISPLAY.separator_width) print("DIVERSE TEXT TYPE MODE") print("=" * config.DISPLAY.separator_width) if models is None: models = [config.MODELS.default_model] results = [] test_count = 0 for model_id in models: print(f"\nTesting {model_id}") print("-" * 30) for text_type in config.TextTypes: print(f"\n Testing {text_type.value} text...") result = self.test_single_extraction(model_id, text_type) results.append(result) if result.get("success"): test_count += 1 if test_count % 3 == 0: print( " Rate limit delay" f" ({config.MODELS.gemini_rate_limit_delay}s)..." ) time.sleep(config.MODELS.gemini_rate_limit_delay) print(f"\nCompleted {test_count} successful tests") return results def save_results(self, results: dict[str, Any]): """Save results and create plots.""" results["timestamp"] = self.timestamp results["git"] = self.git_info json_path = config.PATHS.get_result_path(self.timestamp, "").with_suffix( ".json" ) viz_dir = json_path.parent / "visualizations" / self.timestamp viz_dir.mkdir(parents=True, exist_ok=True) if config.RESULTS_KEY in results: print(f"\nGenerating visualizations in: {viz_dir}") for result in results[config.RESULTS_KEY]: if result.get("success") and config.EXTRACTION_RESULT_KEY in result: model_name = result["model"].replace("/", "_").replace(":", "_") text_type = result["text_type"] viz_name = f"{model_name}_{text_type}" jsonl_path = viz_dir / f"{viz_name}.jsonl" lio.save_annotated_documents( [result[config.EXTRACTION_RESULT_KEY]], output_name=jsonl_path.name, output_dir=str(viz_dir), ) html_content = visualize(str(jsonl_path)) html_path = viz_dir / f"{viz_name}.html" with open(html_path, "w") as f: f.write(getattr(html_content, "data", html_content)) # Remove extraction result objects before saving JSON for result in results.get(config.RESULTS_KEY, []): result.pop(config.EXTRACTION_RESULT_KEY, None) with open(json_path, "w") as f: json.dump(results, f, indent=2, default=str) print(f"\nResults saved to: {json_path}") plot_created = plotting.create_diverse_plots(results, json_path) if plot_created: print(f"Plot saved to: {json_path.with_suffix('.png')}") else: print(f"Warning: Failed to create plot for {json_path.name}") def run_diverse_benchmark(self, models: list[str] | None = None): """Run benchmark.""" self.print_header() tokenization_results = self.benchmark_tokenization() diverse_results = self.test_diverse_text_types(models) results = { "tokenization": tokenization_results, config.RESULTS_KEY: diverse_results, } self.save_results(results) def main(): """Main entry point.""" parser = argparse.ArgumentParser(description="LangExtract Benchmark Suite") parser.add_argument( "--model", type=str, default=None, help=f"Model to use (default: {config.MODELS.default_model})", ) parser.add_argument( "--tokenizer", type=str, choices=["regex", "unicode"], default="regex", help="Tokenizer to use (default: regex)", ) parser.add_argument( "--compare", action="store_true", help="Generate comparison plots from existing benchmark results", ) args = parser.parse_args() # Handle comparison mode if args.compare: results_dir = Path("benchmark_results") json_files = sorted(results_dir.glob("benchmark_*.json")) if len(json_files) < 2: print( "Need at least 2 benchmark results for comparison, found" f" {len(json_files)}" ) return print(f"Found {len(json_files)} benchmark results to compare") # Use last 10 results or all if less than 10 files_to_compare = json_files[-10:] comparison_path = ( results_dir / f"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" ) plotting.create_comparison_plots(files_to_compare, comparison_path) print(f"\nComparison plot saved to: {comparison_path}") return model_to_test = args.model or config.MODELS.default_model if "gemini" in model_to_test.lower() and not GEMINI_API_KEY: print( f"Error: {model_to_test} requires GEMINI_API_KEY or LANGEXTRACT_API_KEY" ) return runner = BenchmarkRunner() runner.set_tokenizer(args.tokenizer) runner.run_diverse_benchmark([args.model] if args.model else None) if __name__ == "__main__": main() ================================================ FILE: benchmarks/config.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark configuration settings and constants. Centralized configuration for tokenization tests, model parameters, display formatting, and test text sources. """ from dataclasses import dataclass import enum from pathlib import Path # Result dictionary keys RESULTS_KEY = "results" EXTRACTION_KEY = "extraction" EXTRACTION_RESULT_KEY = "extraction_result" TOKENIZATION_KEY = "tokenization" @dataclass(frozen=True) class TokenizationConfig: """Settings for tokenization performance tests.""" default_text_sizes: tuple[int, ...] = (100, 1000, 10000) # Word counts benchmark_iterations: int = 10 # Iterations per size for averaging @dataclass(frozen=True) class ModelConfig: """Model and API configuration.""" default_model: str = "gemini-2.5-flash" # Cloud model default local_model: str = "gemma2:9b" # Ollama model default default_temperature: float = 0.0 # Deterministic output default_max_workers: int = 10 # Parallel processing threads default_extraction_passes: int = 1 # Single pass extraction gemini_rate_limit_delay: float = 8.0 # Seconds between batches class TextTypes(str, enum.Enum): """Supported languages for extraction testing.""" ENGLISH = "english" JAPANESE = "japanese" FRENCH = "french" SPANISH = "spanish" # Test texts from Project Gutenberg (similar genres for fair comparison) GUTENBERG_TEXTS = { TextTypes.ENGLISH: ( "https://www.gutenberg.org/files/11/11-0.txt" ), # Alice's Adventures TextTypes.JAPANESE: ( "https://www.gutenberg.org/files/1982/1982-0.txt" ), # Rashomon TextTypes.FRENCH: ( "https://www.gutenberg.org/files/55456/55456-0.txt" ), # Alice (French) TextTypes.SPANISH: ( "https://www.gutenberg.org/files/67248/67248-0.txt" ), # El clavo } @dataclass(frozen=True) class DisplayConfig: """Display configuration.""" separator_width: int = 50 subseparator_width: int = 40 figure_size_single: tuple[int, int] = (12, 5) figure_size_multi: tuple[int, int] = (14, 10) plot_style: str = "seaborn-v0_8-darkgrid" @dataclass(frozen=True) class PathConfig: """Path configuration.""" results_dir: Path = Path("benchmark_results") def get_result_path(self, timestamp: str, suffix: str = "") -> Path: """Get result file path.""" if not self.results_dir.exists(): self.results_dir.mkdir(parents=True) filename = f"benchmark{suffix}_{timestamp}" return self.results_dir / filename # Global config instances TOKENIZATION = TokenizationConfig() MODELS = ModelConfig() DISPLAY = DisplayConfig() PATHS = PathConfig() ================================================ FILE: benchmarks/plotting.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Visualization generation for benchmark results. Creates multi-panel plots showing tokenization performance, extraction metrics, and cross-language comparisons. """ from datetime import datetime import json from pathlib import Path from typing import Any import matplotlib import matplotlib.pyplot as plt import numpy as np from benchmarks import config matplotlib.use("Agg") plt.style.use(config.DISPLAY.plot_style) def create_diverse_plots(results: dict[str, Any], filepath: Path) -> bool: """Generate comprehensive benchmark visualization. Args: results: Benchmark results dictionary with tokenization and extraction data. filepath: Output path for PNG file. Returns: True if plot created successfully, False on error. """ try: fig = plt.figure(figsize=(15, 10)) # Create 2x3 grid: tokenization metrics (top), extraction metrics (bottom) gs = fig.add_gridspec(2, 3, hspace=0.25, wspace=0.25) ax1 = fig.add_subplot(gs[0, 0]) # Tokenization throughput ax2 = fig.add_subplot(gs[0, 1]) # Token density by language ax3 = fig.add_subplot(gs[0, 2]) # Entity extraction counts ax4 = fig.add_subplot(gs[1, 0]) # Processing speed ax5 = fig.add_subplot(gs[1, 1]) # Summary metrics ax6 = fig.add_subplot(gs[1, 2]) # Unused fig.suptitle( f"LangExtract Benchmark - {results['timestamp']}", fontsize=14, y=0.98 ) _plot_tokenization_throughput(ax1, results) _plot_tokenization_rate(ax2, results) _plot_extraction_density(ax3, results) _plot_processing_speed(ax4, results) _plot_summary_table(ax5, results) ax6.axis("off") plt.tight_layout(rect=[0, 0.02, 1, 0.96]) plot_path = filepath.with_suffix(".png") plt.savefig(plot_path, dpi=100, bbox_inches="tight") plt.close() print(f"Plot saved to: {plot_path}") return True except (IOError, OSError) as e: print(f"Warning: Could not create benchmark plot: {e}") return False def _plot_tokenization_throughput(ax, results): """Plot tokenization throughput (tokens per second) on log scale.""" if ( config.TOKENIZATION_KEY not in results or not results[config.TOKENIZATION_KEY] ): ax.text(0.5, 0.5, "No tokenization data", ha="center", va="center") ax.set_title("Tokenization Throughput") return sizes = [r["words"] for r in results[config.TOKENIZATION_KEY]] speeds = [r["tokens_per_sec"] for r in results[config.TOKENIZATION_KEY]] ax.semilogx(sizes, speeds, "b-o", linewidth=2, markersize=8) ax.set_xlabel("Number of Words (log scale)") ax.set_ylabel("Tokens per Second") ax.set_title("Tokenization Throughput") ax.grid(True, alpha=0.3) max_speed = max(speeds) ax.set_ylim(0, max_speed * 1.15) y_ticks = [0, 100000, 200000, 300000, 400000] ax.set_yticks(y_ticks) ax.set_yticklabels([f"{int(y/1000)}K" if y > 0 else "0" for y in y_ticks]) for x, y in zip(sizes, speeds): label = f"{y/1000:.0f}K" ax.annotate( label, xy=(x, y), xytext=(0, 5), textcoords="offset points", ha="center", fontsize=9, ) ax.set_xticks([100, 1000, 10000]) ax.set_xticklabels(["10²", "10³", "10⁴"]) def _plot_tokenization_rate(ax, results): """Plot tokenization rate by text type.""" if config.RESULTS_KEY not in results: ax.text(0.5, 0.5, "No data", ha="center", va="center") ax.set_title("Tokenization Rate") return text_types = [] tok_per_char = [] for result in results[config.RESULTS_KEY]: if config.TOKENIZATION_KEY in result and result.get("success", False): text_type = result.get("text_type", "unknown") if text_type not in text_types: text_types.append(text_type) tpc = result[config.TOKENIZATION_KEY]["tokens_per_char"] tok_per_char.append(tpc) if not text_types: ax.text(0.5, 0.5, "No tokenization data", ha="center", va="center") ax.set_title("Tokenization Rate") return x = np.arange(len(text_types)) bars = ax.bar(x, tok_per_char, color="#2196f3", alpha=0.7) for bar_rect, val in zip(bars, tok_per_char): ax.text( bar_rect.get_x() + bar_rect.get_width() / 2, val + 0.005, f"{val:.3f}", ha="center", va="bottom", fontsize=9, ) ax.set_xlabel("Text Type") ax.set_ylabel("Tokens per Character") ax.set_title("Tokenization Rate") ax.set_xticks(x) ax.set_xticklabels([t.capitalize() for t in text_types]) ax.grid(True, alpha=0.3, axis="y") ax.set_ylim(0, max(0.30, max(tok_per_char) * 1.2) if tok_per_char else 0.30) def _plot_extraction_density(ax, results): """Plot entity extraction density.""" if config.RESULTS_KEY not in results: ax.text(0.5, 0.5, "No data", ha="center", va="center") ax.set_title("Extraction Density") return text_types = [] densities = [] for result in results[config.RESULTS_KEY]: if result.get("success", False): text_type = result.get("text_type", "unknown") if text_type not in text_types: text_types.append(text_type) char_count = 1000 if config.TOKENIZATION_KEY in result: char_count = result[config.TOKENIZATION_KEY].get("num_chars", 1000) entity_count = result.get("entity_count", 0) density = (entity_count * 1000) / char_count densities.append(density) if not text_types: ax.text(0.5, 0.5, "No successful extractions", ha="center", va="center") ax.set_title("Extraction Density") return x = np.arange(len(text_types)) bars = ax.bar(x, densities, color="#4caf50", alpha=0.7) for bar_rect, val in zip(bars, densities): ax.text( bar_rect.get_x() + bar_rect.get_width() / 2, val, f"{val:.1f}", ha="center", va="bottom", fontsize=9, ) ax.set_xlabel("Text Type") ax.set_ylabel("Entities per 1K Characters") ax.set_title("Extraction Density") ax.set_xticks(x) ax.set_xticklabels([t.capitalize() for t in text_types]) ax.grid(True, alpha=0.3, axis="y") def _plot_processing_speed(ax, results): """Plot processing speed normalized by text size.""" if config.RESULTS_KEY not in results: ax.text(0.5, 0.5, "No data", ha="center", va="center") ax.set_title("Processing Speed") return text_types = [] speeds = [] for result in results[config.RESULTS_KEY]: if result.get("success", False): text_type = result.get("text_type", "unknown") if text_type not in text_types: text_types.append(text_type) char_count = 1000 if config.TOKENIZATION_KEY in result: char_count = result[config.TOKENIZATION_KEY].get("num_chars", 1000) time_seconds = result.get("time_seconds", 0) speed = (time_seconds * 1000) / char_count speeds.append(speed) if not text_types: ax.text(0.5, 0.5, "No timing data", ha="center", va="center") ax.set_title("Processing Speed") return x = np.arange(len(text_types)) bars = ax.bar(x, speeds, color="#ff9800", alpha=0.7) for bar_rect, val in zip(bars, speeds): ax.text( bar_rect.get_x() + bar_rect.get_width() / 2, val, f"{val:.1f}s", ha="center", va="bottom", fontsize=9, ) ax.set_xlabel("Text Type") ax.set_ylabel("Seconds per 1K Characters") ax.set_title("Processing Speed") ax.set_xticks(x) ax.set_xticklabels([t.capitalize() for t in text_types]) ax.grid(True, alpha=0.3, axis="y") def _plot_summary_table(ax, results): """Create a summary of key findings.""" ax.axis("off") if config.RESULTS_KEY not in results: ax.text(0.5, 0.5, "No data", ha="center", va="center") ax.set_title("Key Metrics") return summary_lines = [] summary_lines.append("Key Metrics") summary_lines.append("-" * 20) summary_lines.append("") success_count = sum( 1 for r in results.get(config.RESULTS_KEY, []) if r.get("success") ) total_count = len(results.get(config.RESULTS_KEY, [])) if total_count > 0: summary_lines.append("Tests Run:") summary_lines.append(f" {success_count} successful") summary_lines.append(f" {total_count - success_count} failed") summary_lines.append("") if success_count > 0: avg_time = ( sum( r.get("time_seconds", 0) for r in results.get(config.RESULTS_KEY, []) if r.get("success") ) / success_count ) summary_lines.append(f"Avg Time: {avg_time:.1f}s") summary_text = "\n".join(summary_lines) ax.text( 0.5, 0.5, summary_text, ha="center", va="center", fontsize=10, family="monospace", ) ax.set_title("Key Metrics", fontweight="bold", y=0.9) def create_comparison_plots(json_files: list[Path], output_path: Path) -> None: """Create comparison plots from multiple benchmark JSON files. Args: json_files: List of paths to benchmark JSON files to compare. output_path: Path where the comparison plot should be saved. """ if len(json_files) < 2: print("Need at least 2 JSON files for comparison") return all_results = [] for json_file in json_files: try: with open(json_file, "r") as f: data = json.load(f) data["filename"] = json_file.stem all_results.append(data) except (IOError, OSError, json.JSONDecodeError) as e: print(f"Error loading {json_file}: {e}") continue if len(all_results) < 2: print("Could not load enough valid JSON files for comparison") return plt.figure(figsize=(18, 12)) ax1 = plt.subplot(2, 3, (1, 2)) _plot_tokenization_comparison(ax1, all_results) ax2 = plt.subplot(2, 3, 3) _plot_entity_comparison(ax2, all_results) ax3 = plt.subplot(2, 3, 4) _plot_time_comparison(ax3, all_results) ax4 = plt.subplot(2, 3, 5) _plot_success_rate_comparison(ax4, all_results) ax5 = plt.subplot(2, 3, 6) _plot_timeline(ax5, all_results) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") plt.suptitle( f"LangExtract Benchmark Comparison - {timestamp}", fontsize=14, fontweight="bold", ) plt.tight_layout(rect=[0, 0.01, 1, 0.95]) plt.subplots_adjust(hspace=0.45, wspace=0.35, top=0.93) plt.savefig(output_path, dpi=100, bbox_inches="tight") plt.close() print(f"Comparison plot saved to: {output_path}") def _plot_entity_comparison(ax, all_results): """Plot entity count comparison across runs.""" runs = [] languages = ["english", "french", "spanish", "japanese"] language_data = [] for result in all_results: run_name = result["filename"].replace("benchmark_", "")[:10] runs.append(run_name) run_counts = {lang: 0 for lang in languages} if config.RESULTS_KEY in result: for res in result[config.RESULTS_KEY]: lang = res.get("text_type", "") if lang in languages and res.get("success"): run_counts[lang] = res.get("entity_count", 0) language_data.append(run_counts) x = np.arange(len(runs)) width = 0.2 for i, lang in enumerate(languages): counts = [data[lang] for data in language_data] bars = ax.bar(x + i * width, counts, width, label=lang.capitalize()) for bar_rect, count in zip(bars, counts): if count > 0: ax.text( bar_rect.get_x() + bar_rect.get_width() / 2, bar_rect.get_height() + 0.5, str(count), ha="center", fontsize=7, ) ax.set_xlabel("Run") ax.set_ylabel("Entity Count") title = "Entities Extracted by Language\n" subtitle = "Number of unique character names found per language" ax.set_title(title, fontweight="bold", fontsize=10) ax.text( 0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", fontsize=7, style="italic", color="#666666", va="bottom", ) ax.set_xticks(x + width * 1.5) ax.set_xticklabels(runs, rotation=45, ha="right") ax.legend(loc="upper left", fontsize=8) ax.grid(True, alpha=0.3) ax.set_ylim(0, ax.get_ylim()[1] * 1.1) def _plot_time_comparison(ax, all_results): """Plot processing time comparison.""" runs = [] avg_times = [] for result in all_results: run_name = result["filename"].replace("benchmark_", "")[:10] runs.append(run_name) if config.RESULTS_KEY in result: times = [ r.get("time_seconds", 0) for r in result[config.RESULTS_KEY] if r.get("success") ] avg_time = sum(times) / len(times) if times else 0 avg_times.append(avg_time) else: avg_times.append(0) x_pos = np.arange(len(runs)) bars = ax.bar(x_pos, avg_times, color="skyblue", edgecolor="navy", alpha=0.7) ax.set_xlabel("Run") ax.set_ylabel("Average Time (seconds)") title = "Average Processing Time\n" subtitle = "Mean extraction time across all language tests" ax.set_title(title, fontweight="bold", fontsize=10) ax.text( 0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", fontsize=7, style="italic", color="#666666", va="bottom", ) ax.set_xticks(x_pos) ax.set_xticklabels(runs, rotation=45, ha="right") ax.grid(True, alpha=0.3) for bar_rect, time in zip(bars, avg_times): if time > 0: ax.text( bar_rect.get_x() + bar_rect.get_width() / 2, bar_rect.get_height() + 0.1, f"{time:.1f}s", ha="center", fontsize=8, ) if max(avg_times) > 0: ax.set_ylim(0, max(avg_times) * 1.2) def _plot_tokenization_comparison(ax, all_results): """Plot tokenization throughput comparison as line graphs.""" for i, result in enumerate(all_results): run_name = result["filename"].replace("benchmark_", "")[:10] if config.TOKENIZATION_KEY in result and result[config.TOKENIZATION_KEY]: sizes = [r["words"] for r in result[config.TOKENIZATION_KEY]] speeds = [r["tokens_per_sec"] for r in result[config.TOKENIZATION_KEY]] ax.semilogx( sizes, speeds, "o-", linewidth=2, markersize=6, label=run_name, alpha=0.8, ) for x, y in zip(sizes, speeds): if i == 0: # Only label first run to avoid overlap label = f"{y/1000:.0f}K" ax.annotate( label, xy=(x, y), xytext=(0, 5), textcoords="offset points", ha="center", fontsize=7, ) ax.set_xlabel("Number of Words (log scale)") ax.set_ylabel("Tokens per Second") title = "Tokenization Throughput Comparison\n" subtitle = "Speed of text tokenization at different document sizes" ax.set_title(title, fontweight="bold", fontsize=10) ax.text( 0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", fontsize=7, style="italic", color="#666666", va="bottom", ) ax.grid(True, alpha=0.3) ax.legend(loc="best", fontsize=8) ax.set_xticks([100, 1000, 10000]) ax.set_xticklabels(["10²", "10³", "10⁴"]) _, ymax = ax.get_ylim() ax.set_ylim(0, ymax * 1.1) def _plot_success_rate_comparison(ax, all_results): """Plot success rate comparison.""" runs = [] success_rates = [] for result in all_results: run_name = result["filename"].replace("benchmark_", "")[:10] runs.append(run_name) if config.RESULTS_KEY in result: total = len(result[config.RESULTS_KEY]) success = sum(1 for r in result[config.RESULTS_KEY] if r.get("success")) rate = (success / total * 100) if total > 0 else 0 success_rates.append(rate) else: success_rates.append(0) x_pos = np.arange(len(runs)) colors = [ "green" if rate == 100 else "orange" if rate >= 75 else "red" for rate in success_rates ] bars = ax.bar(x_pos, success_rates, color=colors, alpha=0.7) ax.set_xlabel("Run") ax.set_ylabel("Success Rate (%)") title = "Extraction Success Rate\n" subtitle = "Percentage of language tests completed without errors" ax.set_title(title, fontweight="bold", fontsize=10) ax.text( 0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", fontsize=7, style="italic", color="#666666", va="bottom", ) ax.set_ylim(0, 105) ax.set_xticks(x_pos) ax.set_xticklabels(runs, rotation=45, ha="right") ax.axhline(y=100, color="green", linestyle="--", alpha=0.3) ax.grid(True, alpha=0.3) for bar_rect, rate in zip(bars, success_rates): ax.text( bar_rect.get_x() + bar_rect.get_width() / 2, bar_rect.get_height() + 1, f"{rate:.0f}%", ha="center", fontsize=8, ) def _plot_token_rate_by_language(ax, all_results): """Plot tokenization rates by language.""" languages = ["english", "french", "spanish", "japanese"] latest_result = all_results[-1] token_rates = [] colors = [] if config.RESULTS_KEY in latest_result: for lang in languages: lang_results = [ r for r in latest_result[config.RESULTS_KEY] if r.get("text_type") == lang and r.get("success") ] if lang_results and config.TOKENIZATION_KEY in lang_results[0]: rate = lang_results[0][config.TOKENIZATION_KEY].get( "tokens_per_char", 0 ) token_rates.append(rate) colors.append( "red" if rate < 0.1 else "orange" if rate < 0.2 else "green" ) else: token_rates.append(0) colors.append("gray") ax.bar(languages, token_rates, color=colors, alpha=0.7) ax.set_xlabel("Language") ax.set_ylabel("Tokens per Character") ax.set_title("Tokenization Density (Latest Run)") ax.set_xticks(range(len(languages))) ax.set_xticklabels([l.capitalize() for l in languages]) ax.grid(True, alpha=0.3) for i, (lang, rate) in enumerate(zip(languages, token_rates)): ax.text(i, rate + 0.01, f"{rate:.3f}", ha="center", fontsize=8) def _plot_timeline(ax, all_results): """Plot metrics over time if timestamps available.""" timestamps = [] entity_totals = [] for result in all_results: filename = result["filename"] if "timestamp" in result: timestamps.append(result["timestamp"]) else: # Try to parse from filename (format: benchmark_YYYYMMDD_HHMMSS) parts = filename.split("_") if len(parts) >= 3: timestamps.append(f"{parts[-2]}_{parts[-1]}") else: timestamps.append(filename[:10]) if config.RESULTS_KEY in result: total_entities = sum( r.get("entity_count", 0) for r in result[config.RESULTS_KEY] if r.get("success") ) entity_totals.append(total_entities) else: entity_totals.append(0) x_pos = np.arange(len(timestamps)) ax.plot(x_pos, entity_totals, "o-", color="blue", linewidth=2, markersize=8) ax.set_xlabel("Run") ax.set_ylabel("Total Entities") title = "Total Entities Over Time\n" subtitle = "Sum of all entities extracted across all languages" ax.set_title(title, fontweight="bold", fontsize=10) ax.text( 0.5, 1.01, subtitle, transform=ax.transAxes, ha="center", fontsize=7, style="italic", color="#666666", va="bottom", ) ax.set_xticks(x_pos) ax.set_xticklabels([t[-6:] for t in timestamps], rotation=45, ha="right") ax.grid(True, alpha=0.3) for i, total in enumerate(entity_totals): ax.text(i, total + 1, str(total), ha="center", fontsize=8) if entity_totals: min_val = min(0, min(entity_totals) - 5) max_val = max(entity_totals) + 5 ax.set_ylim(min_val, max_val) ================================================ FILE: benchmarks/utils.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper functions for benchmark text retrieval and analysis.""" import subprocess from typing import Any import urllib.error import urllib.request from benchmarks import config from langextract.core import tokenizer def download_text(url: str) -> str: """Download text from URL. Args: url: URL to download from. Returns: Downloaded text content. """ try: with urllib.request.urlopen(url) as response: return response.read().decode("utf-8") except (urllib.error.URLError, urllib.error.HTTPError) as e: raise RuntimeError(f"Could not download from {url}: {e}") from e def extract_text_content(full_text: str) -> str: """Extract main content from Gutenberg text. Skips headers and footers by taking middle 60% of text. Args: full_text: Full text including Gutenberg headers. Returns: Extracted main content. """ start_marker = "*** START OF" end_marker = "*** END OF" start_idx = full_text.upper().find(start_marker) end_idx = full_text.upper().find(end_marker) if start_idx != -1 and end_idx != -1: content_start = full_text.find("\n", start_idx) + 1 # Handle markers with trailing asterisks (e.g., "*** START ... ***"). line_end = full_text.find("***", start_idx + 3) if ( line_end != -1 and line_end < content_start + 100 ): # Ensure marker is on same line. content_start = full_text.find("\n", line_end) + 1 return full_text[content_start:end_idx].strip() text_length = len(full_text) start = int(text_length * 0.2) end = int(text_length * 0.8) return full_text[start:end].strip() def get_text_from_gutenberg(text_type: config.TextTypes) -> str: """Get text from Project Gutenberg for given language. Args: text_type: Type of text (language). Returns: Text sample from Gutenberg. """ url = config.GUTENBERG_TEXTS[text_type] full_text = download_text(url) content = extract_text_content(full_text) mid_point = len(content) // 2 start_chunk = max(0, mid_point - 2500) return content[start_chunk : start_chunk + 5000].strip() def get_optimal_text_size(text: str, model_id: str) -> str: """Get optimal text size for model. Args: text: Original text. model_id: Model identifier. Returns: Text truncated to optimal size. """ if ( ":" in model_id or "gemma" in model_id.lower() or "llama" in model_id.lower() ): max_chars = 500 # Smaller context for local models. else: max_chars = 5000 return text[:max_chars] def get_extraction_example(text_type: config.TextTypes) -> dict[str, str]: # pylint: disable=unused-argument """Get extraction example configuration. Args: text_type: Type of text. Returns: Dictionary with prompt configuration. """ return { "prompt": "Extract all character names from this text", } def get_git_info() -> dict[str, str]: """Get current git branch and commit info. Returns: Dictionary with branch and commit info. """ try: branch = subprocess.run( ["git", "branch", "--show-current"], capture_output=True, text=True, check=True, ).stdout.strip() commit = subprocess.run( ["git", "rev-parse", "--short", "HEAD"], capture_output=True, text=True, check=True, ).stdout.strip() status = subprocess.run( ["git", "status", "--porcelain"], capture_output=True, text=True, check=True, ).stdout.strip() if status: commit += "-dirty" return {"branch": branch, "commit": commit} except subprocess.CalledProcessError: return {"branch": "unknown", "commit": "unknown"} def analyze_tokenization( text: str, tokenizer_inst: tokenizer.Tokenizer | None = None ) -> dict[str, Any]: """Analyze tokenization of given text. Args: text: Text to analyze. tokenizer_inst: Tokenizer instance to use (default: RegexTokenizer). Returns: Dictionary with tokenization metrics. """ if tokenizer_inst: tokenized = tokenizer_inst.tokenize(text) else: tokenized = tokenizer.tokenize(text) num_tokens = len(tokenized.tokens) num_chars = len(text) tokens_per_char = num_tokens / num_chars if num_chars > 0 else 0 return { "num_tokens": num_tokens, "num_chars": num_chars, "tokens_per_char": tokens_per_char, } def format_tokenization_summary(analysis: dict[str, Any]) -> str: """Format tokenization analysis as summary string. Args: analysis: Tokenization analysis dict. Returns: Formatted summary string. """ return ( f"{analysis['num_tokens']} tokens, " f"{analysis['tokens_per_char']:.3f} tok/char" ) ================================================ FILE: docs/examples/batch_api_example.md ================================================ # Vertex AI Batch Processing Guide The Vertex AI Batch API offers significant cost savings (~50%) for large, non-time-critical workloads. `langextract` seamlessly integrates this with automatic routing, caching, and fault tolerance. **[Vertex AI Batch Prediction Documentation →](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini)** **[Quotas & Limits →](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/quotas#batch-prediction-quotas)** ## Real-World Example: Processing Shakespeare This example demonstrates how to process a large text (the first ~20 pages of *Romeo and Juliet*) using the Batch API. We use a small chunk size (`max_char_buffer=500`) to generate enough chunks to trigger batch processing. ```python import requests import textwrap import langextract as lx import logging # Configure logging to see progress (both in console and file) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("batch_process.log"), logging.StreamHandler() ] ) # 1. Download Text (Shakespeare's Romeo and Juliet) url = "https://www.gutenberg.org/files/1513/1513-0.txt" print(f"Downloading {url}...") text = requests.get(url).text # Process first ~20 pages (approx. 60k characters). text_subset = text[:60000] print(f"Processing first {len(text_subset)} characters...") # 2. Define Prompt & Examples prompt = textwrap.dedent("""\ Extract characters and emotions from the text. Use exact text from the input for extraction_text.""") examples = [ lx.data.ExampleData( text="ROMEO. But soft! What light through yonder window breaks?", extractions=[ lx.data.Extraction(extraction_class="character", extraction_text="ROMEO"), lx.data.Extraction(extraction_class="emotion", extraction_text="But soft!"), ] ) ] # 3. Configure Batch Settings batch_config = { "enabled": True, "threshold": 10, "poll_interval": 30, "timeout": 3600, # Set to True to cache results in GCS. Add timestamp to prompt to force re-run. "enable_caching": True, # Retention policy for GCS bucket (days). None for permanent. "retention_days": 30, } # 4. Run Extraction # langextract will automatically chunk the text and submit a batch job. results = lx.extract( text_or_documents=text_subset, prompt_description=prompt, examples=examples, model_id="gemini-2.5-flash", max_char_buffer=500, batch_length=1000, language_model_params={ "vertexai": True, "project": "your-gcp-project", # TODO: Replace with your Project ID. "location": "us-central1", "batch": batch_config } ) ## GCS File Structure The library automatically creates and manages a GCS bucket for you, named: `langextract-{project}-{location}-batch` Inside this bucket, data is organized as follows: - **Input**: `batch-input/{job_name}.jsonl` - **Output**: `batch-input/{job_name}/dest/prediction-model-{timestamp}/predictions.jsonl` - **Cache**: `cache/{hash}.json` (Individual cached results) ## Cost Optimization & Caching LangExtract's batch processing is designed to minimize costs: 1. **Cost Efficiency**: Vertex AI Batch predictions are typically ~50% cheaper than online predictions. 2. **Smart Caching**: - Results are cached in your GCS bucket (`cache/` directory). - **Instant Retrieval**: Re-running identical prompts fetches results directly from storage, bypassing model inference. - **Reduced Inference**: You avoid paying for redundant model calls on previously processed data. - **Lifecycle Management**: Use `retention_days` (e.g., 30) to automatically clean up old data and manage storage usage. ## Analyze Results print(f"Extracted {len(results.extractions)} entities.") print("First 5 extractions:") for extraction in results.extractions[:5]: print(f"- {extraction.extraction_class}: {extraction.extraction_text}") ``` ## Sample Output ```text Extracted 767 entities. First 5 extractions: - character: ESCALUS - character: MERCUTIO - character: PARIS - character: Page to Paris - character: MONTAGUE ``` > **Note on `batch_length`**: The `batch_length` parameter controls how many chunks are submitted in a single batch job. For optimal performance with the Batch API, set this to a high value (e.g., `1000`) to process all chunks in a single job rather than multiple sequential jobs. ## Key Features ### 1. Automatic Routing `langextract` automatically switches between real-time and batch APIs based on your `threshold`. - **< Threshold**: Uses real-time API for immediate results. - **>= Threshold**: Uses Batch API for cost savings. ### 2. Fault Tolerance & Caching Built-in GCS caching (`enable_caching=True`) allows you to resume interrupted jobs without re-processing completed items, saving time and cost. ### 3. Automated Storage `langextract` handles all GCS operations automatically using a dedicated bucket (`gs://langextract-{project}-{location}-batch`). Note that input/output files are retained for debugging. ## Tracking Job Status To monitor progress, you can watch the log file from a separate terminal: ```bash tail -f batch_process.log ``` When running a batch job, `langextract` provides clear log feedback with a direct link to the Google Cloud Console: ```text INFO - Batch job created successfully: projects/123456789/locations/us-central1/batchPredictionJobs/987654321 INFO - Job State: JobState.JOB_STATE_PENDING INFO - Job Console URL: https://console.cloud.google.com/vertex-ai/jobs/batch-predictions/987654321?project=123456789 INFO - Batch job is running... (State: JOB_STATE_PENDING) INFO - Batch job is running... (State: JOB_STATE_RUNNING) ``` - **Completion**: Once the job succeeds, `langextract` automatically downloads, parses, and aligns the results. ================================================ FILE: docs/examples/japanese_extraction.md ================================================ # Japanese Information Extraction This example demonstrates how to use LangExtract to extract structured information from Japanese text. > **Note:** For non-spaced languages like Japanese, use `UnicodeTokenizer` to ensure correct character-based segmentation and alignment. ## Full Pipeline Example ```python import langextract as lx from langextract.core import tokenizer # Japanese text with entities (Person, Location, Organization) # "Mr. Tanaka from Tokyo works at Google." input_text = "東京出身の田中さんはGoogleで働いています。" # Define extraction prompt prompt_description = "Extract named entities including Person, Location, and Organization." # Define example data (few-shot examples help the model understand the task) examples = [ lx.data.ExampleData( text="大阪の山田さんはソニーに入社しました。", # Mr. Yamada from Osaka joined Sony. extractions=[ lx.data.Extraction(extraction_class="Location", extraction_text="大阪"), lx.data.Extraction(extraction_class="Person", extraction_text="山田"), lx.data.Extraction(extraction_class="Organization", extraction_text="ソニー"), ] ) ] # 1. Initialize the UnicodeTokenizer # Essential for Japanese to ensure correct grapheme segmentation. unicode_tokenizer = tokenizer.UnicodeTokenizer() # 2. Run Extraction with the Custom Tokenizer result = lx.extract( text_or_documents=input_text, prompt_description=prompt_description, examples=examples, model_id="gemini-2.5-flash", tokenizer=unicode_tokenizer, # <--- Pass the tokenizer here api_key="your-api-key-here" # Optional if env var is set ) # 3. Display Results print(f"Input: {input_text}\n") print("Extracted Entities:") for entity in result.extractions: position_info = "" if entity.char_interval: start, end = entity.char_interval.start_pos, entity.char_interval.end_pos position_info = f" (pos: {start}-{end})" print(f"• {entity.extraction_class}: {entity.extraction_text}{position_info}") # Expected Output: # Input: 東京出身の田中さんはGoogleで働いています。 # # Extracted Entities: # • Location: 東京 (pos: 0-2) # • Person: 田中 (pos: 5-7) # • Organization: Google (pos: 10-16) ``` ================================================ FILE: docs/examples/longer_text_example.md ================================================ # *Romeo and Juliet* Full Text Extraction LangExtract can process entire documents directly from URLs, handling large texts with high accuracy through parallel processing and enhanced sensitivity features. This example demonstrates extraction from the complete text of *Romeo and Juliet* from Project Gutenberg. ## Example code The following code uses a comprehensive prompt and examples optimized for large, complex literary texts. For large complex inputs, using more detailed examples is suggested to increase extraction robustness. > **Warning:** Running this example processes a large document (~44 000 tokens) and will incur costs. For large-scale use, a Tier 2 Gemini quota is suggested to avoid rate-limit issues ([details](https://ai.google.dev/gemini-api/docs/rate-limits#tier-2)). Please review the [Gemini API pricing](https://ai.google.dev/gemini-api/docs/pricing) before proceeding. ```python import langextract as lx import textwrap from collections import Counter, defaultdict # Define comprehensive prompt and examples for complex literary text prompt = textwrap.dedent("""\ Extract characters, emotions, and relationships from the given text. Provide meaningful attributes for every entity to add context and depth. Important: Use exact text from the input for extraction_text. Do not paraphrase. Extract entities in order of appearance with no overlapping text spans. Note: In play scripts, speaker names appear in ALL-CAPS followed by a period.""") examples = [ lx.data.ExampleData( text=textwrap.dedent("""\ ROMEO. But soft! What light through yonder window breaks? It is the east, and Juliet is the sun. JULIET. O Romeo, Romeo! Wherefore art thou Romeo?"""), extractions=[ lx.data.Extraction( extraction_class="character", extraction_text="ROMEO", attributes={"emotional_state": "wonder"} ), lx.data.Extraction( extraction_class="emotion", extraction_text="But soft!", attributes={"feeling": "gentle awe", "character": "Romeo"} ), lx.data.Extraction( extraction_class="relationship", extraction_text="Juliet is the sun", attributes={"type": "metaphor", "character_1": "Romeo", "character_2": "Juliet"} ), lx.data.Extraction( extraction_class="character", extraction_text="JULIET", attributes={"emotional_state": "yearning"} ), lx.data.Extraction( extraction_class="emotion", extraction_text="Wherefore art thou Romeo?", attributes={"feeling": "longing question", "character": "Juliet"} ), ] ) ] # Process Romeo & Juliet directly from Project Gutenberg print("Downloading and processing Romeo and Juliet from Project Gutenberg...") result = lx.extract( text_or_documents="https://www.gutenberg.org/files/1513/1513-0.txt", prompt_description=prompt, examples=examples, model_id="gemini-2.5-flash", extraction_passes=3, # Multiple passes for improved recall max_workers=20, # Parallel processing for speed max_char_buffer=1000 # Smaller contexts for better accuracy ) print(f"Extracted {len(result.extractions)} entities from {len(result.text):,} characters") # Save and visualize the results lx.io.save_annotated_documents([result], output_name="romeo_juliet_extractions.jsonl", output_dir=".") # Generate the interactive visualization html_content = lx.visualize("romeo_juliet_extractions.jsonl") with open("romeo_juliet_visualization.html", "w") as f: if hasattr(html_content, 'data'): f.write(html_content.data) # For Jupyter/Colab else: f.write(html_content) print("Interactive visualization saved to romeo_juliet_visualization.html") ``` This creates an interactive HTML visualization for exploring the extracted entities: ![Romeo and Juliet Full Visualization](../_static/romeo_juliet_full.gif) ```python # Analyze character mentions characters = {} for e in result.extractions: if e.extraction_class == "character": char_name = e.extraction_text if char_name not in characters: characters[char_name] = {"count": 0, "attributes": set()} characters[char_name]["count"] += 1 if e.attributes: for attr_key, attr_val in e.attributes.items(): characters[char_name]["attributes"].add(f"{attr_key}: {attr_val}") # Print character summary print(f"\nCHARACTER SUMMARY ({len(characters)} unique characters)") print("=" * 60) sorted_chars = sorted(characters.items(), key=lambda x: x[1]["count"], reverse=True) for char_name, char_data in sorted_chars[:10]: # Top 10 characters attrs_preview = list(char_data["attributes"])[:3] attrs_str = f" ({', '.join(attrs_preview)})" if attrs_preview else "" print(f"{char_name}: {char_data['count']} mentions{attrs_str}") # Entity type breakdown entity_counts = Counter(e.extraction_class for e in result.extractions) print(f"\nENTITY TYPE BREAKDOWN") print("=" * 60) for entity_type, count in entity_counts.most_common(): percentage = (count / len(result.extractions)) * 100 print(f"{entity_type}: {count} ({percentage:.1f}%)") ``` ## Sample output ``` Downloading and processing Romeo and Juliet from Project Gutenberg... Downloaded 147,843 characters (25,976 words) from 1513-0.txt Extracted 4,088 entities from 147,843 characters Interactive visualization saved to romeo_juliet_visualization.html CHARACTER SUMMARY (153 unique characters) ============================================================ ROMEO: 287 mentions (emotional_state: excitement, emotional_state: eager to please) JULIET: 204 mentions (emotional_state: fond, emotional_state: resilient) NURSE: 168 mentions (emotional_state: reporting, emotional_state: teasing and evasive) MERCUTIO: 107 mentions (emotional_state: approving, emotional_state: responsive) BENVOLIO: 82 mentions (emotional_state: cautious, emotional_state: teasing) ENTITY TYPE BREAKDOWN ============================================================ character: 1,685 (41.2%) emotion: 1,524 (37.3%) relationship: 879 (21.5%) ``` ## Key benefits for long documents ### Sequential extraction passes Multiple extraction passes improve recall by performing independent extractions and merging non-overlapping results. Each pass uses identical parameters and processing—they are independent runs of the same extraction task. The number of passes is controlled by the `extraction_passes` parameter (e.g., `extraction_passes=3`). **How it works**: Each pass processes the full text independently using the same prompt and examples. Results are then merged using a "first-pass wins" strategy for overlapping entities, while adding unique non-overlapping entities from later passes. This approach captures entities that might be missed in any single run due to the stochastic nature of language model generation. ### Portable and Interoperable Data with JSONL LangExtract uses JSONL, a human-readable format ideal for language model data. Each line is a self-contained JSON object, making outputs easy to parse, share, and integrate with other tools. You can save results with `lx.io.save_annotated_documents` and reload them for later analysis, ensuring your data is both portable and persistent. ### Optimal long context management While single-inference approaches can be powerful, their accuracy may be affected by distant context. LangExtract uses smart chunking strategies that respect text delimiters (such as paragraph breaks) to keep context intact and well-formed for the model. Users can configure context sizes (`max_char_buffer`) combined with parallel processing (`max_workers`) to maintain extraction quality across large documents. Multiple sequential extraction passes further enhance sensitivity by capturing entities that might be missed in any single run due to the stochastic nature of language model generation. ### Enhanced accuracy through chunking The chunked processing approach can improve extraction quality over a single inference pass on a large document because each chunk uses a smaller, more manageable context size. This helps the model focus on the most relevant information and prevents interference from distant context. While the overall latency and time required remain similar due to parallelization, the extraction quality can be substantially higher with better entity coverage and more accurate attribute assignment across the entire document.¹ ### Interactive visualization at scale Seamlessly explore hundreds or thousands of entities through interactive HTML visualizations generated directly from JSONL output files. The generated visualizations handle large result sets efficiently, providing navigation and detailed entity inspection capabilities for comprehensive analysis of complex documents. ### Schema-guided knowledge extraction LangExtract combines precise text positioning with world knowledge enrichment, enabling extraction of information not explicitly stated in the text (like character identities and traits). Under the hood, the library implements [Controlled Generation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/control-generated-output) with supported models to ensure extracted data adheres to your specified schema while maintaining robust extractions across large inputs. --- ¹ Models like Gemini 1.5 Pro show strong performance on many benchmarks, but [needle-in-a-haystack tests](https://cloud.google.com/blog/products/ai-machine-learning/the-needle-in-the-haystack-test-and-how-gemini-pro-solves-it) across million-token contexts indicate that performance can vary in multi-fact retrieval scenarios. This demonstrates how LangExtract's smaller context windows approach ensures consistently high quality across entire documents by avoiding the complexity and potential degradation of massive single-context processing. ================================================ FILE: docs/examples/medication_examples.md ================================================ # Medication Extraction Examples LangExtract excels at extracting structured medical information from clinical text, making it particularly useful for healthcare applications. The methodology originated from research in medical information extraction, where early versions of the techniques were demonstrated to accelerate annotation tasks significantly. > **Disclaimer:** This demonstration is only for illustrative purposes of LangExtract's baseline capability. It does not represent a finished or approved product, is not intended to diagnose or suggest treatment of any disease or condition, and should not be used for medical advice. --- **Medical Information Extraction Research:** The concepts and methods underlying LangExtract were first demonstrated in: Goel, A., Lehman, E., Gulati, A., Chen, R., Nori, H., Hager, G. D., & Durr, N. J. (2023). "LLMs Accelerate Annotation for Medical Information Extraction." *Machine Learning for Health (ML4H), PMLR, 2023*. [arXiv:2312.02296](https://arxiv.org/abs/2312.02296) --- ## Basic Named Entity Recognition (NER) In this basic medical example, LangExtract extracts structured medication information: ```python import langextract as lx # Text with a medication mention input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." # Define extraction prompt prompt_description = "Extract medication information including medication name, dosage, route, frequency, and duration in the order they appear in the text." # Define example data with entities in order of appearance examples = [ lx.data.ExampleData( text="Patient was given 250 mg IV Cefazolin TID for one week.", extractions=[ lx.data.Extraction(extraction_class="dosage", extraction_text="250 mg"), lx.data.Extraction(extraction_class="route", extraction_text="IV"), lx.data.Extraction(extraction_class="medication", extraction_text="Cefazolin"), lx.data.Extraction(extraction_class="frequency", extraction_text="TID"), # TID = three times a day lx.data.Extraction(extraction_class="duration", extraction_text="for one week") ] ) ] result = lx.extract( text_or_documents=input_text, prompt_description=prompt_description, examples=examples, model_id="gemini-2.5-pro", api_key="your-api-key-here" # Optional if LANGEXTRACT_API_KEY environment variable is set ) # Display entities with positions print(f"Input: {input_text}\n") print("Extracted entities:") for entity in result.extractions: position_info = "" if entity.char_interval: start, end = entity.char_interval.start_pos, entity.char_interval.end_pos position_info = f" (pos: {start}-{end})" print(f"• {entity.extraction_class.capitalize()}: {entity.extraction_text}{position_info}") # Save and visualize the results lx.io.save_annotated_documents([result], output_name="medical_ner_extraction.jsonl", output_dir=".") # Generate the interactive visualization html_content = lx.visualize("medical_ner_extraction.jsonl") with open("medical_ner_visualization.html", "w") as f: if hasattr(html_content, 'data'): f.write(html_content.data) # For Jupyter/Colab else: f.write(html_content) print("Interactive visualization saved to medical_ner_visualization.html") ``` ![Medical NER Visualization](../_static/medication_entity.gif) This will produce an output similar to: ``` Input: Patient took 400 mg PO Ibuprofen q4h for two days. Extracted entities: • Dosage: 400 mg (pos: 13-19) • Route: PO (pos: 20-22) • Medication: Ibuprofen (pos: 23-32) • Frequency: q4h (pos: 33-36) • Duration: for two days (pos: 37-49) Interactive visualization saved to medical_ner_visualization.html ``` The interactive HTML visualization allows you to explore the extracted entities visually, with each entity type color-coded and clickable for detailed inspection. ## Relationship Extraction (RE) For more complex extractions that involve relationships between entities, LangExtract can also extract structured relationships. This example shows how to extract medications and their associated attributes: ```python import langextract as lx # Text with interleaved medication mentions input_text = """ The patient was prescribed Lisinopril and Metformin last month. He takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes. """ # Define extraction prompt prompt_description = """ Extract medications with their details, using attributes to group related information: 1. Extract entities in the order they appear in the text 2. Each entity must have a 'medication_group' attribute linking it to its medication 3. All details about a medication should share the same medication_group value """ # Define example data with medication groups examples = [ lx.data.ExampleData( text="Patient takes Aspirin 100mg daily for heart health and Simvastatin 20mg at bedtime.", extractions=[ # First medication group lx.data.Extraction( extraction_class="medication", extraction_text="Aspirin", attributes={"medication_group": "Aspirin"} # Group identifier ), lx.data.Extraction( extraction_class="dosage", extraction_text="100mg", attributes={"medication_group": "Aspirin"} ), lx.data.Extraction( extraction_class="frequency", extraction_text="daily", attributes={"medication_group": "Aspirin"} ), lx.data.Extraction( extraction_class="condition", extraction_text="heart health", attributes={"medication_group": "Aspirin"} ), # Second medication group lx.data.Extraction( extraction_class="medication", extraction_text="Simvastatin", attributes={"medication_group": "Simvastatin"} ), lx.data.Extraction( extraction_class="dosage", extraction_text="20mg", attributes={"medication_group": "Simvastatin"} ), lx.data.Extraction( extraction_class="frequency", extraction_text="at bedtime", attributes={"medication_group": "Simvastatin"} ) ] ) ] result = lx.extract( text_or_documents=input_text, prompt_description=prompt_description, examples=examples, model_id="gemini-2.5-pro", api_key="your-api-key-here" # Optional if LANGEXTRACT_API_KEY environment variable is set ) # Display grouped medications print(f"Input text: {input_text.strip()}\n") print("Extracted Medications:") # Group by medication medication_groups = {} for extraction in result.extractions: if not extraction.attributes or "medication_group" not in extraction.attributes: print(f"Warning: Missing medication_group for {extraction.extraction_text}") continue group_name = extraction.attributes["medication_group"] medication_groups.setdefault(group_name, []).append(extraction) # Print each medication group for med_name, extractions in medication_groups.items(): print(f"\n* {med_name}") for extraction in extractions: position_info = "" if extraction.char_interval: start, end = extraction.char_interval.start_pos, extraction.char_interval.end_pos position_info = f" (pos: {start}-{end})" print(f" • {extraction.extraction_class.capitalize()}: {extraction.extraction_text}{position_info}") # Save and visualize the results lx.io.save_annotated_documents( [result], output_name="medical_relationship_extraction.jsonl", output_dir="." ) # Generate the interactive visualization html_content = lx.visualize("medical_relationship_extraction.jsonl") with open("medical_relationship_visualization.html", "w") as f: if hasattr(html_content, 'data'): f.write(html_content.data) # For Jupyter/Colab else: f.write(html_content) print("Interactive visualization saved to medical_relationship_visualization.html") ``` ![Medical Relationship Visualization](../_static/medication_entity_re.gif) This will produce output similar to: ``` Input text: The patient was prescribed Lisinopril and Metformin last month. He takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes. Extracted Medications: * Lisinopril • Medication: Lisinopril (pos: 28-38) • Dosage: 10mg (pos: 89-93) • Frequency: daily (pos: 94-99) • Condition: hypertension (pos: 104-116) * Metformin • Medication: Metformin (pos: 43-52) • Dosage: 500mg (pos: 149-154) • Frequency: twice daily (pos: 182-193) • Condition: diabetes (pos: 198-206) Interactive visualization saved to medical_relationship_visualization.html ``` The visualization highlights how the `medication_group` attributes connect related entities, making it easy to see which dosages, frequencies, and conditions belong to each medication. Each medication group is visually distinguished in the interactive display. **Understanding Relationship Extraction:** This example demonstrates how attributes enable efficient relationship extraction. Using the `medication_group` attribute as a linking key, related entities are grouped together logically. This approach simplifies extracting connected information and eliminates the need for additional processing steps, while preserving the precise alignment between extracted text and its original location in the document. The interactive visualization makes these relationships immediately apparent, with connected entities sharing visual groupings and color coding. ## Key Features Demonstrated - **Named Entity Recognition**: Extracts entities with their types (medication, dosage, route, etc.) - **Relationship Extraction**: Groups related entities using attributes - **Position Tracking**: Records exact positions of extracted entities in the source text - **Structured Output**: Organizes information in a format suitable for healthcare applications - **Interactive Visualization**: Generates HTML visualizations for exploring complex medical extractions with entity groupings and relationships clearly displayed ================================================ FILE: examples/custom_provider_plugin/README.md ================================================ # Custom Provider Plugin Example This example demonstrates how to create a custom provider plugin that extends LangExtract with your own model backend. **Note**: This is an example included in the LangExtract repository for reference. It is not part of the LangExtract package and won't be installed when you `pip install langextract`. **Automated Creation**: Instead of manually copying this example, use the [provider plugin generator script](../../scripts/create_provider_plugin.py): ```bash python scripts/create_provider_plugin.py MyProvider --with-schema ``` This will create a complete plugin structure with all boilerplate code ready for customization. ## Structure ``` custom_provider_plugin/ ├── pyproject.toml # Package configuration and metadata ├── README.md # This file ├── langextract_provider_example/ # Package directory │ ├── __init__.py # Package initialization │ ├── provider.py # Custom provider implementation │ └── schema.py # Custom schema implementation (optional) └── test_example_provider.py # Test script ``` ## Key Components ### Provider Implementation (`provider.py`) ```python @lx.providers.registry.register( r'^gemini', # Pattern for model IDs this provider handles ) class CustomGeminiProvider(lx.inference.BaseLanguageModel): def __init__(self, model_id: str, **kwargs): # Initialize your backend client def infer(self, batch_prompts, **kwargs): # Call your backend API and return results ``` ### Package Configuration (`pyproject.toml`) ```toml [project.entry-points."langextract.providers"] custom_gemini = "langextract_provider_example:CustomGeminiProvider" ``` This entry point allows LangExtract to automatically discover your provider. ### Custom Schema Support (`schema.py`) Providers can optionally implement custom schemas for structured output: **Flow:** Examples → `from_examples()` → `to_provider_config()` → Provider kwargs → Inference ```python class CustomProviderSchema(lx.schema.BaseSchema): @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): # Analyze examples to find patterns # Build schema based on extraction classes and attributes seen return cls(schema_dict) def to_provider_config(self): # Convert schema to provider kwargs return { "response_schema": self._schema_dict, "enable_structured_output": True } @property def supports_strict_mode(self): # True = valid JSON output, no markdown fences needed return True ``` Then in your provider: ```python class CustomProvider(lx.inference.BaseLanguageModel): @classmethod def get_schema_class(cls): return CustomProviderSchema # Tell LangExtract about your schema def __init__(self, **kwargs): # Receive schema config in kwargs when use_schema_constraints=True self.response_schema = kwargs.get('response_schema') def infer(self, batch_prompts, **kwargs): # Use schema during API calls if self.response_schema: config['response_schema'] = self.response_schema ``` ## Installation ```bash # Navigate to this example directory first cd examples/custom_provider_plugin # Install in development mode pip install -e . # Test the provider (must be run from this directory) python test_example_provider.py ``` ## Usage Since this example registers the same pattern as the default Gemini provider, you must explicitly specify it: ```python import langextract as lx # Create a configured model with explicit provider selection config = lx.factory.ModelConfig( model_id="gemini-2.5-flash", provider="CustomGeminiProvider", provider_kwargs={"api_key": "your-api-key"} ) model = lx.factory.create_model(config) # Note: Passing model directly to extract() is coming soon. # For now, use the model's infer() method directly or pass parameters individually: result = lx.extract( text_or_documents="Your text here", model_id="gemini-2.5-flash", api_key="your-api-key", prompt_description="Extract key information", examples=[...] ) # Coming soon: Direct model passing # result = lx.extract( # text_or_documents="Your text here", # model=model, # Planned feature # prompt_description="Extract key information" # ) ``` ## Creating Your Own Provider - Step by Step ### 1. Copy and Rename ```bash # Copy this example directory cp -r examples/custom_provider_plugin/ ~/langextract-myprovider/ # Rename the package directory cd ~/langextract-myprovider/ mv langextract_provider_example langextract_myprovider ``` ### 2. Update Package Configuration Edit `pyproject.toml`: - Change `name = "langextract-myprovider"` - Update description and author information - Change entry point: `myprovider = "langextract_myprovider:MyProvider"` ### 3. Modify Provider Implementation Edit `provider.py`: - Change class name from `CustomGeminiProvider` to `MyProvider` - Update `@register()` patterns to match your model IDs - Replace Gemini API calls with your backend - Add any provider-specific parameters ### 4. Add Schema Support (Optional) Edit `schema.py`: - Rename to `MyProviderSchema` - Customize `from_examples()` for your extraction format - Update `to_provider_config()` for your API requirements - Set `supports_strict_mode` based on your capabilities ### 5. Install and Test ```bash # Install in development mode pip install -e . # Test your provider python -c " import langextract as lx lx.providers.load_plugins_once() print('Provider registered:', any('myprovider' in str(e) for e in lx.providers.registry.list_entries())) " ``` ### 6. Write Tests - Test that your provider loads and handles basic inference - Verify schema support works (if implemented) - Test error handling for your specific API ### 7. Publish to PyPI and Share with Community ```bash # Build package python -m build # Upload to PyPI twine upload dist/* ``` **Share with the community:** - Submit a PR to add your provider to the [Community Providers Registry](../../COMMUNITY_PROVIDERS.md) - Open an issue on [LangExtract GitHub](https://github.com/google/langextract/issues) to announce your provider and get feedback ## Common Pitfalls to Avoid 1. **Forgetting to trigger plugin loading** - Plugins load lazily, use `load_plugins_once()` in tests 2. **Pattern conflicts** - Avoid patterns that conflict with built-in providers 3. **Missing dependencies** - List all requirements in `pyproject.toml` 4. **Schema mismatches** - Test schema generation with real examples 5. **Not handling None schema** - Provider must clear schema when `apply_schema(None)` is called (see provider.py for implementation) ## License Apache License 2.0 ================================================ FILE: examples/custom_provider_plugin/langextract_provider_example/__init__.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Example custom provider plugin for LangExtract.""" from langextract_provider_example.provider import CustomGeminiProvider __all__ = ["CustomGeminiProvider"] __version__ = "0.1.0" ================================================ FILE: examples/custom_provider_plugin/langextract_provider_example/provider.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Minimal example of a custom provider plugin for LangExtract.""" from __future__ import annotations import dataclasses from typing import Any, Iterator, Sequence from langextract_provider_example import schema as custom_schema import langextract as lx @lx.providers.registry.register( r'^gemini', # Matches Gemini model IDs (same as default provider) ) @dataclasses.dataclass(init=False) class CustomGeminiProvider(lx.inference.BaseLanguageModel): """Example custom LangExtract provider implementation. This demonstrates how to create a custom provider for LangExtract that can intercept and handle model requests. This example wraps the actual Gemini API to show how custom schemas integrate, but you would replace the Gemini calls with your own API or model implementation. Note: Since this registers the same pattern as the default Gemini provider, you must explicitly specify this provider when creating a model: config = lx.factory.ModelConfig( model_id="gemini-2.5-flash", provider="CustomGeminiProvider" ) model = lx.factory.create_model(config) """ model_id: str api_key: str | None temperature: float response_schema: dict[str, Any] | None = None enable_structured_output: bool = False _client: Any = dataclasses.field(repr=False, compare=False) def __init__( self, model_id: str = 'gemini-2.5-flash', api_key: str | None = None, temperature: float = 0.0, **kwargs: Any, ) -> None: """Initialize the custom provider. Args: model_id: The model ID. api_key: API key for the service. temperature: Sampling temperature. **kwargs: Additional parameters. """ super().__init__() # TODO: Replace with your own client initialization try: from google import genai # pylint: disable=import-outside-toplevel except ImportError as e: raise lx.exceptions.InferenceConfigError( 'This example requires google-genai package. ' 'Install with: pip install google-genai' ) from e self.model_id = model_id self.api_key = api_key self.temperature = temperature # Schema kwargs from CustomProviderSchema.to_provider_config() self.response_schema = kwargs.get('response_schema') self.enable_structured_output = kwargs.get( 'enable_structured_output', False ) # Store any additional kwargs for potential use self._extra_kwargs = kwargs if not self.api_key: raise lx.exceptions.InferenceConfigError( 'API key required. Set GEMINI_API_KEY or pass api_key parameter.' ) self._client = genai.Client(api_key=self.api_key) @classmethod def get_schema_class(cls) -> type[lx.schema.BaseSchema] | None: """Return our custom schema class. This allows LangExtract to use our custom schema implementation when use_schema_constraints=True is specified. Returns: Our custom schema class that will be used to generate constraints. """ return custom_schema.CustomProviderSchema def apply_schema(self, schema_instance: lx.schema.BaseSchema | None) -> None: """Apply or clear schema configuration. This method is called by LangExtract to dynamically apply schema constraints after the provider is instantiated. It's important to handle both the application of a new schema and clearing (None). Args: schema_instance: The schema to apply, or None to clear existing schema. """ super().apply_schema(schema_instance) if schema_instance: # Apply the new schema configuration config = schema_instance.to_provider_config() self.response_schema = config.get('response_schema') self.enable_structured_output = config.get( 'enable_structured_output', False ) else: # Clear the schema configuration self.response_schema = None self.enable_structured_output = False def infer( self, batch_prompts: Sequence[str], **kwargs: Any ) -> Iterator[Sequence[lx.inference.ScoredOutput]]: """Run inference on a batch of prompts. Args: batch_prompts: Input prompts to process. **kwargs: Additional generation parameters. Yields: Lists of ScoredOutputs, one per prompt. """ config = { 'temperature': kwargs.get('temperature', self.temperature), } # Add other parameters if provided for key in ['max_output_tokens', 'top_p', 'top_k']: if key in kwargs: config[key] = kwargs[key] # Apply schema constraints if configured if self.response_schema and self.enable_structured_output: # For Gemini, this ensures the model outputs JSON matching our schema # Adapt this section based on your actual provider's API requirements config['response_schema'] = self.response_schema config['response_mime_type'] = 'application/json' for prompt in batch_prompts: try: # TODO: Replace this with your own API/model calls response = self._client.models.generate_content( model=self.model_id, contents=prompt, config=config ) output = response.text.strip() yield [lx.inference.ScoredOutput(score=1.0, output=output)] except Exception as e: raise lx.exceptions.InferenceRuntimeError( f'API error: {str(e)}', original=e ) from e ================================================ FILE: examples/custom_provider_plugin/langextract_provider_example/schema.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Example custom schema implementation for provider plugins.""" from __future__ import annotations from typing import Any, Sequence import langextract as lx class CustomProviderSchema(lx.schema.BaseSchema): """Example custom schema implementation for a provider plugin. This demonstrates how plugins can provide their own schema implementations that integrate with LangExtract's schema system. Custom schemas allow providers to: 1. Generate provider-specific constraints from examples 2. Control output formatting and validation 3. Optimize for their specific model capabilities This example generates a JSON schema from the examples and passes it to the Gemini backend (which this example provider wraps) for structured output. """ def __init__(self, schema_dict: dict[str, Any], strict_mode: bool = True): """Initialize the custom schema. Args: schema_dict: The generated JSON schema dictionary. strict_mode: Whether the provider guarantees valid output. """ self._schema_dict = schema_dict self._strict_mode = strict_mode @classmethod def from_examples( cls, examples_data: Sequence[lx.data.ExampleData], attribute_suffix: str = "_attributes", ) -> CustomProviderSchema: """Generate schema from example data. This method analyzes the provided examples to build a schema that captures the structure of expected extractions. Called automatically by LangExtract when use_schema_constraints=True. Args: examples_data: Example extractions to learn from. attribute_suffix: Suffix for attribute fields (unused in this example). Returns: A configured CustomProviderSchema instance. Example: If examples contain extractions with class "condition" and attribute "severity", the schema will constrain the model to only output those specific classes and attributes. """ extraction_classes = set() attribute_keys = set() for example in examples_data: for extraction in example.extractions: extraction_classes.add(extraction.extraction_class) if extraction.attributes: attribute_keys.update(extraction.attributes.keys()) schema_dict = { "type": "object", "properties": { "extractions": { "type": "array", "items": { "type": "object", "properties": { "extraction_class": { "type": "string", "enum": ( list(extraction_classes) if extraction_classes else None ), }, "extraction_text": {"type": "string"}, "attributes": { "type": "object", "properties": { key: {"type": "string"} for key in attribute_keys }, }, }, "required": ["extraction_class", "extraction_text"], }, }, }, "required": ["extractions"], } # Remove enum if no classes found if not extraction_classes: del schema_dict["properties"]["extractions"]["items"]["properties"][ "extraction_class" ]["enum"] return cls(schema_dict, strict_mode=True) def to_provider_config(self) -> dict[str, Any]: """Convert schema to provider-specific configuration. This is called after from_examples() and returns kwargs that will be passed to the provider's __init__ method. The provider can then use these during inference. Returns: Dictionary of provider kwargs that will be passed to the model. In this example, we return both the schema and a flag to enable structured output mode. Note: These kwargs are merged with user-provided kwargs, with user values taking precedence (caller-wins merge semantics). """ return { "response_schema": self._schema_dict, "enable_structured_output": True, "output_format": "json", } @property def supports_strict_mode(self) -> bool: """Whether this schema guarantees valid structured output. Returns: True if the provider will emit valid JSON without needing Markdown fences for extraction. """ return self._strict_mode @property def schema_dict(self) -> dict[str, Any]: """Access the underlying schema dictionary. Returns: The JSON schema dictionary. """ return self._schema_dict ================================================ FILE: examples/custom_provider_plugin/pyproject.toml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "langextract-provider-example" # Change to your package name version = "0.1.0" # Update version for releases description = "Example custom provider plugin for LangExtract" readme = "README.md" requires-python = ">=3.10" license = {text = "Apache-2.0"} dependencies = [ # Uncomment when creating a standalone plugin package: # "langextract", # Will install latest version "google-genai>=0.2.0", # Replace with your backend's SDK ] # Register the provider with LangExtract's plugin system [project.entry-points."langextract.providers"] custom_gemini = "langextract_provider_example:CustomGeminiProvider" [tool.setuptools.packages.find] where = ["."] include = ["langextract_provider_example*"] ================================================ FILE: examples/custom_provider_plugin/test_example_provider.py ================================================ #!/usr/bin/env python3 # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Simple test for the custom provider plugin.""" import os import dotenv # Import the provider to trigger registration with LangExtract # Note: This manual import is only needed when running without installation. # After `pip install -e .`, the entry point system handles this automatically. from langextract_provider_example import CustomGeminiProvider # noqa: F401 import langextract as lx def main(): """Test the custom provider.""" dotenv.load_dotenv(override=True) api_key = os.getenv("GEMINI_API_KEY") or os.getenv("LANGEXTRACT_API_KEY") if not api_key: print("Set GEMINI_API_KEY or LANGEXTRACT_API_KEY to test") return config = lx.factory.ModelConfig( model_id="gemini-2.5-flash", provider="CustomGeminiProvider", provider_kwargs={"api_key": api_key}, ) model = lx.factory.create_model(config) print(f"✓ Created {model.__class__.__name__}") # Test inference prompts = ["Say hello"] results = list(model.infer(prompts)) if results and results[0]: print(f"✓ Inference worked: {results[0][0].output[:50]}...") else: print("✗ No response") if __name__ == "__main__": main() ================================================ FILE: examples/notebooks/romeo_juliet_extraction.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "header" }, "source": [ "# Romeo and Juliet Text Extraction with LangExtract\n", "\n", "This notebook demonstrates extracting characters, emotions, and relationships from Shakespeare's Romeo and Juliet using LangExtract.\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/langextract/blob/main/examples/notebooks/romeo_juliet_extraction.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "setup_header" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "install" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# Install LangExtract\n", "%pip install -q langextract" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "api_key" }, "outputs": [], "source": [ "# Set up your Gemini API key\n", "# Get your key from: https://aistudio.google.com/app/apikey\n", "import os\n", "from getpass import getpass\n", "\n", "if 'GEMINI_API_KEY' not in os.environ:\n", " os.environ['GEMINI_API_KEY'] = getpass('Enter your Gemini API key: ')" ] }, { "cell_type": "markdown", "metadata": { "id": "define_header" }, "source": [ "## Define Extraction Task" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "setup_extraction" }, "outputs": [], "source": [ "import langextract as lx\n", "import textwrap\n", "\n", "# Define the extraction task\n", "prompt = textwrap.dedent(\"\"\"\\\n", " Extract characters, emotions, and relationships in order of appearance.\n", " Use exact text for extractions. Do not paraphrase or overlap entities.\n", " Provide meaningful attributes for each entity to add context.\"\"\")\n", "\n", "# Provide a high-quality example\n", "examples = [\n", " lx.data.ExampleData(\n", " text=\"ROMEO. But soft! What light through yonder window breaks? It is the east, and Juliet is the sun.\",\n", " extractions=[\n", " lx.data.Extraction(\n", " extraction_class=\"character\",\n", " extraction_text=\"ROMEO\",\n", " attributes={\"emotional_state\": \"wonder\"}\n", " ),\n", " lx.data.Extraction(\n", " extraction_class=\"emotion\",\n", " extraction_text=\"But soft!\",\n", " attributes={\"feeling\": \"gentle awe\"}\n", " ),\n", " lx.data.Extraction(\n", " extraction_class=\"relationship\",\n", " extraction_text=\"Juliet is the sun\",\n", " attributes={\"type\": \"metaphor\"}\n", " ),\n", " ]\n", " )\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "extract_header" }, "source": [ "## Extract from Sample Text" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "simple_extraction" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[94m\u001b[1mLangExtract\u001b[0m: model=\u001b[92mgemini-2.5-flash\u001b[0m, current=\u001b[92m68\u001b[0m chars, processed=\u001b[92m68\u001b[0m chars: [00:01]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[92m✓\u001b[0m Extraction processing complete\n", "\u001b[92m✓\u001b[0m Extracted \u001b[1m3\u001b[0m entities (\u001b[1m3\u001b[0m unique types)\n", " \u001b[96m•\u001b[0m Time: \u001b[1m1.96s\u001b[0m\n", " \u001b[96m•\u001b[0m Speed: \u001b[1m35\u001b[0m chars/sec\n", " \u001b[96m•\u001b[0m Chunks: \u001b[1m1\u001b[0m\n", "Extracted 3 entities:\n", "\n", "• character: 'Lady Juliet'\n", " - emotional_state: longing\n", "• emotion: 'gazed longingly at the stars, her heart aching'\n", " - feeling: melancholy longing\n", "• relationship: 'her heart aching for Romeo'\n", " - type: romantic\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Simple extraction from a short text\n", "input_text = \"Lady Juliet gazed longingly at the stars, her heart aching for Romeo\"\n", "\n", "result = lx.extract(\n", " text_or_documents=input_text,\n", " prompt_description=prompt,\n", " examples=examples,\n", " model_id=\"gemini-2.5-flash\",\n", ")\n", "\n", "# Display results\n", "print(f\"Extracted {len(result.extractions)} entities:\\n\")\n", "for extraction in result.extractions:\n", " print(f\"• {extraction.extraction_class}: '{extraction.extraction_text}'\")\n", " if extraction.attributes:\n", " for key, value in extraction.attributes.items():\n", " print(f\" - {key}: {value}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "viz_header" }, "source": [ "## Interactive Visualization" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "visualization" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[94m\u001b[1mLangExtract\u001b[0m: Saving to \u001b[92mromeo_juliet.jsonl\u001b[0m: 1 docs [00:00, 995.33 docs/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[92m✓\u001b[0m Saved \u001b[1m1\u001b[0m documents to \u001b[92mromeo_juliet.jsonl\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "\u001b[94m\u001b[1mLangExtract\u001b[0m: Loading \u001b[92mromeo_juliet.jsonl\u001b[0m: 100%|██████████| 961/961 [00:00<00:00, 2.49MB/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[92m✓\u001b[0m Loaded \u001b[1m1\u001b[0m documents from \u001b[92mromeo_juliet.jsonl\u001b[0m\n", "Interactive visualization (hover over highlights to see attributes):\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "\n", "
\n", "
\n", "
Highlights Legend: character emotion relationship
\n", "
\n", "
\n", "
\n", " Lady Juliet gazed longingly at the stars, her heart aching for Romeo\n", "
\n", "
\n", "
\n", " \n", " \n", " \n", "
\n", "
\n", " \n", "
\n", "
\n", " Entity 1/3 |\n", " Pos [0-11]\n", "
\n", "
\n", "
\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Save results to JSONL\n", "lx.io.save_annotated_documents([result], output_name=\"romeo_juliet.jsonl\", output_dir=\".\")\n", "\n", "# Generate interactive visualization\n", "html_content = lx.visualize(\"romeo_juliet.jsonl\")\n", "\n", "# Display in notebook\n", "print(\"Interactive visualization (hover over highlights to see attributes):\")\n", "html_content" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "save_viz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✓ Visualization saved to romeo_juliet_visualization.html\n", "You can download this file from the Files panel on the left.\n" ] } ], "source": [ "# Save visualization to file (for downloading)\n", "with open(\"romeo_juliet_visualization.html\", \"w\") as f:\n", " # Handle both Jupyter (HTML object) and non-Jupyter (string) environments\n", " if hasattr(html_content, 'data'):\n", " f.write(html_content.data)\n", " else:\n", " f.write(html_content)\n", "\n", "print(\"✓ Visualization saved to romeo_juliet_visualization.html\")\n", "print(\"You can download this file from the Files panel on the left.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "experiment_header" }, "source": [ "## Try Your Own Text\n", "\n", "Experiment with your own Shakespeare quotes or any literary text!" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "experiment" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[94m\u001b[1mLangExtract\u001b[0m: model=\u001b[92mgemini-2.5-flash\u001b[0m, current=\u001b[92m163\u001b[0m chars, processed=\u001b[92m163\u001b[0m chars: [00:05]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[92m✓\u001b[0m Extraction processing complete\n", "\u001b[92m✓\u001b[0m Extracted \u001b[1m6\u001b[0m entities (\u001b[1m3\u001b[0m unique types)\n", " \u001b[96m•\u001b[0m Time: \u001b[1m5.84s\u001b[0m\n", " \u001b[96m•\u001b[0m Speed: \u001b[1m28\u001b[0m chars/sec\n", " \u001b[96m•\u001b[0m Chunks: \u001b[1m1\u001b[0m\n", "Extractions from your text:\n", "\n", "• character: 'JULIET'\n", " - emotional_state: longing\n", "• emotion: 'O Romeo, Romeo! wherefore art thou Romeo?'\n", " - feeling: desperate questioning\n", "• relationship: 'thy father'\n", " - type: familial\n", "• relationship: 'thy name'\n", " - type: lineage\n", "• relationship: 'my love'\n", " - type: romantic bond\n", "• relationship: 'Capulet'\n", " - type: family affiliation\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Try your own text\n", "your_text = \"\"\"\n", "JULIET: O Romeo, Romeo! wherefore art thou Romeo?\n", "Deny thy father and refuse thy name;\n", "Or, if thou wilt not, be but sworn my love,\n", "And I'll no longer be a Capulet.\n", "\"\"\"\n", "\n", "custom_result = lx.extract(\n", " text_or_documents=your_text,\n", " prompt_description=prompt,\n", " examples=examples,\n", " model_id=\"gemini-2.5-flash\",\n", ")\n", "\n", "print(\"Extractions from your text:\\n\")\n", "for e in custom_result.extractions:\n", " print(f\"• {e.extraction_class}: '{e.extraction_text}'\")\n", " if e.attributes:\n", " for key, value in e.attributes.items():\n", " print(f\" - {key}: {value}\")" ] } ], "metadata": { "colab": { "name": "Romeo and Juliet Text Extraction with LangExtract", "provenance": [] }, "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/ollama/.dockerignore ================================================ # Ignore Python cache __pycache__/ *.pyc *.pyo *.pyd .Python # Ignore version control .git/ .gitignore # Ignore OS files .DS_Store Thumbs.db # Ignore virtual environments venv/ env/ .venv/ # Ignore IDE files .vscode/ .idea/ *.swp *.swo # Ignore test artifacts .pytest_cache/ .coverage htmlcov/ # Ignore build artifacts build/ dist/ *.egg-info/ ================================================ FILE: examples/ollama/Dockerfile ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM python:3.11-slim-bookworm WORKDIR /app RUN pip install langextract COPY demo_ollama.py . CMD ["python", "demo_ollama.py"] ================================================ FILE: examples/ollama/README.md ================================================ # Ollama Examples This directory contains examples for using LangExtract with Ollama for local LLM inference. For setup instructions and documentation, see the [main README's Ollama section](../../README.md#using-local-llms-with-ollama). ## Quick Reference **Option 1: Run locally** ```bash # Install and start Ollama ollama pull gemma2:2b ollama serve # Keep this running in a separate terminal # Run the demo python demo_ollama.py ``` **Option 2: Run with Docker** ```bash # Runs both Ollama and the demo in containers docker-compose up ``` ## Files - `demo_ollama.py` - Comprehensive extraction examples demonstrating Ollama on README examples - `docker-compose.yml` - Production-ready Docker setup with health checks - `Dockerfile` - Container definition for LangExtract ## Configuration Options ### Timeout Settings For slower models or large prompts, you may need to increase the timeout (default: 120 seconds): ```python import langextract as lx result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id="llama3.1:70b", # Larger model may need more time timeout=300, # 5 minutes model_url="http://localhost:11434", ) ``` Or using ModelConfig: ```python config = lx.factory.ModelConfig( model_id="llama3.1:70b", provider_kwargs={ "model_url": "http://localhost:11434", "timeout": 300, # 5 minutes } ) ``` ## Model License Ollama models come with their own licenses. For example: - Gemma models: [Gemma Terms of Use](https://ai.google.dev/gemma/terms) - Llama models: [Meta Llama License](https://llama.meta.com/llama-downloads/) Please review the license for any model you use. ================================================ FILE: examples/ollama/demo_ollama.py ================================================ #!/usr/bin/env python3 # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Comprehensive demo of Ollama integration with FormatHandler. This example demonstrates: - Using the pre-configured OLLAMA_FORMAT_HANDLER for consistent configuration - Running multiple extraction examples with progress bars - Generating interactive HTML visualizations - Handling various extraction patterns (NER, relationships, dialogue extraction) Prerequisites: 1. Install Ollama: https://ollama.com/ 2. Pull the model: ollama pull gemma2:2b 3. Start Ollama: ollama serve Usage: python demo_ollama.py [--model MODEL_NAME] Examples: # Use default model (gemma2:2b) python demo_ollama.py # Use a different model python demo_ollama.py --model llama3.2:3b Output: Results are saved to test_output/ directory (gitignored) - JSONL files with extraction data - Interactive HTML visualizations """ import argparse import os from pathlib import Path import sys import textwrap import time import traceback import urllib.error import urllib.request import dotenv import langextract as lx from langextract.providers import ollama dotenv.load_dotenv(override=True) DEFAULT_MODEL = "gemma2:2b" DEFAULT_OLLAMA_URL = os.environ.get("OLLAMA_HOST", "http://localhost:11434") OUTPUT_DIR = "test_output" def check_ollama_available(url: str = DEFAULT_OLLAMA_URL) -> bool: """Check if Ollama is available at the specified URL.""" try: with urllib.request.urlopen(f"{url}/api/tags", timeout=2) as response: return response.status == 200 except (urllib.error.URLError, TimeoutError): return False def ensure_output_directory() -> Path: """Create output directory if it doesn't exist.""" output_path = Path(OUTPUT_DIR) output_path.mkdir(exist_ok=True) return output_path def print_header(title: str, width: int = 80) -> None: """Print a formatted header.""" print("\n" + "=" * width) print(f" {title}") print("=" * width) def print_section(title: str, width: int = 60) -> None: """Print a formatted section.""" print(f"\n▶ {title}") print("-" * width) def print_results_summary(extractions: list[lx.data.Extraction]) -> None: """Print a summary of extraction results.""" if not extractions: print(" No extractions found") return class_counts = {} for ext in extractions: class_counts[ext.extraction_class] = ( class_counts.get(ext.extraction_class, 0) + 1 ) print(f" Total extractions: {len(extractions)}") print(" By type:") for cls, count in sorted(class_counts.items()): print(f" • {cls}: {count}") def example_romeo_juliet( model_id: str, model_url: str ) -> lx.data.AnnotatedDocument | None: """Romeo & Juliet character and emotion extraction example.""" print_section("Example 1: Romeo & Juliet - Characters and Emotions") prompt = textwrap.dedent("""\ Extract characters, emotions, and relationships in order of appearance. Use exact text for extractions. Do not paraphrase or overlap entities. Provide meaningful attributes for each entity to add context.""") examples = [ lx.data.ExampleData( text=( "ROMEO. But soft! What light through yonder window breaks? It is" " the east, and Juliet is the sun." ), extractions=[ lx.data.Extraction( extraction_class="character", extraction_text="ROMEO", attributes={"emotional_state": "wonder"}, ), lx.data.Extraction( extraction_class="emotion", extraction_text="But soft!", attributes={"feeling": "gentle awe"}, ), lx.data.Extraction( extraction_class="relationship", extraction_text="Juliet is the sun", attributes={"type": "metaphor"}, ), ], ) ] input_text = ( "Lady Juliet gazed longingly at the stars, her heart aching for Romeo" ) print(f" Input: {input_text}") print(f" Model: {model_id}") print("\n Extracting...") result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=model_id, model_url=model_url, resolver_params={"format_handler": ollama.OLLAMA_FORMAT_HANDLER}, show_progress=True, ) print("\n Results:") print_results_summary(result.extractions) return result def example_medication_ner( model_id: str, model_url: str ) -> lx.data.AnnotatedDocument | None: """Medical named entity recognition example.""" print_section("Example 2: Medication Named Entity Recognition") input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." prompt_description = ( "Extract medication information including medication name, dosage, route," " frequency, and duration in the order they appear in the text." ) examples = [ lx.data.ExampleData( text="Patient was given 250 mg IV Cefazolin TID for one week.", extractions=[ lx.data.Extraction( extraction_class="dosage", extraction_text="250 mg" ), lx.data.Extraction( extraction_class="route", extraction_text="IV" ), lx.data.Extraction( extraction_class="medication", extraction_text="Cefazolin" ), lx.data.Extraction( extraction_class="frequency", extraction_text="TID" ), lx.data.Extraction( extraction_class="duration", extraction_text="for one week" ), ], ) ] print(f" Input: {input_text}") print(f" Model: {model_id}") print("\n Extracting...") result = lx.extract( text_or_documents=input_text, prompt_description=prompt_description, examples=examples, model_id=model_id, model_url=model_url, resolver_params={"format_handler": ollama.OLLAMA_FORMAT_HANDLER}, show_progress=True, ) print("\n Results:") print_results_summary(result.extractions) return result def example_medication_relationships( model_id: str, model_url: str ) -> lx.data.AnnotatedDocument | None: """Medication relationship extraction with grouped attributes.""" print_section("Example 3: Medication Relationship Extraction") input_text = textwrap.dedent(""" The patient was prescribed Lisinopril and Metformin last month. He takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes. """).strip() prompt_description = textwrap.dedent(""" Extract medications with their details, using attributes to group related information: 1. Extract entities in the order they appear in the text 2. Each entity must have a 'medication_group' attribute linking it to its medication 3. All details about a medication should share the same medication_group value """).strip() examples = [ lx.data.ExampleData( text=( "Patient takes Aspirin 100mg daily for heart health and" " Simvastatin 20mg at bedtime." ), extractions=[ lx.data.Extraction( extraction_class="medication", extraction_text="Aspirin", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class="dosage", extraction_text="100mg", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class="frequency", extraction_text="daily", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class="condition", extraction_text="heart health", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class="medication", extraction_text="Simvastatin", attributes={"medication_group": "Simvastatin"}, ), lx.data.Extraction( extraction_class="dosage", extraction_text="20mg", attributes={"medication_group": "Simvastatin"}, ), lx.data.Extraction( extraction_class="frequency", extraction_text="at bedtime", attributes={"medication_group": "Simvastatin"}, ), ], ) ] print(f" Input: {input_text[:80]}...") print(f" Model: {model_id}") print("\n Extracting...") result = lx.extract( text_or_documents=input_text, prompt_description=prompt_description, examples=examples, model_id=model_id, model_url=model_url, resolver_params={"format_handler": ollama.OLLAMA_FORMAT_HANDLER}, show_progress=True, ) print("\n Results:") print_results_summary(result.extractions) medication_groups = {} for ext in result.extractions: if ext.attributes and "medication_group" in ext.attributes: group_name = ext.attributes["medication_group"] medication_groups.setdefault(group_name, []).append(ext) if medication_groups: print("\n Grouped by medication:") for med_name in sorted(medication_groups.keys()): print(f" {med_name}: {len(medication_groups[med_name])} attributes") return result def example_shakespeare_dialogue( model_id: str, model_url: str ) -> lx.data.AnnotatedDocument | None: """Extract character dialogue from Shakespeare play excerpt.""" print_section("Example 4: Shakespeare Dialogue Extraction") long_text = textwrap.dedent(""" Act I, Scene I. Verona. A public place. Enter SAMPSON and GREGORY, armed with swords and bucklers. SAMPSON: Gregory, on my word, we'll not carry coals. GREGORY: No, for then we should be colliers. SAMPSON: I mean, an we be in choler, we'll draw. GREGORY: Ay, while you live, draw your neck out of collar. Enter ABRAHAM and BALTHASAR. ABRAHAM: Do you bite your thumb at us, sir? SAMPSON: I do bite my thumb, sir. ABRAHAM: Do you bite your thumb at us, sir? SAMPSON: No, sir, I do not bite my thumb at you, sir, but I bite my thumb, sir. GREGORY: Do you quarrel, sir? ABRAHAM: Quarrel, sir? No, sir. Enter BENVOLIO. BENVOLIO: Part, fools! Put up your swords. You know not what you do. Enter TYBALT. TYBALT: What, art thou drawn among these heartless hinds? Turn thee, Benvolio; look upon thy death. BENVOLIO: I do but keep the peace. Put up thy sword, Or manage it to part these men with me. TYBALT: What, drawn, and talk of peace? I hate the word, As I hate hell, all Montagues, and thee. Have at thee, coward! """).strip() prompt = ( "Extract all character names and their dialogue in order of appearance." ) examples = [ lx.data.ExampleData( text="JULIET: O Romeo, Romeo! Wherefore art thou Romeo?", extractions=[ lx.data.Extraction( extraction_class="character", extraction_text="JULIET" ), lx.data.Extraction( extraction_class="dialogue", extraction_text="O Romeo, Romeo! Wherefore art thou Romeo?", attributes={"speaker": "JULIET"}, ), ], ) ] print(f" Input: Romeo and Juliet Act I, Scene I ({len(long_text)} chars)") print(f" Model: {model_id}") print(" Note: Automatically chunked for longer text processing") print("\n Extracting...") result = lx.extract( text_or_documents=long_text, prompt_description=prompt, examples=examples, model_id=model_id, model_url=model_url, resolver_params={"format_handler": ollama.OLLAMA_FORMAT_HANDLER}, max_char_buffer=500, show_progress=True, ) print("\n Results:") print_results_summary(result.extractions) characters = set( ext.extraction_text for ext in result.extractions if ext.extraction_class == "character" ) if characters: print("\n Characters found: " + ", ".join(sorted(characters))) return result def save_results( results: list[tuple[str, lx.data.AnnotatedDocument | None]], output_dir: Path, ) -> None: """Save all results to JSONL and generate HTML visualizations.""" print_header("Saving Results and Generating Visualizations") saved_files = [] for name, result in results: if result is None: print(f" ✗ Skipping {name} (no result)") continue jsonl_file = f"{name}.jsonl" jsonl_path = output_dir / jsonl_file lx.io.save_annotated_documents( [result], output_name=jsonl_file, output_dir=str(output_dir) ) print(f" ✓ Saved {jsonl_path}") html_file = f"{name}.html" html_path = output_dir / html_file try: html_content = lx.visualize(str(jsonl_path)) with open(html_path, "w") as f: if hasattr(html_content, "data"): f.write(html_content.data) else: f.write(html_content) print(f" ✓ Generated {html_path}") saved_files.append((jsonl_path, html_path)) except Exception as e: print(f" ✗ Failed to generate {html_path}: {e}") return saved_files def main(): """Run all examples and generate outputs.""" parser = argparse.ArgumentParser( description="Ollama + FormatHandler Demo", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument( "--model", default=DEFAULT_MODEL, help=f"Ollama model to use (default: {DEFAULT_MODEL})", ) parser.add_argument( "--url", default=DEFAULT_OLLAMA_URL, help=f"Ollama server URL (default: {DEFAULT_OLLAMA_URL})", ) parser.add_argument( "--skip-examples", nargs="+", choices=["1", "2", "3", "4"], help="Skip specific examples (e.g., --skip-examples 3 4)", ) args = parser.parse_args() skip_examples = set(args.skip_examples or []) print_header("Ollama + FormatHandler Demo") print("\nConfiguration:") print(f" Model: {args.model}") print(f" Server: {args.url}") print(f" Output: {OUTPUT_DIR}/") print(f" Format Handler: {ollama.OLLAMA_FORMAT_HANDLER}") print("\nChecking Ollama server...") if not check_ollama_available(args.url): print(f"\n⚠️ ERROR: Ollama not available at {args.url}") print("\nTroubleshooting:") print(" 1. Install Ollama: https://ollama.com/") print(" 2. Start server: ollama serve") print(f" 3. Pull model: ollama pull {args.model}") print("\nFor Docker setup, see examples/ollama/docker-compose.yml") sys.exit(1) print("✓ Ollama server is available") output_dir = ensure_output_directory() print("✓ Output directory ready: " + str(output_dir) + "/") print_header("Running Examples") results = [] try: if "1" not in skip_examples: result = example_romeo_juliet(args.model, args.url) results.append(("romeo_juliet", result)) time.sleep(0.5) if "2" not in skip_examples: result = example_medication_ner(args.model, args.url) results.append(("medication_ner", result)) time.sleep(0.5) if "3" not in skip_examples: result = example_medication_relationships(args.model, args.url) results.append(("medication_relationships", result)) time.sleep(0.5) if "4" not in skip_examples: result = example_shakespeare_dialogue(args.model, args.url) results.append(("shakespeare_dialogue", result)) except KeyboardInterrupt: print("\n\n⚠️ Interrupted by user") print("Saving completed results...") except Exception as e: print(f"\n\n✗ Error during execution: {e}") traceback.print_exc() print("\nSaving completed results...") if results: save_results(results, output_dir) print_header("Summary") successful = sum(1 for _, r in results if r is not None) print(f"\n✓ Successfully ran {successful}/{len(results)} examples") if results: print(f"\nOutput files in {output_dir}/:") for name, result in results: if result is not None: print(f" • {name}.jsonl - Extraction data") print(f" • {name}.html - Interactive visualization") print("\nTo view results:") print(" open " + str(output_dir) + "/romeo_juliet.html") print("\nOr serve locally:") print(" python -m http.server 8000 --directory " + str(output_dir)) print(" Then visit http://localhost:8000") if __name__ == "__main__": main() ================================================ FILE: examples/ollama/docker-compose.yml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. services: ollama: image: ollama/ollama:0.5.4 ports: - "127.0.0.1:11434:11434" # Bind only to localhost for security volumes: - ollama-data:/root/.ollama # Cross-platform support command: serve healthcheck: test: ["CMD", "curl", "-f", "http://localhost:11434/api/version"] interval: 5s timeout: 3s retries: 5 start_period: 10s langextract: build: . depends_on: ollama: condition: service_healthy environment: - OLLAMA_HOST=http://ollama:11434 volumes: - .:/app command: python demo_ollama.py volumes: ollama-data: ================================================ FILE: langextract/__init__.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LangExtract: Extract structured information from text with LLMs. This package provides the main extract and visualize functions, with lazy loading for other submodules accessed via attribute access. """ from __future__ import annotations import importlib import sys from typing import Any, Dict from langextract import visualization from langextract.extraction import extract as extract_func __all__ = [ # Public convenience functions (thin wrappers) "extract", "visualize", # Submodules exposed lazily on attribute access for ergonomics: "annotation", "data", "providers", "schema", "inference", "factory", "resolver", "prompting", "io", "visualization", "exceptions", "core", "plugins", ] _CACHE: Dict[str, Any] = {} def extract(*args: Any, **kwargs: Any): """Top-level API: lx.extract(...).""" return extract_func(*args, **kwargs) def visualize(*args: Any, **kwargs: Any): """Top-level API: lx.visualize(...).""" return visualization.visualize(*args, **kwargs) # PEP 562 lazy loading _LAZY_MODULES = { "annotation": "langextract.annotation", "chunking": "langextract.chunking", "data": "langextract.data", "data_lib": "langextract.data_lib", "debug_utils": "langextract.core.debug_utils", "exceptions": "langextract.exceptions", "factory": "langextract.factory", "inference": "langextract.inference", "io": "langextract.io", "progress": "langextract.progress", "prompting": "langextract.prompting", "providers": "langextract.providers", "resolver": "langextract.resolver", "schema": "langextract.schema", "tokenizer": "langextract.tokenizer", "visualization": "langextract.visualization", "core": "langextract.core", "plugins": "langextract.plugins", "registry": "langextract.registry", # Backward compat - will emit warning } def __getattr__(name: str) -> Any: if name in _CACHE: return _CACHE[name] modpath = _LAZY_MODULES.get(name) if modpath is None: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") module = importlib.import_module(modpath) # ensure future 'import langextract.' returns the same module sys.modules[f"{__name__}.{name}"] = module setattr(sys.modules[__name__], name, module) _CACHE[name] = module return module def __dir__(): return sorted(__all__) ================================================ FILE: langextract/_compat/README.md ================================================ # Backward Compatibility Layer This directory contains backward compatibility shims for deprecated imports. ## Deprecation Timeline All code in this directory will be removed in LangExtract v2.0.0. ## Migration Guide The following imports are deprecated and should be updated: ### Inference Module - `from langextract.inference import BaseLanguageModel` → `from langextract.core.base_model import BaseLanguageModel` - `from langextract.inference import ScoredOutput` → `from langextract.core.types import ScoredOutput` - `from langextract.inference import InferenceOutputError` → `from langextract.core.exceptions import InferenceOutputError` - `from langextract.inference import GeminiLanguageModel` → `from langextract.providers.gemini import GeminiLanguageModel` - `from langextract.inference import OpenAILanguageModel` → `from langextract.providers.openai import OpenAILanguageModel` - `from langextract.inference import OllamaLanguageModel` → `from langextract.providers.ollama import OllamaLanguageModel` ### Schema Module - `from langextract.schema import BaseSchema` → `from langextract.core.schema import BaseSchema` - `from langextract.schema import Constraint` → `from langextract.core.schema import Constraint` - `from langextract.schema import ConstraintType` → `from langextract.core.schema import ConstraintType` - `from langextract.schema import EXTRACTIONS_KEY` → `from langextract.core.schema import EXTRACTIONS_KEY` - `from langextract.schema import GeminiSchema` → `from langextract.providers.schemas.gemini import GeminiSchema` ### Exceptions Module - All exceptions: `from langextract.exceptions import *` → `from langextract.core.exceptions import *` ### Registry Module - `from langextract.registry import *` → `from langextract.plugins import *` - `from langextract.providers.registry import *` → `from langextract.providers.router import *` ## For Contributors Do not add new code to this directory. All new development should use the canonical imports from `core/` and `providers/`. ================================================ FILE: langextract/_compat/__init__.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Backward compatibility layer for LangExtract. This package contains compatibility shims for deprecated imports. All code in this directory will be removed in v2.0.0. """ from __future__ import annotations __all__ = ["inference", "schema", "exceptions", "registry"] ================================================ FILE: langextract/_compat/exceptions.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.exceptions imports.""" # pylint: disable=duplicate-code from __future__ import annotations import warnings from langextract.core import exceptions # Re-export exceptions from core.exceptions with a warning-on-first-access def __getattr__(name: str): allowed = { "LangExtractError", "InferenceError", "InferenceConfigError", "InferenceRuntimeError", "InferenceOutputError", "ProviderError", "SchemaError", } if name in allowed: warnings.warn( "`langextract.exceptions` is deprecated; import from" " `langextract.core.exceptions`.", FutureWarning, stacklevel=2, ) return getattr(exceptions, name) raise AttributeError(name) ================================================ FILE: langextract/_compat/inference.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.inference imports.""" from __future__ import annotations import enum import warnings class InferenceType(enum.Enum): """Enum for inference types - kept for backward compatibility.""" ITERATIVE = "iterative" MULTIPROCESS = "multiprocess" def __getattr__(name: str): moved = { "BaseLanguageModel": ("langextract.core.base_model", "BaseLanguageModel"), "ScoredOutput": ("langextract.core.types", "ScoredOutput"), "InferenceOutputError": ( "langextract.core.exceptions", "InferenceOutputError", ), "GeminiLanguageModel": ( "langextract.providers.gemini", "GeminiLanguageModel", ), "OpenAILanguageModel": ( "langextract.providers.openai", "OpenAILanguageModel", ), "OllamaLanguageModel": ( "langextract.providers.ollama", "OllamaLanguageModel", ), } if name in moved: mod, attr = moved[name] warnings.warn( f"`langextract.inference.{name}` is deprecated and will be removed in" f" v2.0.0; use `{mod}.{attr}` instead.", FutureWarning, stacklevel=2, ) module = __import__(mod, fromlist=[attr]) return getattr(module, attr) raise AttributeError(name) ================================================ FILE: langextract/_compat/registry.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.registry imports.""" # pylint: disable=duplicate-code from __future__ import annotations import warnings from langextract import plugins def __getattr__(name: str): """Forward to plugins module with deprecation warning.""" warnings.warn( "`langextract.registry` is deprecated and will be removed in v2.0.0; " "use `langextract.plugins` instead.", FutureWarning, stacklevel=2, ) return getattr(plugins, name) ================================================ FILE: langextract/_compat/schema.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.schema imports.""" # pylint: disable=duplicate-code from __future__ import annotations import warnings def __getattr__(name: str): moved = { "BaseSchema": ("langextract.core.schema", "BaseSchema"), "Constraint": ("langextract.core.schema", "Constraint"), "ConstraintType": ("langextract.core.schema", "ConstraintType"), "EXTRACTIONS_KEY": ("langextract.core.schema", "EXTRACTIONS_KEY"), "GeminiSchema": ("langextract.providers.schemas.gemini", "GeminiSchema"), } if name in moved: mod, attr = moved[name] warnings.warn( f"`langextract.schema.{name}` is deprecated and will be removed in" f" v2.0.0; use `{mod}.{attr}` instead.", FutureWarning, stacklevel=2, ) module = __import__(mod, fromlist=[attr]) return getattr(module, attr) raise AttributeError(name) ================================================ FILE: langextract/annotation.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides functionality for annotating medical text using a language model. The annotation process involves tokenizing the input text, generating prompts for the language model, and resolving the language model's output into structured annotations. Usage example: annotator = Annotator(language_model, prompt_template) annotated_documents = annotator.annotate_documents(documents, resolver) """ from __future__ import annotations import collections from collections.abc import Iterable, Iterator import time from typing import DefaultDict from absl import logging from langextract import chunking from langextract import progress from langextract import prompting from langextract import resolver as resolver_lib from langextract.core import base_model from langextract.core import data from langextract.core import exceptions from langextract.core import format_handler as fh from langextract.core import tokenizer as tokenizer_lib def _merge_non_overlapping_extractions( all_extractions: list[Iterable[data.Extraction]], ) -> list[data.Extraction]: """Merges extractions from multiple extraction passes. When extractions from different passes overlap in their character positions, the extraction from the earlier pass is kept (first-pass wins strategy). Only non-overlapping extractions from later passes are added to the result. Args: all_extractions: List of extraction iterables from different sequential extraction passes, ordered by pass number. Returns: List of merged extractions with overlaps resolved in favor of earlier passes. """ if not all_extractions: return [] if len(all_extractions) == 1: return list(all_extractions[0]) merged_extractions = list(all_extractions[0]) for pass_extractions in all_extractions[1:]: for extraction in pass_extractions: overlaps = False if extraction.char_interval is not None: for existing_extraction in merged_extractions: if existing_extraction.char_interval is not None: if _extractions_overlap(extraction, existing_extraction): overlaps = True break if not overlaps: merged_extractions.append(extraction) return merged_extractions def _extractions_overlap( extraction1: data.Extraction, extraction2: data.Extraction ) -> bool: """Checks if two extractions overlap based on their character intervals. Args: extraction1: First extraction to compare. extraction2: Second extraction to compare. Returns: True if the extractions overlap, False otherwise. """ if extraction1.char_interval is None or extraction2.char_interval is None: return False start1, end1 = ( extraction1.char_interval.start_pos, extraction1.char_interval.end_pos, ) start2, end2 = ( extraction2.char_interval.start_pos, extraction2.char_interval.end_pos, ) if start1 is None or end1 is None or start2 is None or end2 is None: return False # Two intervals overlap if one starts before the other ends return start1 < end2 and start2 < end1 def _document_chunk_iterator( documents: Iterable[data.Document], max_char_buffer: int, restrict_repeats: bool = True, tokenizer: tokenizer_lib.Tokenizer | None = None, ) -> Iterator[chunking.TextChunk]: """Iterates over documents to yield text chunks along with the document ID. Args: documents: A sequence of Document objects. max_char_buffer: The maximum character buffer size for the ChunkIterator. restrict_repeats: Whether to restrict the same document id from being visited more than once. tokenizer: Optional tokenizer instance. Yields: TextChunk containing document ID for a corresponding document. Raises: InvalidDocumentError: If restrict_repeats is True and the same document ID is visited more than once. Valid documents prior to the error will be returned. """ visited_ids = set() for document in documents: if tokenizer: tokenized_text = tokenizer.tokenize(document.text or "") else: tokenized_text = document.tokenized_text document_id = document.document_id if restrict_repeats and document_id in visited_ids: raise exceptions.InvalidDocumentError( f"Document id {document_id} is already visited." ) chunk_iter = chunking.ChunkIterator( text=tokenized_text, max_char_buffer=max_char_buffer, document=document, tokenizer_impl=tokenizer or tokenizer_lib.RegexTokenizer(), ) visited_ids.add(document_id) yield from chunk_iter class Annotator: """Annotates documents with extractions using a language model.""" def __init__( self, language_model: base_model.BaseLanguageModel, prompt_template: prompting.PromptTemplateStructured, format_type: data.FormatType = data.FormatType.YAML, attribute_suffix: str = data.ATTRIBUTE_SUFFIX, fence_output: bool = False, format_handler: fh.FormatHandler | None = None, ): """Initializes Annotator. Args: language_model: Model which performs language model inference. prompt_template: Structured prompt template where the answer is expected to be formatted text (YAML or JSON). format_type: The format type for the output (YAML or JSON). attribute_suffix: Suffix to append to attribute keys in the output. fence_output: Whether to expect/generate fenced output (```json or ```yaml). When True, the model is prompted to generate fenced output and the resolver expects it. When False, raw JSON/YAML is expected. Defaults to False. If format_handler is provided, it takes precedence. format_handler: Optional FormatHandler for managing format-specific logic. """ self._language_model = language_model if format_handler is None: format_handler = fh.FormatHandler( format_type=format_type, use_wrapper=True, wrapper_key=data.EXTRACTIONS_KEY, use_fences=fence_output, attribute_suffix=attribute_suffix, ) self._prompt_generator = prompting.QAPromptGenerator( template=prompt_template, format_handler=format_handler, ) logging.debug( "Annotator initialized with format_handler: %s", format_handler ) def annotate_documents( self, documents: Iterable[data.Document], resolver: resolver_lib.AbstractResolver | None = None, max_char_buffer: int = 200, batch_length: int = 1, debug: bool = True, extraction_passes: int = 1, context_window_chars: int | None = None, show_progress: bool = True, tokenizer: tokenizer_lib.Tokenizer | None = None, **kwargs, ) -> Iterator[data.AnnotatedDocument]: """Annotates a sequence of documents with NLP extractions. Breaks documents into chunks, processes them into prompts and performs batched inference, mapping annotated extractions back to the original document. Batch processing is determined by batch_length, and can operate across documents for optimized throughput. Args: documents: Documents to annotate. Each document is expected to have a unique document_id. resolver: Resolver to use for extracting information from text. max_char_buffer: Max number of characters that we can run inference on. The text will be broken into chunks up to this length. batch_length: Number of chunks to process in a single batch. debug: Whether to populate debug fields. extraction_passes: Number of sequential extraction attempts to improve recall by finding additional entities. Defaults to 1, which performs standard single extraction. Values > 1 reprocess tokens multiple times, potentially increasing costs with the potential for a more thorough extraction. context_window_chars: Number of characters from the previous chunk to include as context for the current chunk. Helps with coreference resolution across chunk boundaries. Defaults to None (disabled). show_progress: Whether to show progress bar. Defaults to True. tokenizer: Optional tokenizer to use. If None, uses default tokenizer. **kwargs: Additional arguments passed to LanguageModel.infer and Resolver. Yields: Resolved annotations from input documents. Raises: ValueError: If there are no scored outputs during inference. """ if resolver is None: resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML) if extraction_passes == 1: yield from self._annotate_documents_single_pass( documents, resolver, max_char_buffer, batch_length, debug, show_progress, context_window_chars=context_window_chars, tokenizer=tokenizer, **kwargs, ) else: yield from self._annotate_documents_sequential_passes( documents, resolver, max_char_buffer, batch_length, debug, extraction_passes, show_progress, context_window_chars=context_window_chars, tokenizer=tokenizer, **kwargs, ) def _annotate_documents_single_pass( self, documents: Iterable[data.Document], resolver: resolver_lib.AbstractResolver, max_char_buffer: int, batch_length: int, debug: bool, show_progress: bool = True, context_window_chars: int | None = None, tokenizer: tokenizer_lib.Tokenizer | None = None, **kwargs, ) -> Iterator[data.AnnotatedDocument]: """Single-pass annotation with stable ordering and streaming emission. Streams input without full materialization, maintains correct attribution across batches, and emits completed documents immediately to minimize peak memory usage. Handles generators from both infer() and align(). When context_window_chars is set, includes text from the previous chunk as context for coreference resolution across chunk boundaries. """ doc_order: list[str] = [] doc_text_by_id: dict[str, str] = {} per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict( list ) next_emit_idx = 0 def _capture_docs(src: Iterable[data.Document]) -> Iterator[data.Document]: """Captures document order and text lazily as chunks are produced.""" for document in src: document_id = document.document_id if document_id in doc_text_by_id: raise exceptions.InvalidDocumentError( f"Duplicate document_id: {document_id}" ) doc_order.append(document_id) doc_text_by_id[document_id] = document.text or "" yield document def _emit_docs_iter( keep_last_doc: bool, ) -> Iterator[data.AnnotatedDocument]: """Yields documents that are guaranteed complete. Args: keep_last_doc: If True, retains the most recently started document for additional extractions. If False, emits all remaining documents. """ nonlocal next_emit_idx limit = max(0, len(doc_order) - 1) if keep_last_doc else len(doc_order) while next_emit_idx < limit: document_id = doc_order[next_emit_idx] yield data.AnnotatedDocument( document_id=document_id, extractions=per_doc.get(document_id, []), text=doc_text_by_id.get(document_id, ""), ) per_doc.pop(document_id, None) doc_text_by_id.pop(document_id, None) next_emit_idx += 1 chunk_iter = _document_chunk_iterator( _capture_docs(documents), max_char_buffer, tokenizer=tokenizer ) batches = chunking.make_batches_of_textchunk(chunk_iter, batch_length) model_info = progress.get_model_info(self._language_model) batch_iter = progress.create_extraction_progress_bar( batches, model_info=model_info, disable=not show_progress ) chars_processed = 0 prompt_builder = prompting.ContextAwarePromptBuilder( generator=self._prompt_generator, context_window_chars=context_window_chars, ) try: for batch in batch_iter: if not batch: continue prompts = [ prompt_builder.build_prompt( chunk.chunk_text, chunk.document_id, chunk.additional_context ) for chunk in batch ] if show_progress: current_chars = sum( len(text_chunk.chunk_text) for text_chunk in batch ) try: batch_iter.set_description( progress.format_extraction_progress( model_info, current_chars=current_chars, processed_chars=chars_processed, ) ) except AttributeError: pass outputs = self._language_model.infer(batch_prompts=prompts, **kwargs) if not isinstance(outputs, list): outputs = list(outputs) for text_chunk, scored_outputs in zip(batch, outputs): if not isinstance(scored_outputs, list): scored_outputs = list(scored_outputs) if not scored_outputs: raise exceptions.InferenceOutputError( "No scored outputs from language model." ) resolved_extractions = resolver.resolve( scored_outputs[0].output, debug=debug, **kwargs ) token_offset = ( text_chunk.token_interval.start_index if text_chunk.token_interval else 0 ) char_offset = ( text_chunk.char_interval.start_pos if text_chunk.char_interval else 0 ) aligned_extractions = resolver.align( resolved_extractions, text_chunk.chunk_text, token_offset, char_offset, tokenizer_inst=tokenizer, **kwargs, ) for extraction in aligned_extractions: per_doc[text_chunk.document_id].append(extraction) if show_progress and text_chunk.char_interval is not None: chars_processed += ( text_chunk.char_interval.end_pos - text_chunk.char_interval.start_pos ) yield from _emit_docs_iter(keep_last_doc=True) finally: batch_iter.close() yield from _emit_docs_iter(keep_last_doc=False) def _annotate_documents_sequential_passes( self, documents: Iterable[data.Document], resolver: resolver_lib.AbstractResolver, max_char_buffer: int, batch_length: int, debug: bool, extraction_passes: int, show_progress: bool = True, context_window_chars: int | None = None, tokenizer: tokenizer_lib.Tokenizer | None = None, **kwargs, ) -> Iterator[data.AnnotatedDocument]: """Sequential extraction passes logic for improved recall.""" logging.info( "Starting sequential extraction passes for improved recall with %d" " passes.", extraction_passes, ) document_list = list(documents) document_extractions_by_pass: dict[str, list[list[data.Extraction]]] = {} document_texts: dict[str, str] = {} # Preserve text up-front so we can emit documents even if later passes # produce no extractions. for _doc in document_list: document_texts[_doc.document_id] = _doc.text or "" for pass_num in range(extraction_passes): logging.info( "Starting extraction pass %d of %d", pass_num + 1, extraction_passes ) for annotated_doc in self._annotate_documents_single_pass( document_list, resolver, max_char_buffer, batch_length, debug=(debug and pass_num == 0), show_progress=show_progress if pass_num == 0 else False, context_window_chars=context_window_chars, tokenizer=tokenizer, **kwargs, ): doc_id = annotated_doc.document_id if doc_id not in document_extractions_by_pass: document_extractions_by_pass[doc_id] = [] # Keep first-seen text (already pre-filled above). document_extractions_by_pass[doc_id].append( annotated_doc.extractions or [] ) # Emit results strictly in original input order. for doc in document_list: doc_id = doc.document_id all_pass_extractions = document_extractions_by_pass.get(doc_id, []) merged_extractions = _merge_non_overlapping_extractions( all_pass_extractions ) if debug: total_extractions = sum( len(extractions) for extractions in all_pass_extractions ) logging.info( "Document %s: Merged %d extractions from %d passes into " "%d non-overlapping extractions.", doc_id, total_extractions, extraction_passes, len(merged_extractions), ) yield data.AnnotatedDocument( document_id=doc_id, extractions=merged_extractions, text=document_texts.get(doc_id, doc.text or ""), ) logging.info("Sequential extraction passes completed.") def annotate_text( self, text: str, resolver: resolver_lib.AbstractResolver | None = None, max_char_buffer: int = 200, batch_length: int = 1, additional_context: str | None = None, debug: bool = True, extraction_passes: int = 1, context_window_chars: int | None = None, show_progress: bool = True, tokenizer: tokenizer_lib.Tokenizer | None = None, **kwargs, ) -> data.AnnotatedDocument: """Annotates text with NLP extractions for text input. Args: text: Source text to annotate. resolver: Resolver to use for extracting information from text. max_char_buffer: Max number of characters that we can run inference on. The text will be broken into chunks up to this length. batch_length: Number of chunks to process in a single batch. additional_context: Additional context to supplement prompt instructions. debug: Whether to populate debug fields. extraction_passes: Number of sequential extraction passes to improve recall by finding additional entities. Defaults to 1, which performs standard single extraction. Values > 1 reprocess tokens multiple times, potentially increasing costs. context_window_chars: Number of characters from the previous chunk to include as context for coreference resolution. Defaults to None (disabled). show_progress: Whether to show progress bar. Defaults to True. tokenizer: Optional tokenizer instance. **kwargs: Additional arguments for inference and resolver_lib. Returns: Resolved annotations from text for document. """ if resolver is None: resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, ) start_time = time.time() if debug else None documents = [ data.Document( text=text, document_id=None, additional_context=additional_context, ) ] annotations = list( self.annotate_documents( documents=documents, resolver=resolver, max_char_buffer=max_char_buffer, batch_length=batch_length, debug=debug, extraction_passes=extraction_passes, context_window_chars=context_window_chars, show_progress=show_progress, tokenizer=tokenizer, **kwargs, ) ) assert ( len(annotations) == 1 ), f"Expected 1 annotation but got {len(annotations)} annotations." if debug and annotations[0].extractions: elapsed_time = time.time() - start_time if start_time else None num_extractions = len(annotations[0].extractions) unique_classes = len( set(e.extraction_class for e in annotations[0].extractions) ) num_chunks = len(text) // max_char_buffer + ( 1 if len(text) % max_char_buffer else 0 ) progress.print_extraction_summary( num_extractions, unique_classes, elapsed_time=elapsed_time, chars_processed=len(text), num_chunks=num_chunks, ) return data.AnnotatedDocument( document_id=annotations[0].document_id, extractions=annotations[0].extractions, text=annotations[0].text, ) ================================================ FILE: langextract/chunking.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library for breaking documents into chunks of sentences. When a text-to-text model (e.g. a large language model with a fixed context size) can not accommodate a large document, this library can help us break the document into chunks of a required maximum length that we can perform inference on. """ from collections.abc import Iterable, Iterator, Sequence import dataclasses import re from absl import logging import more_itertools from langextract.core import data from langextract.core import exceptions from langextract.core import tokenizer as tokenizer_lib class TokenUtilError(exceptions.LangExtractError): """Error raised when token_util returns unexpected values.""" @dataclasses.dataclass class TextChunk: """Stores a text chunk with attributes to the source document. Attributes: token_interval: The token interval of the chunk in the source document. document: The source document. """ token_interval: tokenizer_lib.TokenInterval document: data.Document | None = None _chunk_text: str | None = dataclasses.field( default=None, init=False, repr=False ) _sanitized_chunk_text: str | None = dataclasses.field( default=None, init=False, repr=False ) _char_interval: data.CharInterval | None = dataclasses.field( default=None, init=False, repr=False ) def __str__(self): interval_repr = ( f"start_index: {self.token_interval.start_index}, end_index:" f" {self.token_interval.end_index}" ) doc_id_repr = ( f"Document ID: {self.document_id}" if self.document_id else "Document ID: None" ) try: chunk_text_repr = f"'{self.chunk_text}'" except ValueError: chunk_text_repr = "" return ( "TextChunk(\n" f" interval=[{interval_repr}],\n" f" {doc_id_repr},\n" f" Chunk Text: {chunk_text_repr}\n" ")" ) @property def document_id(self) -> str | None: """Gets the document ID from the source document.""" if self.document is not None: return self.document.document_id return None @property def document_text(self) -> tokenizer_lib.TokenizedText | None: """Gets the tokenized text from the source document.""" if self.document is not None: return self.document.tokenized_text return None @property def chunk_text(self) -> str: """Gets the chunk text. Raises an error if `document_text` is not set.""" if self.document_text is None: raise ValueError("document_text must be set to access chunk_text.") if self._chunk_text is None: self._chunk_text = get_token_interval_text( self.document_text, self.token_interval ) return self._chunk_text @property def sanitized_chunk_text(self) -> str: """Gets the sanitized chunk text.""" if self._sanitized_chunk_text is None: self._sanitized_chunk_text = _sanitize(self.chunk_text) return self._sanitized_chunk_text @property def additional_context(self) -> str | None: """Gets the additional context for prompting from the source document.""" if self.document is not None: return self.document.additional_context return None @property def char_interval(self) -> data.CharInterval: """Gets the character interval corresponding to the token interval. Returns: data.CharInterval: The character interval for this chunk. Raises: ValueError: If document_text is not set. """ if self._char_interval is None: if self.document_text is None: raise ValueError("document_text must be set to compute char_interval.") self._char_interval = get_char_interval( self.document_text, self.token_interval ) return self._char_interval def create_token_interval( start_index: int, end_index: int ) -> tokenizer_lib.TokenInterval: """Creates a token interval. Args: start_index: first token's index (inclusive). end_index: last token's index + 1 (exclusive). Returns: Token interval. Raises: ValueError: If the token indices are invalid. """ if start_index < 0: raise ValueError(f"Start index {start_index} must be positive.") if start_index >= end_index: raise ValueError( f"Start index {start_index} must be < end index {end_index}." ) return tokenizer_lib.TokenInterval( start_index=start_index, end_index=end_index ) def get_token_interval_text( tokenized_text: tokenizer_lib.TokenizedText, token_interval: tokenizer_lib.TokenInterval, ) -> str: """Get the text within an interval of tokens. Args: tokenized_text: Tokenized documents. token_interval: An interval specifying the start (inclusive) and end (exclusive) indices of the tokens to extract. These indices refer to the positions in the list of tokens within `tokenized_text.tokens`, not the value of the field `index` of `token_pb2.Token`. If the tokens are [(index:0, text:A), (index:5, text:B), (index:10, text:C)], we should use token_interval=[0, 2] to represent taking A and B, not [0, 6]. Please see details from the implementation of tokenizer_lib.tokens_text Returns: Text within the token interval. Raises: ValueError: If the token indices are invalid. TokenUtilError: If tokenizer_lib.tokens_text returns an empty string. """ if token_interval.start_index >= token_interval.end_index: raise ValueError( f"Start index {token_interval.start_index} must be < end index " f"{token_interval.end_index}." ) return_string = tokenizer_lib.tokens_text(tokenized_text, token_interval) logging.debug( "Token util returns string: %s for tokenized_text: %s, token_interval:" " %s", return_string, tokenized_text, token_interval, ) if tokenized_text.text and not return_string: raise TokenUtilError( "Token util returns an empty string unexpectedly. Number of tokens is" f" tokenized_text: {len(tokenized_text.tokens)}, token_interval is" f" {token_interval.start_index} to {token_interval.end_index}, which" " should not lead to empty string." ) return return_string def get_char_interval( tokenized_text: tokenizer_lib.TokenizedText, token_interval: tokenizer_lib.TokenInterval, ) -> data.CharInterval: """Returns the char interval corresponding to the token interval. Args: tokenized_text: Document. token_interval: Token interval. Returns: Char interval of the token interval of interest. Raises: ValueError: If the token_interval is invalid. """ if token_interval.start_index >= token_interval.end_index: raise ValueError( f"Start index {token_interval.start_index} must be < end index " f"{token_interval.end_index}." ) start_token = tokenized_text.tokens[token_interval.start_index] # Penultimate token prior to interval.end_index final_token = tokenized_text.tokens[token_interval.end_index - 1] return data.CharInterval( start_pos=start_token.char_interval.start_pos, end_pos=final_token.char_interval.end_pos, ) def _sanitize(text: str) -> str: """Converts all whitespace characters in input text to a single space. Args: text: Input to sanitize. Returns: Sanitized text with newlines and excess spaces removed. Raises: ValueError: If the sanitized text is empty. """ sanitized_text = re.sub(r"\s+", " ", text.strip()) if not sanitized_text: raise ValueError("Sanitized text is empty.") return sanitized_text def make_batches_of_textchunk( chunk_iter: Iterator[TextChunk], batch_length: int, ) -> Iterable[Sequence[TextChunk]]: """Processes chunks into batches of TextChunk for inference, using itertools.batched. Args: chunk_iter: Iterator of TextChunks. batch_length: Number of chunks to include in each batch. Yields: Batches of TextChunks. """ for batch in more_itertools.batched(chunk_iter, batch_length): yield list(batch) class SentenceIterator: """Iterate through sentences of a tokenized text.""" def __init__( self, tokenized_text: tokenizer_lib.TokenizedText, curr_token_pos: int = 0, ): """Constructor. Args: tokenized_text: Document to iterate through. curr_token_pos: Iterate through sentences from this token position. Raises: IndexError: if curr_token_pos is not within the document. """ self.tokenized_text = tokenized_text self.token_len = len(tokenized_text.tokens) if curr_token_pos < 0: raise IndexError( f"Current token position {curr_token_pos} can not be negative." ) elif curr_token_pos > self.token_len: raise IndexError( f"Current token position {curr_token_pos} is past the length of the " f"document {self.token_len}." ) self.curr_token_pos = curr_token_pos def __iter__(self) -> Iterator[tokenizer_lib.TokenInterval]: return self def __next__(self) -> tokenizer_lib.TokenInterval: """Returns next sentence's interval starting from current token position. Returns: Next sentence token interval starting from current token position. Raises: StopIteration: If end of text is reached. """ assert self.curr_token_pos <= self.token_len if self.curr_token_pos == self.token_len: raise StopIteration # This locates the sentence which contains the current token position. sentence_range = tokenizer_lib.find_sentence_range( self.tokenized_text.text, self.tokenized_text.tokens, self.curr_token_pos, ) assert sentence_range # Start the sentence from the current token position. # If we are in the middle of a sentence, we should start from there. sentence_range = create_token_interval( self.curr_token_pos, sentence_range.end_index ) self.curr_token_pos = sentence_range.end_index return sentence_range class ChunkIterator: r"""Iterate through chunks of a tokenized text. Chunks may consist of sentences or sentence fragments that can fit into the maximum character buffer that we can run inference on. A) If a sentence length exceeds the max char buffer, then it needs to be broken into chunks that can fit within the max char buffer. We do this in a way that maximizes the chunk length while respecting newlines (if present) and token boundaries. Consider this sentence from a poem by John Donne: ``` No man is an island, Entire of itself, Every man is a piece of the continent, A part of the main. ``` With max_char_buffer=40, the chunks are: * "No man is an island,\nEntire of itself," len=38 * "Every man is a piece of the continent," len=38 * "A part of the main." len=19 B) If a single token exceeds the max char buffer, it comprises the whole chunk. Consider the sentence: "This is antidisestablishmentarianism." With max_char_buffer=20, the chunks are: * "This is" len=7 * "antidisestablishmentarianism" len=28 * "." len(1) C) If multiple *whole* sentences can fit within the max char buffer, then they are used to form the chunk. Consider the sentences: "Roses are red. Violets are blue. Flowers are nice. And so are you." With max_char_buffer=60, the chunks are: * "Roses are red. Violets are blue. Flowers are nice." len=50 * "And so are you." len=15 """ def __init__( self, text: str | tokenizer_lib.TokenizedText | None, max_char_buffer: int, tokenizer_impl: tokenizer_lib.Tokenizer, document: data.Document | None = None, ): """Constructor. Args: text: Document to chunk. Can be either a string or a tokenized text. max_char_buffer: Size of buffer that we can run inference on. tokenizer_impl: Tokenizer instance to use. document: Optional source document. """ if text is None: if document is None: raise ValueError("Either text or document must be provided.") text = document.text or "" if isinstance(text, str): text = tokenizer_impl.tokenize(text) elif isinstance(text, tokenizer_lib.TokenizedText) and not text.tokens: text_to_tokenize = text.text or (document.text if document else "") text = tokenizer_impl.tokenize(text_to_tokenize) self.tokenized_text = text self.max_char_buffer = max_char_buffer self.sentence_iter = SentenceIterator(self.tokenized_text) self.broken_sentence = False # TODO: Refactor redundancy between document and text. if document is None: self.document = data.Document(text=text.text) else: self.document = document self.document.tokenized_text = self.tokenized_text def __iter__(self) -> Iterator[TextChunk]: return self def _tokens_exceed_buffer( self, token_interval: tokenizer_lib.TokenInterval ) -> bool: """Check if the token interval exceeds the maximum buffer size. Args: token_interval: Token interval to check. Returns: True if the token interval exceeds the maximum buffer size. """ char_interval = get_char_interval(self.tokenized_text, token_interval) return ( char_interval.end_pos - char_interval.start_pos ) > self.max_char_buffer def __next__(self) -> TextChunk: sentence = next(self.sentence_iter) # If the next token is greater than the max_char_buffer, let it be the # entire chunk. curr_chunk = create_token_interval( sentence.start_index, sentence.start_index + 1 ) if self._tokens_exceed_buffer(curr_chunk): self.sentence_iter = SentenceIterator( self.tokenized_text, curr_token_pos=sentence.start_index + 1 ) self.broken_sentence = curr_chunk.end_index < sentence.end_index return TextChunk( token_interval=curr_chunk, document=self.document, ) # Append tokens to the chunk up to the max_char_buffer. start_of_new_line = -1 for token_index in range(curr_chunk.start_index, sentence.end_index): if self.tokenized_text.tokens[token_index].first_token_after_newline: start_of_new_line = token_index test_chunk = create_token_interval( curr_chunk.start_index, token_index + 1 ) if self._tokens_exceed_buffer(test_chunk): # Only break at newline if: 1) newline exists (> 0) and # 2) it's after chunk start (prevents empty intervals) if start_of_new_line > 0 and start_of_new_line > curr_chunk.start_index: # Terminate the curr_chunk at the start of the most recent newline. curr_chunk = create_token_interval( curr_chunk.start_index, start_of_new_line ) self.sentence_iter = SentenceIterator( self.tokenized_text, curr_token_pos=curr_chunk.end_index ) self.broken_sentence = True return TextChunk( token_interval=curr_chunk, document=self.document, ) else: curr_chunk = test_chunk if self.broken_sentence: self.broken_sentence = False else: for sentence in self.sentence_iter: test_chunk = create_token_interval( curr_chunk.start_index, sentence.end_index ) if self._tokens_exceed_buffer(test_chunk): self.sentence_iter = SentenceIterator( self.tokenized_text, curr_token_pos=curr_chunk.end_index ) return TextChunk( token_interval=curr_chunk, document=self.document, ) else: curr_chunk = test_chunk return TextChunk( token_interval=curr_chunk, document=self.document, ) ================================================ FILE: langextract/core/__init__.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Core abstractions for LangExtract. This package contains the foundational base models and types used throughout LangExtract. Each module can be imported independently for fine-grained dependency management in build systems. """ from __future__ import annotations __all__ = [ "base_model", "types", "exceptions", "schema", "data", "tokenizer", ] ================================================ FILE: langextract/core/base_model.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base interfaces for language models.""" from __future__ import annotations import abc from collections.abc import Iterator, Sequence import json from typing import Any, Mapping import yaml from langextract.core import schema from langextract.core import types __all__ = ['BaseLanguageModel'] class BaseLanguageModel(abc.ABC): """An abstract inference class for managing LLM inference. Attributes: _constraint: A `Constraint` object specifying constraints for model output. """ def __init__(self, constraint: types.Constraint | None = None, **kwargs: Any): """Initializes the BaseLanguageModel with an optional constraint. Args: constraint: Applies constraints when decoding the output. Defaults to no constraint. **kwargs: Additional keyword arguments passed to the model. """ self._constraint = constraint or types.Constraint() self._schema: schema.BaseSchema | None = None self._fence_output_override: bool | None = None self._extra_kwargs: dict[str, Any] = kwargs.copy() @classmethod def get_schema_class(cls) -> type[Any] | None: """Return the schema class this provider supports.""" return None def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None: """Apply a schema instance to this provider. Optional method that providers can override to store the schema instance for runtime use. The default implementation stores it as _schema. Args: schema_instance: The schema instance to apply, or None to clear. """ self._schema = schema_instance @property def schema(self) -> schema.BaseSchema | None: """The current schema instance if one is configured. Returns: The schema instance or None if no schema is applied. """ return self._schema def set_fence_output(self, fence_output: bool | None) -> None: """Set explicit fence output preference. Args: fence_output: True to force fences, False to disable, None for auto. """ if not hasattr(self, '_fence_output_override'): self._fence_output_override = None self._fence_output_override = fence_output @property def requires_fence_output(self) -> bool: """Whether this model requires fence output for parsing. Uses explicit override if set, otherwise computes from schema. Returns True if no schema or schema doesn't require raw output. """ if ( hasattr(self, '_fence_output_override') and self._fence_output_override is not None ): return self._fence_output_override schema_obj = self.schema if schema_obj is None: return True return not schema_obj.requires_raw_output def merge_kwargs( self, runtime_kwargs: Mapping[str, Any] | None = None ) -> dict[str, Any]: """Merge stored extra kwargs with runtime kwargs. Runtime kwargs take precedence over stored kwargs. Args: runtime_kwargs: Kwargs provided at inference time, or None. Returns: Merged kwargs dictionary. """ base = getattr(self, '_extra_kwargs', {}) or {} incoming = dict(runtime_kwargs or {}) return {**base, **incoming} @abc.abstractmethod def infer( self, batch_prompts: Sequence[str], **kwargs ) -> Iterator[Sequence[types.ScoredOutput]]: """Implements language model inference. Args: batch_prompts: Batch of inputs for inference. Single element list can be used for a single input. **kwargs: Additional arguments for inference, like temperature and max_decode_steps. Returns: Batch of Sequence of probable output text outputs, sorted by descending score. """ def infer_batch( self, prompts: Sequence[str], batch_size: int = 32 # pylint: disable=unused-argument ) -> list[list[types.ScoredOutput]]: """Batch inference with configurable batch size. This is a convenience method that collects all results from infer(). Args: prompts: List of prompts to process. batch_size: Batch size (currently unused, for future optimization). Returns: List of lists of ScoredOutput objects. """ results = [] for output in self.infer(prompts): results.append(list(output)) return results def parse_output(self, output: str) -> Any: """Parses model output as JSON or YAML. Note: This expects raw JSON/YAML without code fences. Code fence extraction is handled by resolver.py. Args: output: Raw output string from the model. Returns: Parsed Python object (dict or list). Raises: ValueError: If output cannot be parsed as JSON or YAML. """ # Check if we have a format_type attribute (providers should set this) format_type = getattr(self, 'format_type', types.FormatType.JSON) try: if format_type == types.FormatType.JSON: return json.loads(output) else: return yaml.safe_load(output) except Exception as e: raise ValueError( f'Failed to parse output as {format_type.name}: {str(e)}' ) from e ================================================ FILE: langextract/core/data.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Classes used to represent core data types of annotation pipeline.""" from __future__ import annotations import dataclasses import enum import uuid from langextract.core import tokenizer from langextract.core import types FormatType = types.FormatType # Backward compat EXTRACTIONS_KEY = "extractions" ATTRIBUTE_SUFFIX = "_attributes" __all__ = [ "AlignmentStatus", "CharInterval", "Extraction", "Document", "AnnotatedDocument", "ExampleData", "FormatType", "EXTRACTIONS_KEY", "ATTRIBUTE_SUFFIX", ] class AlignmentStatus(enum.Enum): MATCH_EXACT = "match_exact" MATCH_GREATER = "match_greater" MATCH_LESSER = "match_lesser" MATCH_FUZZY = "match_fuzzy" @dataclasses.dataclass class CharInterval: """Class for representing a character interval. Attributes: start_pos: The starting position of the interval (inclusive). end_pos: The ending position of the interval (exclusive). """ start_pos: int | None = None end_pos: int | None = None @dataclasses.dataclass(init=False) class Extraction: """Represents an extraction extracted from text. This class encapsulates an extraction's characteristics and its position within the source text. It can represent a diverse range of information for NLP information extraction tasks. Attributes: extraction_class: The class of the extraction. extraction_text: The text of the extraction. char_interval: The character interval of the extraction in the original text. alignment_status: The alignment status of the extraction. extraction_index: The index of the extraction in the list of extractions. group_index: The index of the group the extraction belongs to. description: A description of the extraction. attributes: A list of attributes of the extraction. token_interval: The token interval of the extraction. """ extraction_class: str extraction_text: str char_interval: CharInterval | None = None alignment_status: AlignmentStatus | None = None extraction_index: int | None = None group_index: int | None = None description: str | None = None attributes: dict[str, str | list[str]] | None = None _token_interval: tokenizer.TokenInterval | None = dataclasses.field( default=None, repr=False, compare=False ) def __init__( self, extraction_class: str, extraction_text: str, *, token_interval: tokenizer.TokenInterval | None = None, char_interval: CharInterval | None = None, alignment_status: AlignmentStatus | None = None, extraction_index: int | None = None, group_index: int | None = None, description: str | None = None, attributes: dict[str, str | list[str]] | None = None, ): self.extraction_class = extraction_class self.extraction_text = extraction_text self.char_interval = char_interval self._token_interval = token_interval self.alignment_status = alignment_status self.extraction_index = extraction_index self.group_index = group_index self.description = description self.attributes = attributes @property def token_interval(self) -> tokenizer.TokenInterval | None: return self._token_interval @token_interval.setter def token_interval(self, value: tokenizer.TokenInterval | None) -> None: self._token_interval = value @dataclasses.dataclass class Document: """Document class for annotating documents. Attributes: text: Raw text representation for the document. document_id: Unique identifier for each document and is auto-generated if not set. additional_context: Additional context to supplement prompt instructions. tokenized_text: Tokenized text for the document, computed from `text`. """ text: str additional_context: str | None = None _document_id: str | None = dataclasses.field( default=None, init=False, repr=False, compare=False ) _tokenized_text: tokenizer.TokenizedText | None = dataclasses.field( init=False, default=None, repr=False, compare=False ) def __init__( self, text: str, *, document_id: str | None = None, additional_context: str | None = None, ): self.text = text self.additional_context = additional_context self._document_id = document_id @property def document_id(self) -> str: """Returns the document ID, generating a unique one if not set.""" if self._document_id is None: self._document_id = f"doc_{uuid.uuid4().hex[:8]}" return self._document_id @document_id.setter def document_id(self, value: str | None) -> None: """Sets the document ID.""" self._document_id = value @property def tokenized_text(self) -> tokenizer.TokenizedText: if self._tokenized_text is None: self._tokenized_text = tokenizer.tokenize(self.text) return self._tokenized_text @tokenized_text.setter def tokenized_text(self, value: tokenizer.TokenizedText) -> None: self._tokenized_text = value @dataclasses.dataclass class AnnotatedDocument: """Class for representing annotated documents. Attributes: document_id: Unique identifier for each document - autogenerated if not set. extractions: List of extractions in the document. text: Raw text representation of the document. tokenized_text: Tokenized text of the document, computed from `text`. """ extractions: list[Extraction] | None = None text: str | None = None _document_id: str | None = dataclasses.field( default=None, init=False, repr=False, compare=False ) _tokenized_text: tokenizer.TokenizedText | None = dataclasses.field( init=False, default=None, repr=False, compare=False ) def __init__( self, *, document_id: str | None = None, extractions: list[Extraction] | None = None, text: str | None = None, ): self.extractions = extractions self.text = text self._document_id = document_id @property def document_id(self) -> str: """Returns the document ID, generating a unique one if not set.""" if self._document_id is None: self._document_id = f"doc_{uuid.uuid4().hex[:8]}" return self._document_id @document_id.setter def document_id(self, value: str | None) -> None: """Sets the document ID.""" self._document_id = value @property def tokenized_text(self) -> tokenizer.TokenizedText | None: if self._tokenized_text is None and self.text is not None: self._tokenized_text = tokenizer.tokenize(self.text) return self._tokenized_text @tokenized_text.setter def tokenized_text(self, value: tokenizer.TokenizedText) -> None: self._tokenized_text = value @dataclasses.dataclass class ExampleData: """A single training/example data instance for a structured prompting. Attributes: text: The raw input text (sentence, paragraph, etc.). extractions: A list of Extraction objects extracted from the text. """ text: str extractions: list[Extraction] = dataclasses.field(default_factory=list) ================================================ FILE: langextract/core/debug_utils.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Debug utilities for LangExtract.""" from __future__ import annotations import functools import inspect import logging import reprlib import time from typing import Any, Callable, Mapping from absl import logging as absl_logging _LOG = logging.getLogger("langextract.debug") # Add NullHandler to prevent "No handler found" warnings _langextract_logger = logging.getLogger("langextract") if not _langextract_logger.handlers: _langextract_logger.addHandler(logging.NullHandler()) # Sensitive keys to redact _REDACT_KEYS = { "api_key", "apikey", "token", "secret", "password", "authorization", "bearer", "jwt", } _MAX_STR = 500 _MAX_SEQ = 20 def _safe_repr(obj: Any) -> str: """Truncate object repr for safe logging.""" r = reprlib.Repr() r.maxstring = _MAX_STR r.maxlist = r.maxtuple = r.maxset = r.maxdict = _MAX_SEQ return r.repr(obj) def _redact_value(name: str, value: Any) -> str: """Redact sensitive values based on parameter name.""" if isinstance(name, str) and name.lower() in _REDACT_KEYS: return "" # If a nested mapping, redact its sensitive keys too if isinstance(value, Mapping): redacted = {} for k, v in value.items(): if isinstance(k, str) and k.lower() in _REDACT_KEYS: redacted[k] = "" else: redacted[k] = _safe_repr(v) return _safe_repr(redacted) return _safe_repr(value) def _redact_mapping(mapping: Mapping[str, Any]) -> dict[str, str]: """Replace sensitive values with .""" out = {} for k, v in mapping.items(): out[k] = _redact_value(k, v) return out def _format_bound_args( fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> str: """Format function arguments using signature inspection.""" try: sig = inspect.signature(fn) bound = sig.bind_partial(*args, **kwargs) bound.apply_defaults() except Exception: # Fallback (no names) if binding fails parts = [_safe_repr(a) for a in args] if kwargs: red = _redact_mapping(kwargs) parts += [f"{k}={v}" for k, v in sorted(red.items())] return ", ".join(parts) parts: list[str] = [] for name, value in bound.arguments.items(): if name in ("self", "cls"): parts.append(f"{name}=<{type(value).__name__}>") else: parts.append(f"{name}={_redact_value(name, value)}") return ", ".join(parts) def debug_log_calls(fn: Callable) -> Callable: """Log function calls with redacted sensitive data and timing. Automatically redacts api_key, token, etc. and truncates large outputs. """ @functools.wraps(fn) def wrapper(*args, **kwargs): logger = _LOG if not logger.isEnabledFor(logging.DEBUG): return fn(*args, **kwargs) fn_qual = getattr(fn, "__qualname__", fn.__name__) mod = getattr(fn, "__module__", "") # Format arguments using signature inspection arg_str = _format_bound_args(fn, args, kwargs) logger.debug("[%s] CALL: %s(%s)", mod, fn_qual, arg_str, stacklevel=2) start = time.perf_counter() try: result = fn(*args, **kwargs) except Exception: dur_ms = (time.perf_counter() - start) * 1000 logger.exception( "[%s] EXCEPTION: %s (%.1f ms)", mod, fn_qual, dur_ms, stacklevel=2 ) raise dur_ms = (time.perf_counter() - start) * 1000 result_repr = _safe_repr(result) logger.debug( "[%s] RETURN: %s -> %s (%.1f ms)", mod, fn_qual, result_repr, dur_ms, stacklevel=2, ) return result return wrapper def configure_debug_logging() -> None: """Enable debug logging for the 'langextract' namespace only.""" logger = logging.getLogger("langextract") # Skip if we already added our handler our_handler_exists = any( isinstance(h, logging.StreamHandler) and getattr(h, "langextract_debug", False) for h in logger.handlers ) if our_handler_exists: return # Respect host handlers - only set level if they exist non_null_handlers = [ h for h in logger.handlers if not isinstance(h, logging.NullHandler) ] if non_null_handlers: logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() handler.setLevel(logging.DEBUG) fmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" handler.setFormatter(logging.Formatter(fmt)) handler.langextract_debug = True logger.addHandler(handler) logger.propagate = False # Best-effort absl configuration try: absl_logging.set_verbosity(absl_logging.DEBUG) except Exception: pass ================================================ FILE: langextract/core/exceptions.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Core error types for LangExtract. This module defines all base exceptions for LangExtract. These are the foundational error types that are used throughout the codebase. """ from __future__ import annotations __all__ = [ "LangExtractError", "InferenceError", "InferenceConfigError", "InferenceRuntimeError", "InferenceOutputError", "InternalError", "InvalidDocumentError", "ProviderError", "SchemaError", "FormatError", "FormatParseError", ] class LangExtractError(Exception): """Base exception for all LangExtract errors. All exceptions raised by LangExtract should inherit from this class. This allows users to catch all LangExtract-specific errors with a single except clause. """ class InferenceError(LangExtractError): """Base exception for inference-related errors.""" class InferenceConfigError(InferenceError): """Exception raised for configuration errors. This includes missing API keys, invalid model IDs, or other configuration-related issues that prevent model instantiation. """ class InferenceRuntimeError(InferenceError): """Exception raised for runtime inference errors. This includes API call failures, network errors, or other issues that occur during inference execution. """ def __init__( self, message: str, *, original: BaseException | None = None, provider: str | None = None, ) -> None: """Initialize the runtime error. Args: message: Error message. original: Original exception from the provider SDK. provider: Name of the provider that raised the error. """ super().__init__(message) self.original = original self.provider = provider class InferenceOutputError(LangExtractError): """Exception raised when no scored outputs are available from the language model.""" def __init__(self, message: str): self.message = message super().__init__(self.message) class InvalidDocumentError(LangExtractError): """Exception raised when document input is invalid. This includes cases like duplicate document IDs or malformed documents. """ class InternalError(LangExtractError): """Exception raised for internal invariant violations. This indicates a bug in LangExtract itself rather than user error. """ class ProviderError(LangExtractError): """Provider/backend specific error.""" class SchemaError(LangExtractError): """Schema validation/serialization error.""" class FormatError(LangExtractError): """Base exception for format handling errors.""" class FormatParseError(FormatError): """Raised when format parsing fails. This consolidates all parsing errors including: - Missing fence markers when required - Multiple fenced blocks - JSON/YAML decode errors - Missing wrapper keys - Invalid structure """ ================================================ FILE: langextract/core/format_handler.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Centralized format handler for prompts and parsing.""" from __future__ import annotations import json import re from typing import Mapping, Sequence import warnings import yaml from langextract.core import data from langextract.core import exceptions ExtractionValueType = str | int | float | dict | list | None _JSON_FORMAT = "json" _YAML_FORMAT = "yaml" _YML_FORMAT = "yml" _FENCE_START = r"```" _LANGUAGE_TAG = r"(?P[A-Za-z0-9_+-]+)?" _FENCE_NEWLINE = r"(?:\s*\n)?" _FENCE_BODY = r"(?P[\s\S]*?)" _FENCE_END = r"```" _FENCE_RE = re.compile( _FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END, re.MULTILINE, ) _THINK_TAG_RE = re.compile(r"[\s\S]*?\s*", re.IGNORECASE) class FormatHandler: """Handles all format-specific logic for prompts and parsing. This class centralizes format handling for JSON and YAML outputs, including fence detection, wrapper management, and parsing. Attributes: format_type: The output format ('json' or 'yaml'). use_wrapper: Whether to wrap extractions in a container dictionary. wrapper_key: The key name for the container dictionary (e.g., creates {"extractions": [...]} instead of just [...]). use_fences: Whether to use code fences in formatted output. attribute_suffix: Suffix for attribute fields in extractions. strict_fences: Whether to enforce strict fence validation. allow_top_level_list: Whether to allow top-level lists in parsing. """ def __init__( self, format_type: data.FormatType = data.FormatType.JSON, use_wrapper: bool = True, wrapper_key: str | None = None, use_fences: bool = True, attribute_suffix: str = data.ATTRIBUTE_SUFFIX, strict_fences: bool = False, allow_top_level_list: bool = True, ) -> None: """Initialize format handler. Args: format_type: Output format type enum. use_wrapper: Whether to wrap extractions in a container dictionary. True: {"extractions": [...]}, False: [...] wrapper_key: Key name for the container dictionary. When use_wrapper=True: - If None: defaults to EXTRACTIONS_KEY ("extractions") - If provided: uses the specified key as container When use_wrapper=False, this parameter is ignored. use_fences: Whether to use ```json or ```yaml fences. attribute_suffix: Suffix for attribute fields. strict_fences: If True, require exact fence format. If False, be lenient with model output variations. allow_top_level_list: Allow top-level list when not strict and wrapper not required. """ self.format_type = format_type self.use_wrapper = use_wrapper if use_wrapper: self.wrapper_key = ( wrapper_key if wrapper_key is not None else data.EXTRACTIONS_KEY ) else: self.wrapper_key = None self.use_fences = use_fences self.attribute_suffix = attribute_suffix self.strict_fences = strict_fences self.allow_top_level_list = allow_top_level_list def __repr__(self) -> str: return ( "FormatHandler(" f"format_type={self.format_type!r}, use_wrapper={self.use_wrapper}, " f"wrapper_key={self.wrapper_key!r}, use_fences={self.use_fences}, " f"attribute_suffix={self.attribute_suffix!r}, " f"strict_fences={self.strict_fences}, " f"allow_top_level_list={self.allow_top_level_list})" ) def format_extraction_example( self, extractions: list[data.Extraction] ) -> str: """Format extractions for a prompt example. Args: extractions: List of extractions to format Returns: Formatted string for the prompt """ items = [ { ext.extraction_class: ext.extraction_text, f"{ext.extraction_class}{self.attribute_suffix}": ( ext.attributes or {} ), } for ext in extractions ] if self.use_wrapper and self.wrapper_key: payload = {self.wrapper_key: items} else: payload = items if self.format_type == data.FormatType.YAML: formatted = yaml.safe_dump( payload, default_flow_style=False, sort_keys=False ) else: formatted = json.dumps(payload, indent=2, ensure_ascii=False) return self._add_fences(formatted) if self.use_fences else formatted def parse_output( self, text: str, *, strict: bool | None = None ) -> Sequence[Mapping[str, ExtractionValueType]]: """Parse model output to extract data. Args: text: Raw model output. strict: If True, enforce strict schema validation. When strict is True, always require wrapper object if wrapper_key is configured, reject top-level lists even if allow_top_level_list is True, and enforce exact format compliance. Returns: List of extraction dictionaries. Raises: FormatError: Various subclasses for specific parsing failures. """ if not text: raise exceptions.FormatParseError("Empty or invalid input string.") content = self._extract_content(text) try: parsed = self._parse_with_fallback(content, strict) except (yaml.YAMLError, json.JSONDecodeError) as e: msg = ( f"Failed to parse {self.format_type.value.upper()} content:" f" {str(e)[:200]}" ) raise exceptions.FormatParseError(msg) from e if parsed is None: if self.use_wrapper: raise exceptions.FormatParseError( f"Content must be a mapping with an '{self.wrapper_key}' key." ) else: raise exceptions.FormatParseError( "Content must be a list of extractions or a dict." ) require_wrapper = self.wrapper_key is not None and ( self.use_wrapper or bool(strict) ) if isinstance(parsed, dict): if require_wrapper: if self.wrapper_key not in parsed: raise exceptions.FormatParseError( f"Content must contain an '{self.wrapper_key}' key." ) items = parsed[self.wrapper_key] else: if data.EXTRACTIONS_KEY in parsed: items = parsed[data.EXTRACTIONS_KEY] elif self.wrapper_key and self.wrapper_key in parsed: items = parsed[self.wrapper_key] else: items = [parsed] elif isinstance(parsed, list): if require_wrapper and (strict or not self.allow_top_level_list): raise exceptions.FormatParseError( f"Content must be a mapping with an '{self.wrapper_key}' key." ) if strict and self.use_wrapper: raise exceptions.FormatParseError( "Strict mode requires a wrapper object." ) if not self.allow_top_level_list: raise exceptions.FormatParseError("Top-level list is not allowed.") # Some models return [...] instead of {"extractions": [...]}. items = parsed else: raise exceptions.FormatParseError( f"Expected list or dict, got {type(parsed)}" ) if not isinstance(items, list): raise exceptions.FormatParseError( "The extractions must be a sequence (list) of mappings." ) for item in items: if not isinstance(item, dict): raise exceptions.FormatParseError( "Each item in the sequence must be a mapping." ) for k in item.keys(): if not isinstance(k, str): raise exceptions.FormatParseError( "All extraction keys must be strings (got a non-string key)." ) return items def _add_fences(self, content: str) -> str: """Add code fences around content.""" fence_type = self.format_type.value return f"```{fence_type}\n{content.strip()}\n```" def _is_valid_language_tag( self, lang: str | None, valid_tags: dict[data.FormatType, set[str]] ) -> bool: """Check if language tag is valid for the format type.""" if lang is None: return True tag = lang.strip().lower() return tag in valid_tags.get(self.format_type, set()) def _parse_with_fallback(self, content: str, strict: bool): """Parse content, retrying without tags on failure.""" try: if self.format_type == data.FormatType.YAML: return yaml.safe_load(content) return json.loads(content) except (yaml.YAMLError, json.JSONDecodeError): if strict: raise # Reasoning models (DeepSeek-R1, QwQ) emit tags before JSON. if _THINK_TAG_RE.search(content): stripped = _THINK_TAG_RE.sub("", content).strip() if self.format_type == data.FormatType.YAML: return yaml.safe_load(stripped) return json.loads(stripped) raise def _extract_content(self, text: str) -> str: """Extract content from text, handling fences if configured. Args: text: Input text that may contain fenced blocks Returns: Extracted content Raises: FormatParseError: When fences required but not found or multiple blocks found. """ if not self.use_fences: return text.strip() matches = list(_FENCE_RE.finditer(text)) valid_tags = { data.FormatType.YAML: {_YAML_FORMAT, _YML_FORMAT}, data.FormatType.JSON: {_JSON_FORMAT}, } candidates = [ m for m in matches if self._is_valid_language_tag(m.group("lang"), valid_tags) ] if self.strict_fences: if len(candidates) != 1: if len(candidates) == 0: raise exceptions.FormatParseError( "Input string does not contain valid fence markers." ) else: raise exceptions.FormatParseError( "Multiple fenced blocks found. Expected exactly one." ) return candidates[0].group("body").strip() if len(candidates) == 1: return candidates[0].group("body").strip() elif len(candidates) > 1: raise exceptions.FormatParseError( "Multiple fenced blocks found. Expected exactly one." ) if matches: if not self.strict_fences and len(matches) == 1: return matches[0].group("body").strip() raise exceptions.FormatParseError( f"No {self.format_type.value} code block found." ) return text.strip() # ---- Backward compatibility methods (to be removed in v2.0.0) ---- _LEGACY_FORMAT_KEYS = frozenset({ "fence_output", "format_type", "strict_fences", "require_extractions_key", "extraction_attributes_suffix", "attribute_suffix", "format_handler", }) @classmethod def from_resolver_params( cls, *, resolver_params: dict | None, base_format_type: data.FormatType, base_use_fences: bool, base_attribute_suffix: str = data.ATTRIBUTE_SUFFIX, base_use_wrapper: bool = True, base_wrapper_key: str | None = data.EXTRACTIONS_KEY, warn_on_legacy: bool = True, ) -> tuple[FormatHandler, dict]: """Create FormatHandler from resolver_params with legacy support. This method handles backward compatibility for legacy resolver parameters and will be removed in v2.0.0. Args: resolver_params: May contain legacy keys or a 'format_handler'. base_format_type: Default format when not overridden. base_use_fences: Default fence usage from the model. base_attribute_suffix: Default attribute suffix. base_use_wrapper: Default wrapper behavior. base_wrapper_key: Default wrapper key. warn_on_legacy: If True, emit DeprecationWarnings. Returns: (format_handler, remaining_resolver_params) """ rp = dict(resolver_params or {}) if rp.get("format_handler") is not None: handler = rp.pop("format_handler") for k in list(rp.keys()): if k in cls._LEGACY_FORMAT_KEYS: rp.pop(k, None) return handler, rp kwargs = { "format_type": base_format_type, "use_fences": base_use_fences, "attribute_suffix": base_attribute_suffix, "use_wrapper": base_use_wrapper, "wrapper_key": base_wrapper_key if base_use_wrapper else None, } mapping = { "fence_output": "use_fences", "format_type": "format_type", "strict_fences": "strict_fences", "require_extractions_key": "use_wrapper", "extraction_attributes_suffix": "attribute_suffix", "attribute_suffix": "attribute_suffix", } used_legacy = [] for legacy_key, fh_key in mapping.items(): if legacy_key in rp and rp[legacy_key] is not None: val = rp.pop(legacy_key) if fh_key == "format_type" and hasattr(val, "value"): val = val.value kwargs[fh_key] = val used_legacy.append(legacy_key) if warn_on_legacy and used_legacy: warnings.warn( "Resolver legacy params are deprecated and will be removed in" f" v2.0.0: {used_legacy}. Pass a FormatHandler explicitly via" " `resolver_params={'format_handler': FormatHandler(...)}` or rely" " on defaults configured by the model.", DeprecationWarning, stacklevel=3, ) handler = cls(**kwargs) return handler, rp @classmethod def from_kwargs(cls, **kwargs) -> FormatHandler: """Create FormatHandler from legacy resolver keyword arguments. This method will be removed in v2.0.0. Args: **kwargs: Legacy parameters like fence_output, format_type, etc. Returns: FormatHandler configured with legacy parameters. """ legacy_params = { "fence_output", "format_type", "strict_fences", "require_extractions_key", } used_legacy = legacy_params.intersection(kwargs.keys()) if used_legacy: warnings.warn( f"Using legacy Resolver parameters {used_legacy} is deprecated. " "Please use FormatHandler directly. " "This compatibility layer will be removed in v2.0.0.", DeprecationWarning, stacklevel=3, ) fence_output = kwargs.pop("fence_output", True) format_type = kwargs.pop("format_type", None) strict_fences = kwargs.pop("strict_fences", False) require_extractions_key = kwargs.pop("require_extractions_key", True) attribute_suffix = kwargs.pop("attribute_suffix", data.ATTRIBUTE_SUFFIX) if format_type is None: format_type = data.FormatType.JSON elif hasattr(format_type, "value"): pass else: format_type = ( data.FormatType.JSON if str(format_type).lower() == "json" else data.FormatType.YAML ) return cls( format_type=format_type, use_wrapper=require_extractions_key, wrapper_key=data.EXTRACTIONS_KEY if require_extractions_key else None, use_fences=fence_output, strict_fences=strict_fences, attribute_suffix=attribute_suffix, ) ================================================ FILE: langextract/core/schema.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Core schema abstractions for LangExtract.""" from __future__ import annotations import abc from collections.abc import Sequence from typing import Any from langextract.core import data from langextract.core import format_handler as fh from langextract.core import types __all__ = [ "ConstraintType", "Constraint", "BaseSchema", "FormatModeSchema", ] # Backward compat re-exports ConstraintType = types.ConstraintType Constraint = types.Constraint class BaseSchema(abc.ABC): """Abstract base class for generating structured constraints from examples.""" @classmethod @abc.abstractmethod def from_examples( cls, examples_data: Sequence[data.ExampleData], attribute_suffix: str = data.ATTRIBUTE_SUFFIX, ) -> BaseSchema: """Factory method to build a schema instance from example data.""" @abc.abstractmethod def to_provider_config(self) -> dict[str, Any]: """Convert schema to provider-specific configuration. Returns: Dictionary of provider kwargs (e.g., response_schema for Gemini). Should be a pure data mapping with no side effects. """ @property @abc.abstractmethod def requires_raw_output(self) -> bool: """Whether this schema outputs raw JSON/YAML without fence markers. When True, the provider emits syntactically valid JSON directly. When False, the provider needs fence markers for structure. """ def validate_format(self, format_handler: fh.FormatHandler) -> None: """Validate format compatibility and warn about issues. Override in subclasses to check format settings. Default implementation does nothing (no validation needed). Args: format_handler: The format configuration to validate. """ def sync_with_provider_kwargs(self, kwargs: dict[str, Any]) -> None: """Hook to update schema state based on provider kwargs. This allows schemas to adjust their behavior based on caller overrides. For example, FormatModeSchema uses this to sync its format when the caller overrides it, ensuring requires_raw_output stays accurate. Default implementation does nothing. Override if your schema needs to respond to provider kwargs. Args: kwargs: The effective provider kwargs after merging. """ class FormatModeSchema(BaseSchema): """Generic schema for providers that support format modes (JSON/YAML). This schema doesn't enforce structure, only output format. Useful for providers that can guarantee syntactically valid JSON or YAML but don't support field-level constraints. """ def __init__(self, format_type: types.FormatType = types.FormatType.JSON): """Initialize with a format type.""" self.format_type = format_type # Keep _format for backward compatibility with tests self._format = "json" if format_type == types.FormatType.JSON else "yaml" @classmethod def from_examples( cls, examples_data: Sequence[data.ExampleData], attribute_suffix: str = data.ATTRIBUTE_SUFFIX, ) -> FormatModeSchema: """Factory method to build a schema instance from example data.""" # Default to JSON format return cls(format_type=types.FormatType.JSON) def to_provider_config(self) -> dict[str, Any]: """Convert schema to provider-specific configuration.""" return {"format": self._format} @property def requires_raw_output(self) -> bool: """JSON format schemas output raw JSON without fences, YAML does not.""" return self._format == "json" def sync_with_provider_kwargs(self, kwargs: dict[str, Any]) -> None: """Sync format type with provider kwargs.""" if "format_type" in kwargs: self.format_type = kwargs["format_type"] self._format = ( "json" if self.format_type == types.FormatType.JSON else "yaml" ) if "format" in kwargs: self._format = kwargs["format"] self.format_type = ( types.FormatType.JSON if self._format == "json" else types.FormatType.YAML ) ================================================ FILE: langextract/core/tokenizer.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization utilities for text. Provides methods to split text into regex-based or Unicode-aware tokens. Tokenization is used for alignment in `resolver.py` and for determining sentence boundaries for smaller context use cases. This module is not used for tokenization within the language model during inference. """ import abc from collections.abc import Sequence, Set import dataclasses import enum import functools import unicodedata import regex from langextract.core import debug_utils from langextract.core import exceptions __all__ = [ "BaseTokenizerError", "InvalidTokenIntervalError", "SentenceRangeError", "CharInterval", "TokenInterval", "TokenType", "Token", "TokenizedText", "Tokenizer", "RegexTokenizer", "UnicodeTokenizer", "tokenize", "tokens_text", "find_sentence_range", ] class BaseTokenizerError(exceptions.LangExtractError): """Base class for all tokenizer-related errors.""" class InvalidTokenIntervalError(BaseTokenizerError): """Error raised when a token interval is invalid or out of range.""" class SentenceRangeError(BaseTokenizerError): """Error raised when the start token index for a sentence is out of range.""" @dataclasses.dataclass(slots=True) class CharInterval: """Represents a range of character positions in the original text. Attributes: start_pos: The starting character index (inclusive). end_pos: The ending character index (exclusive). """ start_pos: int end_pos: int @dataclasses.dataclass(slots=True) class TokenInterval: """Represents an interval over tokens in tokenized text. The interval is defined by a start index (inclusive) and an end index (exclusive). Attributes: start_index: The index of the first token in the interval. end_index: The index one past the last token in the interval. """ start_index: int = 0 end_index: int = 0 class TokenType(enum.IntEnum): """Enumeration of token types produced during tokenization. Attributes: WORD: Represents an alphabetical word token. NUMBER: Represents a numeric token. PUNCTUATION: Represents punctuation characters. """ WORD = 0 NUMBER = 1 PUNCTUATION = 2 @dataclasses.dataclass(slots=True) class Token: """Represents a token extracted from text. Each token is assigned an index and classified into a type (word, number, punctuation, or acronym). The token also records the range of characters (its CharInterval) that correspond to the substring from the original text. Additionally, it tracks whether it follows a newline. Attributes: index: The position of the token in the sequence of tokens. token_type: The type of the token, as defined by TokenType. char_interval: The character interval within the original text that this token spans. first_token_after_newline: True if the token immediately follows a newline or carriage return. """ index: int token_type: TokenType char_interval: CharInterval = dataclasses.field( default_factory=lambda: CharInterval(0, 0) ) first_token_after_newline: bool = False @dataclasses.dataclass class TokenizedText: """Holds the result of tokenizing a text string. Attributes: text: The text that was tokenized. For UnicodeTokenizer, this is NOT normalized to NFC (to preserve indices). tokens: A list of Token objects extracted from the text. """ text: str tokens: list[Token] = dataclasses.field(default_factory=list) _LETTERS_PATTERN = r"[^\W\d_]+" _DIGITS_PATTERN = r"\d+" # Group identical symbols (e.g. "!!") but split mixed ones. _SYMBOLS_PATTERN = r"([^\w\s]|_)\1*" _END_OF_SENTENCE_PATTERN = regex.compile(r"[.?!。!?\u0964][\"'”’»)\]}]*$") _TOKEN_PATTERN = regex.compile( rf"{_LETTERS_PATTERN}|{_DIGITS_PATTERN}|{_SYMBOLS_PATTERN}" ) _WORD_PATTERN = regex.compile(rf"(?:{_LETTERS_PATTERN}|{_DIGITS_PATTERN})\Z") # Abbreviations that do not end sentences. # TODO: Evaluate removal for large-context use cases. _KNOWN_ABBREVIATIONS = frozenset({"Mr.", "Mrs.", "Ms.", "Dr.", "Prof.", "St."}) _CLOSING_PUNCTUATION = frozenset({'"', "'", "”", "’", "»", ")", "]", "}"}) class Tokenizer(abc.ABC): """Abstract base class for tokenizers.""" @abc.abstractmethod def tokenize(self, text: str) -> TokenizedText: """Splits text into tokens. Args: text: The text to tokenize. Returns: A TokenizedText object. """ class RegexTokenizer(Tokenizer): """Regex-based tokenizer (default). The RegexTokenizer is faster than UnicodeTokenizer for English text because it skips involved Unicode handling. """ @debug_utils.debug_log_calls def tokenize(self, text: str) -> TokenizedText: """Splits text into tokens (words, digits, or punctuation). Each token is annotated with its character position and type. Tokens following a newline or carriage return have `first_token_after_newline` set to True. Args: text: The text to tokenize. Returns: A TokenizedText object containing all extracted tokens. """ tokenized = TokenizedText(text=text) previous_end = 0 for token_index, match in enumerate(_TOKEN_PATTERN.finditer(text)): start_pos, end_pos = match.span() matched_text = match.group() token = Token( index=token_index, char_interval=CharInterval(start_pos=start_pos, end_pos=end_pos), token_type=TokenType.WORD, first_token_after_newline=False, ) if token_index > 0: # Optimization: Check gap without slicing. has_newline = text.find("\n", previous_end, start_pos) != -1 if not has_newline: has_newline = text.find("\r", previous_end, start_pos) != -1 if has_newline: token.first_token_after_newline = True if regex.fullmatch(_DIGITS_PATTERN, matched_text): token.token_type = TokenType.NUMBER elif _WORD_PATTERN.fullmatch(matched_text): token.token_type = TokenType.WORD else: token.token_type = TokenType.PUNCTUATION tokenized.tokens.append(token) previous_end = end_pos return tokenized # Default tokenizer instance for backward compatibility _DEFAULT_TOKENIZER = RegexTokenizer() def tokenize( text: str, tokenizer: Tokenizer = _DEFAULT_TOKENIZER ) -> TokenizedText: """Splits text into tokens using the provided tokenizer (default: RegexTokenizer). Args: text: The text to tokenize. tokenizer: The tokenizer instance to use. Returns: A TokenizedText object. """ return tokenizer.tokenize(text) _CJK_PATTERN = regex.compile( r"\p{Is_Han}|\p{Is_Hiragana}|\p{Is_Katakana}|\p{Is_Hangul}" ) _NON_SPACED_PATTERN = regex.compile( r"\p{Is_Thai}|\p{Is_Lao}|\p{Is_Khmer}|\p{Is_Myanmar}" ) class Sentinel: """Sentinel class for unique object identification.""" def __init__(self, name: str): self.name = name def __repr__(self) -> str: return f"<{self.name}>" _NO_GROUP_SCRIPT = Sentinel("NO_GROUP") _UNKNOWN_SCRIPT = Sentinel("UNKNOWN") _LATIN_SCRIPT = "Latin" # Optimization: Direct mapping for common scripts avoids regex overhead. def _get_script_fast(char: str) -> str | Sentinel: # Fast path for ASCII: Avoids regex and unicodedata lookups. if ord(char) < 128: return _LATIN_SCRIPT # Fallback to the robust regex method return _get_common_script_cached(char) def _classify_grapheme(g: str) -> TokenType: if not g: return TokenType.PUNCTUATION c = g[0] cat = unicodedata.category(c) if cat.startswith("L"): return TokenType.WORD if cat.startswith("N"): return TokenType.NUMBER return TokenType.PUNCTUATION _COMMON_SCRIPTS = [ "Latin", "Cyrillic", "Greek", "Arabic", "Hebrew", "Devanagari", ] _COMMON_SCRIPTS_PATTERN = regex.compile( "|".join( rf"(?P<{script}>\p{{Script={script}}})" for script in _COMMON_SCRIPTS ) ) _GRAPHEME_CLUSTER_PATTERN = regex.compile(r"\X") @functools.lru_cache(maxsize=4096) def _get_common_script_cached(c: str) -> str | Sentinel: """Determines script using regex, cached for performance.""" match = _COMMON_SCRIPTS_PATTERN.match(c) if match: return match.lastgroup return _UNKNOWN_SCRIPT class UnicodeTokenizer(Tokenizer): """Unicode-aware tokenizer for better non-English support. This tokenizer uses Unicode character properties (Unicode Standard Annex #29) via the `regex` library's `\\X` pattern to correctly handle grapheme clusters like Emojis and Hangul. Unlike some Unicode tokenizers, this class does NOT normalize text to NFC. This ensures that token indices exactly match the original input string. Note: Grapheme clustering makes this tokenizer slower than RegexTokenizer. """ @debug_utils.debug_log_calls def tokenize(self, text: str) -> TokenizedText: """Splits text into tokens using Unicode properties. Args: text: The text to tokenize. Returns: A TokenizedText object. """ tokens: list[Token] = [] current_start = 0 current_type = None current_script = None previous_end = 0 for match in regex.finditer(r"\X", text): grapheme = match.group() start, _ = match.span() # 1. Handle Whitespace if grapheme.isspace(): if current_type is not None: self._emit_token( tokens, text, current_start, start, current_type, previous_end ) previous_end = start current_type = None current_script = None # Keep `previous_end` to detect newlines within the whitespace gap. continue g_type = _classify_grapheme(grapheme) # 2. Determine if we should merge with the current token should_merge = False if current_type is not None: if current_type == g_type: if current_type == TokenType.WORD: # Script Check first_char = grapheme[0] # Fast path: Explicit NO_GROUP (CJK/Thai) never merges. if current_script is _NO_GROUP_SCRIPT: should_merge = False # CJK and Non-Spaced scripts require fragmentation. elif _CJK_PATTERN.match(first_char) or _NON_SPACED_PATTERN.match( first_char ): should_merge = False else: g_script = _get_script_fast(first_char) # Safety: Do not merge distinct unknown scripts. if ( current_script == g_script and current_script is not _UNKNOWN_SCRIPT ): should_merge = True elif current_type == TokenType.NUMBER: should_merge = True elif current_type == TokenType.PUNCTUATION: # Heuristic: Merge punctuation only if identical (e.g. "!!"). last_grapheme = text[current_start:start] if last_grapheme == grapheme: should_merge = True elif len(last_grapheme) >= len(grapheme) and last_grapheme.endswith( grapheme ): should_merge = True # 3. State Transition if should_merge: # Extend current token pass else: # Flush previous token if exists if current_type is not None: self._emit_token( tokens, text, current_start, start, current_type, previous_end ) previous_end = start # Start new token current_start = start current_type = g_type # Determine script for the new token if current_type == TokenType.WORD: c = grapheme[0] if _CJK_PATTERN.match(c) or _NON_SPACED_PATTERN.match(c): current_script = _NO_GROUP_SCRIPT else: current_script = _get_script_fast(c) else: current_script = None # 4. Flush final token if current_type is not None: self._emit_token( tokens, text, current_start, len(text), current_type, previous_end ) return TokenizedText(text=text, tokens=tokens) def _emit_token( self, tokens: list[Token], text: str, start: int, end: int, token_type: TokenType, previous_end: int, ): """Helper to create and append a token.""" token = Token( index=len(tokens), char_interval=CharInterval(start_pos=start, end_pos=end), token_type=token_type, first_token_after_newline=False, ) # Check for newlines in the gap between the previous token and this one if start > previous_end: gap = text[previous_end:start] if "\n" in gap or "\r" in gap: token.first_token_after_newline = True tokens.append(token) def tokens_text( tokenized_text: TokenizedText, token_interval: TokenInterval, ) -> str: """Reconstructs the substring of the original text spanning a given token interval. Args: tokenized_text: A TokenizedText object containing token data. token_interval: The interval specifying the range [start_index, end_index) of tokens. Returns: The exact substring of the original text corresponding to the token interval. Raises: InvalidTokenIntervalError: If the token_interval is invalid or out of range. """ if token_interval.start_index == token_interval.end_index: return "" if ( token_interval.start_index < 0 or token_interval.end_index > len(tokenized_text.tokens) or token_interval.start_index > token_interval.end_index ): raise InvalidTokenIntervalError( f"Invalid token interval. start_index={token_interval.start_index}, " f"end_index={token_interval.end_index}, " f"total_tokens={len(tokenized_text.tokens)}." ) start_token = tokenized_text.tokens[token_interval.start_index] end_token = tokenized_text.tokens[token_interval.end_index - 1] return tokenized_text.text[ start_token.char_interval.start_pos : end_token.char_interval.end_pos ] def _is_end_of_sentence_token( text: str, tokens: Sequence[Token], current_idx: int, known_abbreviations: Set[str] = _KNOWN_ABBREVIATIONS, ) -> bool: """Checks if the punctuation token at `current_idx` ends a sentence. A token is considered a sentence terminator and is not part of a known abbreviation. Only searches the text corresponding to the current token. Args: text: The entire input text. tokens: The sequence of Token objects. current_idx: The current token index to check. known_abbreviations: Abbreviations that should not count as sentence enders (e.g., "Dr."). Returns: True if the token at `current_idx` ends a sentence, otherwise False. """ current_token_text = text[ tokens[current_idx] .char_interval.start_pos : tokens[current_idx] .char_interval.end_pos ] if _END_OF_SENTENCE_PATTERN.search(current_token_text): if current_idx > 0: prev_token_text = text[ tokens[current_idx - 1] .char_interval.start_pos : tokens[current_idx - 1] .char_interval.end_pos ] if f"{prev_token_text}{current_token_text}" in known_abbreviations: return False return True return False def _is_sentence_break_after_newline( text: str, tokens: Sequence[Token], current_idx: int, ) -> bool: """Checks if the next token starts uppercase and follows a newline. Args: text: The entire input text. tokens: The sequence of Token objects. current_idx: The current token index. Returns: True if a newline is found between current_idx and current_idx+1, and the next token (if any) begins with an uppercase character. """ if current_idx + 1 >= len(tokens): return False next_token = tokens[current_idx + 1] if not next_token.first_token_after_newline: return False next_token_text = text[ next_token.char_interval.start_pos : next_token.char_interval.end_pos ] # Assume break unless lowercase (covers numbers/quotes). return bool(next_token_text) and not next_token_text[0].islower() def find_sentence_range( text: str, tokens: Sequence[Token], start_token_index: int, known_abbreviations: Set[str] = _KNOWN_ABBREVIATIONS, ) -> TokenInterval: """Finds a 'sentence' interval from a given start index. Sentence boundaries are defined by: - punctuation tokens in _END_OF_SENTENCE_PATTERN - newline breaks followed by an uppercase letter - not abbreviations in _KNOWN_ABBREVIATIONS (e.g., "Dr.") This favors terminating a sentence prematurely over missing a sentence boundary, and will terminate a sentence early if the first line ends with new line and the second line begins with a capital letter. Args: text: The text to analyze. tokens: The tokens that make up `text`. Note: For UnicodeTokenizer, use normalized text. start_token_index: The index of the token to start the sentence from. known_abbreviations: A set of strings that are known abbreviations and should not be treated as sentence boundaries. Returns: A TokenInterval representing the sentence range [start_token_index, end). If no sentence boundary is found, the end index will be the length of `tokens`. Raises: SentenceRangeError: If `start_token_index` is out of range. """ if not tokens: return TokenInterval(0, 0) if start_token_index < 0 or start_token_index >= len(tokens): raise SentenceRangeError( f"start_token_index={start_token_index} out of range. " f"Total tokens: {len(tokens)}." ) i = start_token_index while i < len(tokens): if tokens[i].token_type == TokenType.PUNCTUATION: if _is_end_of_sentence_token(text, tokens, i, known_abbreviations): end_index = i + 1 # Consume any trailing closing punctuation (e.g. quotes, parens) while end_index < len(tokens): next_token_text = text[ tokens[end_index] .char_interval.start_pos : tokens[end_index] .char_interval.end_pos ] if ( tokens[end_index].token_type == TokenType.PUNCTUATION and next_token_text in _CLOSING_PUNCTUATION ): end_index += 1 else: break return TokenInterval(start_index=start_token_index, end_index=end_index) if _is_sentence_break_after_newline(text, tokens, i): return TokenInterval(start_index=start_token_index, end_index=i + 1) i += 1 return TokenInterval(start_index=start_token_index, end_index=len(tokens)) ================================================ FILE: langextract/core/types.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Core data types for LangExtract.""" from __future__ import annotations import dataclasses import enum import textwrap __all__ = [ 'ScoredOutput', 'FormatType', 'ConstraintType', 'Constraint', ] class FormatType(enum.Enum): """Enumeration of prompt output formats.""" YAML = 'yaml' JSON = 'json' class ConstraintType(enum.Enum): """Enumeration of constraint types.""" NONE = 'none' @dataclasses.dataclass class Constraint: """Represents a constraint for model output decoding. Attributes: constraint_type: The type of constraint applied. """ constraint_type: ConstraintType = ConstraintType.NONE @dataclasses.dataclass(frozen=True) class ScoredOutput: """Scored output from language model inference.""" score: float | None = None output: str | None = None def __str__(self) -> str: score_str = '-' if self.score is None else f'{self.score:.2f}' if self.output is None: return f'Score: {score_str}\nOutput: None' formatted_lines = textwrap.indent(self.output, prefix=' ') return f'Score: {score_str}\nOutput:\n{formatted_lines}' ================================================ FILE: langextract/data.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.data imports. This module provides backward compatibility for code that imports from langextract.data. All functionality has moved to langextract.core.data. """ from __future__ import annotations # Re-export everything from core.data for backward compatibility # pylint: disable=unused-wildcard-import from langextract.core.data import * ================================================ FILE: langextract/data_lib.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library for data conversion between AnnotatedDocument and JSON.""" from __future__ import annotations import dataclasses import enum import numbers from typing import Any, Iterable, Mapping from langextract.core import data from langextract.core import tokenizer def enum_asdict_factory(items: Iterable[tuple[str, Any]]) -> dict[str, Any]: """Custom dict_factory for dataclasses.asdict. Recursively converts dataclass instances, converts enum values to their underlying values, converts integral numeric types to int, and skips any field whose name starts with an underscore. Args: items: An iterable of (key, value) pairs from fields of a dataclass. Returns: A mapping of field names to their values, with special handling for dataclasses, enums, and numeric types. """ result: dict[str, Any] = {} for key, value in items: # Skip internal fields. if key.startswith("_"): continue if dataclasses.is_dataclass(value): result[key] = dataclasses.asdict(value, dict_factory=enum_asdict_factory) elif isinstance(value, enum.Enum): result[key] = value.value elif isinstance(value, numbers.Integral) and not isinstance(value, bool): result[key] = int(value) else: result[key] = value return result def annotated_document_to_dict( adoc: data.AnnotatedDocument | None, ) -> dict[str, Any]: """Converts an AnnotatedDocument into a Python dict. This function converts an AnnotatedDocument object into a Python dict, making it easier to serialize or deserialize the document. Enum values and NumPy integers are converted to their underlying values, while other data types are left unchanged. Private fields with an underscore prefix are not included in the output. Args: adoc: The AnnotatedDocument object to convert. Returns: A Python dict representing the AnnotatedDocument. """ if not adoc: return {} result = dataclasses.asdict(adoc, dict_factory=enum_asdict_factory) result["document_id"] = adoc.document_id return result def dict_to_annotated_document( adoc_dic: Mapping[str, Any], ) -> data.AnnotatedDocument: """Converts a Python dict back to an AnnotatedDocument. Args: adoc_dic: A Python dict representing an AnnotatedDocument. Returns: An AnnotatedDocument object. """ if not adoc_dic: return data.AnnotatedDocument() for extractions in adoc_dic.get("extractions", []): token_int = extractions.get("token_interval") if token_int: extractions["token_interval"] = tokenizer.TokenInterval(**token_int) else: extractions["token_interval"] = None char_int = extractions.get("char_interval") if char_int: extractions["char_interval"] = data.CharInterval(**char_int) else: extractions["char_interval"] = None status_str = extractions.get("alignment_status") if status_str: extractions["alignment_status"] = data.AlignmentStatus(status_str) else: extractions["alignment_status"] = None return data.AnnotatedDocument( document_id=adoc_dic.get("document_id"), text=adoc_dic.get("text"), extractions=[ data.Extraction(**ent) for ent in adoc_dic.get("extractions", []) ], ) ================================================ FILE: langextract/exceptions.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Public exceptions API for LangExtract. This module re-exports exceptions from core.exceptions for backward compatibility. All new code should import directly from langextract.core.exceptions. """ # pylint: disable=duplicate-code from __future__ import annotations from langextract.core import exceptions as core_exceptions # Backward compat re-exports InferenceConfigError = core_exceptions.InferenceConfigError InferenceError = core_exceptions.InferenceError InferenceOutputError = core_exceptions.InferenceOutputError InferenceRuntimeError = core_exceptions.InferenceRuntimeError LangExtractError = core_exceptions.LangExtractError ProviderError = core_exceptions.ProviderError SchemaError = core_exceptions.SchemaError __all__ = [ "LangExtractError", "InferenceError", "InferenceConfigError", "InferenceRuntimeError", "InferenceOutputError", "ProviderError", "SchemaError", ] ================================================ FILE: langextract/extraction.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main extraction API for LangExtract.""" from __future__ import annotations from collections.abc import Iterable import typing from typing import cast import warnings from langextract import annotation from langextract import factory from langextract import io from langextract import prompt_validation as pv from langextract import prompting from langextract import resolver from langextract.core import base_model from langextract.core import data from langextract.core import format_handler as fh from langextract.core import tokenizer as tokenizer_lib def extract( text_or_documents: typing.Any, prompt_description: str | None = None, examples: typing.Sequence[typing.Any] | None = None, model_id: str = "gemini-2.5-flash", api_key: str | None = None, language_model_type: typing.Type[typing.Any] | None = None, format_type: typing.Any = None, max_char_buffer: int = 1000, temperature: float | None = None, fence_output: bool | None = None, use_schema_constraints: bool = True, batch_length: int = 10, max_workers: int = 10, additional_context: str | None = None, resolver_params: dict | None = None, language_model_params: dict | None = None, debug: bool = False, model_url: str | None = None, extraction_passes: int = 1, context_window_chars: int | None = None, config: typing.Any = None, model: typing.Any = None, *, fetch_urls: bool = True, prompt_validation_level: pv.PromptValidationLevel = pv.PromptValidationLevel.WARNING, prompt_validation_strict: bool = False, show_progress: bool = True, tokenizer: tokenizer_lib.Tokenizer | None = None, ) -> list[data.AnnotatedDocument] | data.AnnotatedDocument: """Extracts structured information from text. Retrieves structured information from the provided text or documents using a language model based on the instructions in prompt_description and guided by examples. Supports sequential extraction passes to improve recall at the cost of additional API calls. Args: text_or_documents: The source text to extract information from, a URL to download text from (starting with http:// or https:// when fetch_urls is True), or an iterable of Document objects. prompt_description: Instructions for what to extract from the text. examples: List of ExampleData objects to guide the extraction. tokenizer: Optional Tokenizer instance to use for chunking and alignment. If None, defaults to RegexTokenizer. api_key: API key for Gemini or other LLM services (can also use environment variable LANGEXTRACT_API_KEY). Cost considerations: Most APIs charge by token volume. Smaller max_char_buffer values increase the number of API calls, while extraction_passes > 1 reprocesses tokens multiple times. Note that max_workers improves processing speed without additional token costs. Refer to your API provider's pricing details and monitor usage with small test runs to estimate costs. model_id: The model ID to use for extraction (e.g., 'gemini-2.5-flash'). If your model ID is not recognized or you need to use a custom provider, use the 'config' parameter with factory.ModelConfig to specify the provider explicitly. language_model_type: [DEPRECATED] The type of language model to use for inference. Warning triggers when value differs from the legacy default (GeminiLanguageModel). This parameter will be removed in v2.0.0. Use the model, config, or model_id parameters instead. format_type: The format type for the output (JSON or YAML). max_char_buffer: Max number of characters for inference. temperature: The sampling temperature for generation. When None (default), uses the model's default temperature. Set to 0.0 for deterministic output or higher values for more variation. fence_output: Whether to expect/generate fenced output (```json or ```yaml). When True, the model is prompted to generate fenced output and the resolver expects it. When False, raw JSON/YAML is expected. When None, automatically determined based on provider schema capabilities: if a schema is applied and requires_raw_output is True, defaults to False; otherwise True. If your model utilizes schema constraints, this can generally be set to False unless the constraint also accounts for code fence delimiters. use_schema_constraints: Whether to generate schema constraints for models. For supported models, this enables structured outputs. Defaults to True. batch_length: Number of text chunks processed per batch. Higher values enable greater parallelization when batch_length >= max_workers. Defaults to 10. max_workers: Maximum parallel workers for concurrent processing. Effective parallelization is limited by min(batch_length, max_workers). Supported by Gemini models. Defaults to 10. additional_context: Additional context to be added to the prompt during inference. resolver_params: Parameters for the `resolver.Resolver`, which parses the raw language model output string (e.g., extracting JSON from ```json ... ``` blocks) into structured `data.Extraction` objects. This dictionary overrides default settings. Keys include: - 'extraction_index_suffix' (str | None): Suffix for keys indicating extraction order. Default is None (order by appearance). Additional alignment parameters can be included: 'enable_fuzzy_alignment' (bool): Whether to use fuzzy matching if exact matching fails. Disabling this can improve performance but may reduce recall. Default is True. 'fuzzy_alignment_threshold' (float): Minimum token overlap ratio for fuzzy match (0.0-1.0). Default is 0.75. 'accept_match_lesser' (bool): Whether to accept partial exact matches. Default is True. 'suppress_parse_errors' (bool): Whether to suppress parsing errors and continue pipeline. Default is False. language_model_params: Additional parameters for the language model. debug: Whether to enable debug logging. When True, enables detailed logging of function calls, arguments, return values, and timing for the langextract namespace. Note: Debug logging remains enabled for the process once activated. model_url: Endpoint URL for self-hosted or on-prem models. Only forwarded when the selected `language_model_type` accepts this argument. extraction_passes: Number of sequential extraction attempts to improve recall and find additional entities. Defaults to 1 (standard single extraction). When > 1, the system performs multiple independent extractions and merges non-overlapping results (first extraction wins for overlaps). WARNING: Each additional pass reprocesses tokens, potentially increasing API costs. For example, extraction_passes=3 reprocesses tokens 3x. context_window_chars: Number of characters from the previous chunk to include as context for the current chunk. This helps with coreference resolution across chunk boundaries (e.g., resolving "She" to a person mentioned in the previous chunk). Defaults to None (disabled). config: Model configuration to use for extraction. Takes precedence over model_id, api_key, and language_model_type parameters. When both model and config are provided, model takes precedence. model: Pre-configured language model to use for extraction. Takes precedence over all other parameters including config. fetch_urls: Whether to automatically download content when the input is a URL string. When True (default), strings starting with http:// or https:// are fetched. When False, all strings are treated as literal text to analyze. This is a keyword-only parameter. prompt_validation_level: Controls pre-flight alignment checks on few-shot examples. OFF skips validation, WARNING logs issues but continues, ERROR raises on failures. Defaults to WARNING. prompt_validation_strict: When True and prompt_validation_level is ERROR, raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False. show_progress: Whether to show progress bar during extraction. Defaults to True. Returns: An AnnotatedDocument with the extracted information when input is a string or URL, or an iterable of AnnotatedDocuments when input is an iterable of Documents. Raises: ValueError: If examples is None or empty. ValueError: If no API key is provided or found in environment variables. requests.RequestException: If URL download fails. pv.PromptAlignmentError: If validation fails in ERROR mode. """ if not examples: raise ValueError( "Examples are required for reliable extraction. Please provide at least" " one ExampleData object with sample extractions." ) if prompt_validation_level is not pv.PromptValidationLevel.OFF: report = pv.validate_prompt_alignment( examples=examples, aligner=resolver.WordAligner(), policy=pv.AlignmentPolicy(), tokenizer=tokenizer, ) pv.handle_alignment_report( report, level=prompt_validation_level, strict_non_exact=prompt_validation_strict, ) if debug: # pylint: disable=import-outside-toplevel from langextract.core import debug_utils debug_utils.configure_debug_logging() if format_type is None: format_type = data.FormatType.JSON if max_workers is not None and batch_length < max_workers: warnings.warn( f"batch_length ({batch_length}) < max_workers ({max_workers}). " f"Only {batch_length} workers will be used. " "Set batch_length >= max_workers for optimal parallelization.", UserWarning, ) if ( fetch_urls and isinstance(text_or_documents, str) and io.is_url(text_or_documents) ): text_or_documents = io.download_text_from_url(text_or_documents) prompt_template = prompting.PromptTemplateStructured( description=prompt_description ) prompt_template.examples.extend(examples) language_model: base_model.BaseLanguageModel | None = None if model: language_model = model if fence_output is not None: language_model.set_fence_output(fence_output) if use_schema_constraints: warnings.warn( "'use_schema_constraints' is ignored when 'model' is provided. " "The model should already be configured with schema constraints.", UserWarning, stacklevel=2, ) elif config: if use_schema_constraints: warnings.warn( "With 'config', schema constraints are still applied via examples. " "Or pass explicit schema in config.provider_kwargs.", UserWarning, stacklevel=2, ) language_model = factory.create_model( config=config, examples=prompt_template.examples if use_schema_constraints else None, use_schema_constraints=use_schema_constraints, fence_output=fence_output, ) else: if language_model_type is not None: warnings.warn( "'language_model_type' is deprecated and will be removed in v2.0.0. " "Use model, config, or model_id parameters instead.", FutureWarning, stacklevel=2, ) base_lm_kwargs: dict[str, typing.Any] = { "api_key": api_key, "format_type": format_type, "temperature": temperature, "model_url": model_url, "base_url": model_url, "max_workers": max_workers, } # TODO(v2.0.0): Remove gemini_schema parameter if "gemini_schema" in (language_model_params or {}): warnings.warn( "'gemini_schema' is deprecated. Schema constraints are now " "automatically handled. This parameter will be ignored.", FutureWarning, stacklevel=2, ) language_model_params = dict(language_model_params or {}) language_model_params.pop("gemini_schema", None) base_lm_kwargs.update(language_model_params or {}) filtered_kwargs = {k: v for k, v in base_lm_kwargs.items() if v is not None} config = factory.ModelConfig( model_id=model_id, provider_kwargs=filtered_kwargs ) language_model = factory.create_model( config=config, examples=prompt_template.examples if use_schema_constraints else None, use_schema_constraints=use_schema_constraints, fence_output=fence_output, ) format_handler, remaining_params = fh.FormatHandler.from_resolver_params( resolver_params=resolver_params, base_format_type=format_type, base_use_fences=language_model.requires_fence_output, base_attribute_suffix=data.ATTRIBUTE_SUFFIX, base_use_wrapper=True, base_wrapper_key=data.EXTRACTIONS_KEY, ) if language_model.schema is not None: language_model.schema.validate_format(format_handler) # Pull alignment settings from normalized params alignment_kwargs = {} for key in resolver.ALIGNMENT_PARAM_KEYS: val = remaining_params.pop(key, None) if val is not None: alignment_kwargs[key] = val effective_params = {"format_handler": format_handler, **remaining_params} try: res = resolver.Resolver(**effective_params) except TypeError as e: msg = str(e) if ( "unexpected keyword argument" in msg or "got an unexpected keyword argument" in msg ): raise TypeError( f"Unknown key in resolver_params; check spelling: {e}" ) from e raise annotator = annotation.Annotator( language_model=language_model, prompt_template=prompt_template, format_handler=format_handler, ) if isinstance(text_or_documents, str): result = annotator.annotate_text( text=text_or_documents, resolver=res, max_char_buffer=max_char_buffer, batch_length=batch_length, additional_context=additional_context, debug=debug, extraction_passes=extraction_passes, context_window_chars=context_window_chars, show_progress=show_progress, max_workers=max_workers, tokenizer=tokenizer, **alignment_kwargs, ) return result else: documents = cast(Iterable[data.Document], text_or_documents) result = annotator.annotate_documents( documents=documents, resolver=res, max_char_buffer=max_char_buffer, batch_length=batch_length, debug=debug, extraction_passes=extraction_passes, context_window_chars=context_window_chars, show_progress=show_progress, max_workers=max_workers, tokenizer=tokenizer, **alignment_kwargs, ) return list(result) ================================================ FILE: langextract/factory.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Factory for creating language model instances. This module provides a factory pattern for instantiating language models based on configuration, with support for environment variable resolution and provider-specific defaults. """ from __future__ import annotations import dataclasses import os import typing import warnings from langextract import providers from langextract.core import base_model from langextract.core import exceptions from langextract.providers import router @dataclasses.dataclass(slots=True, frozen=True) class ModelConfig: """Configuration for instantiating a language model provider. Attributes: model_id: The model identifier (e.g., "gemini-2.5-flash", "gpt-4o"). provider: Optional explicit provider name or class name. Use this to disambiguate when multiple providers support the same model_id. provider_kwargs: Optional provider-specific keyword arguments. """ model_id: str | None = None provider: str | None = None provider_kwargs: dict[str, typing.Any] = dataclasses.field( default_factory=dict ) def _kwargs_with_environment_defaults( model_id: str, kwargs: dict[str, typing.Any] ) -> dict[str, typing.Any]: """Add environment-based defaults to provider kwargs. Args: model_id: The model identifier. kwargs: Existing keyword arguments. Returns: Updated kwargs with environment defaults. """ resolved = dict(kwargs) if "api_key" not in resolved and not resolved.get("vertexai", False): model_lower = model_id.lower() env_vars_by_provider = { "gemini": ("GEMINI_API_KEY", "LANGEXTRACT_API_KEY"), "gpt": ("OPENAI_API_KEY", "LANGEXTRACT_API_KEY"), } for provider_prefix, env_vars in env_vars_by_provider.items(): if provider_prefix in model_lower: found_keys = [] for env_var in env_vars: key_val = os.getenv(env_var) if key_val: found_keys.append((env_var, key_val)) if found_keys: resolved["api_key"] = found_keys[0][1] if len(found_keys) > 1: keys_list = ", ".join(k[0] for k in found_keys) warnings.warn( f"Multiple API keys detected in environment: {keys_list}. " f"Using {found_keys[0][0]} and ignoring others.", UserWarning, stacklevel=3, ) break if "ollama" in model_id.lower() and "base_url" not in resolved: resolved["base_url"] = os.getenv( "OLLAMA_BASE_URL", "http://localhost:11434" ) return resolved def create_model( config: ModelConfig, examples: typing.Sequence[typing.Any] | None = None, use_schema_constraints: bool = False, fence_output: bool | None = None, return_fence_output: bool = False, ) -> base_model.BaseLanguageModel | tuple[base_model.BaseLanguageModel, bool]: """Create a language model instance from configuration. Args: config: Model configuration with optional model_id and/or provider. examples: Optional examples for schema generation (if use_schema_constraints=True). use_schema_constraints: Whether to apply schema constraints from examples. fence_output: Explicit fence output preference. If None, computed from schema. return_fence_output: If True, also return computed fence_output value. Returns: An instantiated language model provider. If return_fence_output=True: Tuple of (model, model.requires_fence_output). Raises: ValueError: If neither model_id nor provider is specified. ValueError: If no provider is registered for the model_id. InferenceConfigError: If provider instantiation fails. """ if use_schema_constraints or fence_output is not None: model = _create_model_with_schema( config=config, examples=examples, use_schema_constraints=use_schema_constraints, fence_output=fence_output, ) if return_fence_output: return model, model.requires_fence_output return model if not config.model_id and not config.provider: raise ValueError("Either model_id or provider must be specified") providers.load_builtins_once() providers.load_plugins_once() try: if config.provider: provider_class = router.resolve_provider(config.provider) else: provider_class = router.resolve(config.model_id) except (ModuleNotFoundError, ImportError) as e: raise exceptions.InferenceConfigError( "Failed to load provider. " "This may be due to missing dependencies. " f"Check that all required packages are installed. Error: {e}" ) from e model_id = config.model_id model_id = config.model_id kwargs = _kwargs_with_environment_defaults( model_id or config.provider or "", config.provider_kwargs ) if model_id: kwargs["model_id"] = model_id try: model = provider_class(**kwargs) if return_fence_output: return model, model.requires_fence_output return model except (ValueError, TypeError) as e: raise exceptions.InferenceConfigError( f"Failed to create provider {provider_class.__name__}: {e}" ) from e def create_model_from_id( model_id: str | None = None, provider: str | None = None, **provider_kwargs: typing.Any, ) -> base_model.BaseLanguageModel: """Convenience function to create a model. Args: model_id: The model identifier (e.g., "gemini-2.5-flash"). provider: Optional explicit provider name to disambiguate. **provider_kwargs: Optional provider-specific keyword arguments. Returns: An instantiated language model provider. """ config = ModelConfig( model_id=model_id, provider=provider, provider_kwargs=provider_kwargs ) return create_model(config) def _create_model_with_schema( config: ModelConfig, examples: typing.Sequence[typing.Any] | None = None, use_schema_constraints: bool = True, fence_output: bool | None = None, ) -> base_model.BaseLanguageModel: """Internal helper to create a model with optional schema constraints. This function creates a language model and optionally configures it with schema constraints derived from the provided examples. It also computes appropriate fence defaulting based on the schema's capabilities. Args: config: Model configuration with model_id and/or provider. examples: Optional sequence of ExampleData for schema generation. use_schema_constraints: Whether to generate and apply schema constraints. fence_output: Whether to wrap output in markdown fences. If None, will be computed based on schema's requires_raw_output. Returns: A model instance with fence_output configured appropriately. """ if config.provider: provider_class = router.resolve_provider(config.provider) else: providers.load_builtins_once() providers.load_plugins_once() provider_class = router.resolve(config.model_id) schema_instance = None if use_schema_constraints and examples: schema_class = provider_class.get_schema_class() if schema_class is not None: schema_instance = schema_class.from_examples(examples) if schema_instance: kwargs = schema_instance.to_provider_config() kwargs.update(config.provider_kwargs) else: kwargs = dict(config.provider_kwargs) if schema_instance: schema_instance.sync_with_provider_kwargs(kwargs) # Add environment defaults model_id = config.model_id kwargs = _kwargs_with_environment_defaults( model_id or config.provider or "", kwargs ) if model_id: kwargs["model_id"] = model_id try: model = provider_class(**kwargs) except (ValueError, TypeError) as e: raise exceptions.InferenceConfigError( f"Failed to create provider {provider_class.__name__}: {e}" ) from e model.apply_schema(schema_instance) model.set_fence_output(fence_output) return model ================================================ FILE: langextract/inference.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Language model inference compatibility layer. This module provides backward compatibility for the inference module. New code should import from langextract.core.base_model instead. """ from __future__ import annotations from langextract._compat import inference def __getattr__(name: str): """Forward to _compat.inference for backward compatibility.""" # Handle InferenceType specially since it's defined in _compat if name == "InferenceType": return inference.InferenceType return inference.__getattr__(name) ================================================ FILE: langextract/io.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Supports Input and Output Operations for Data Annotations.""" from __future__ import annotations import abc import dataclasses import ipaddress import json import os import pathlib from typing import Any, Iterator from urllib import parse as urlparse import pandas as pd import requests from langextract import data_lib from langextract import progress from langextract.core import data from langextract.core import exceptions DEFAULT_TIMEOUT_SECONDS = 30 class InvalidDatasetError(exceptions.LangExtractError): """Error raised when Dataset is empty or invalid.""" @dataclasses.dataclass(frozen=True) class Dataset(abc.ABC): """A dataset for inputs to LLM Labeler.""" input_path: pathlib.Path id_key: str text_key: str def load(self, delimiter: str = ',') -> Iterator[data.Document]: """Loads the dataset from a CSV file. Args: delimiter: The delimiter to use when reading the CSV file. Yields: A Document for each row in the dataset. Raises: IOError: If the file does not exist. InvalidDatasetError: If the dataset is empty or invalid. NotImplementedError: If the file type is not supported. """ if not os.path.exists(self.input_path): raise IOError(f'File does not exist: {self.input_path}') if str(self.input_path).endswith('.csv'): try: csv_data = _read_csv( self.input_path, column_names=[self.text_key, self.id_key], delimiter=delimiter, ) except InvalidDatasetError as e: raise InvalidDatasetError(f'Empty dataset: {self.input_path}') from e for row in csv_data: yield data.Document( text=row[self.text_key], document_id=row[self.id_key], ) else: raise NotImplementedError(f'Unsupported file type: {self.input_path}') def save_annotated_documents( annotated_documents: Iterator[data.AnnotatedDocument], output_dir: pathlib.Path | str | None = None, output_name: str = 'data.jsonl', show_progress: bool = True, ) -> None: """Saves annotated documents to a JSON Lines file. Args: annotated_documents: Iterator over AnnotatedDocument objects to save. output_dir: The directory to which the JSONL file should be written. Can be a Path object or a string. Defaults to 'test_output/' if None. output_name: File name for the JSONL file. show_progress: Whether to show a progress bar during saving. Raises: IOError: If the output directory cannot be created. InvalidDatasetError: If no documents are produced. """ if output_dir is None: output_dir = pathlib.Path('test_output') else: output_dir = pathlib.Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) output_file = output_dir / output_name has_data = False doc_count = 0 # Create progress bar progress_bar = progress.create_save_progress_bar( output_path=str(output_file), disable=not show_progress ) with open(output_file, 'w', encoding='utf-8') as f: for adoc in annotated_documents: if not adoc.document_id: continue doc_dict = data_lib.annotated_document_to_dict(adoc) f.write(json.dumps(doc_dict, ensure_ascii=False) + '\n') has_data = True doc_count += 1 progress_bar.update(1) progress_bar.close() if not has_data: raise InvalidDatasetError(f'No documents to save in: {output_file}') if show_progress: progress.print_save_complete(doc_count, str(output_file)) def load_annotated_documents_jsonl( jsonl_path: pathlib.Path, show_progress: bool = True, ) -> Iterator[data.AnnotatedDocument]: """Loads annotated documents from a JSON Lines file. Args: jsonl_path: The file path to the JSON Lines file. show_progress: Whether to show a progress bar during loading. Yields: AnnotatedDocument objects. Raises: IOError: If the file does not exist or is invalid. """ if not os.path.exists(jsonl_path): raise IOError(f'File does not exist: {jsonl_path}') # Get file size for progress bar file_size = os.path.getsize(jsonl_path) # Create progress bar progress_bar = progress.create_load_progress_bar( file_path=str(jsonl_path), total_size=file_size if show_progress else None, disable=not show_progress, ) doc_count = 0 bytes_read = 0 with open(jsonl_path, 'r', encoding='utf-8') as f: for line in f: line_bytes = len(line.encode('utf-8')) bytes_read += line_bytes progress_bar.update(line_bytes) line = line.strip() if not line: continue doc_dict = json.loads(line) doc_count += 1 yield data_lib.dict_to_annotated_document(doc_dict) progress_bar.close() if show_progress: progress.print_load_complete(doc_count, str(jsonl_path)) def _read_csv( filepath: pathlib.Path, column_names: list[str], delimiter: str = ',' ) -> Iterator[dict[str, Any]]: """Reads a CSV file and yields rows as dicts. Args: filepath: The path to the file. column_names: The names of the columns to read. delimiter: The delimiter to use when reading the CSV file. Yields: An iterator of dicts representing each row. Raises: IOError: If the file does not exist. InvalidDatasetError: If the dataset is empty or invalid. """ if not os.path.exists(filepath): raise IOError(f'File does not exist: {filepath}') try: with open(filepath, 'r', encoding='utf-8') as f: df = pd.read_csv(f, usecols=column_names, dtype=str, delimiter=delimiter) for _, row in df.iterrows(): yield row.to_dict() except pd.errors.EmptyDataError as e: raise InvalidDatasetError(f'Empty dataset: {filepath}') from e except ValueError as e: raise InvalidDatasetError(f'Invalid dataset file: {filepath}') from e def is_url(text: str) -> bool: """Check if the given text is a valid URL. Uses urllib.parse to validate that the text is a properly formed URL with http or https scheme and a valid network location. Args: text: The string to check. Returns: True if the text is a valid URL with http(s) scheme, False otherwise. """ if not text or not isinstance(text, str): return False text = text.strip() # Reject text with whitespace (not a pure URL) if ' ' in text or '\n' in text or '\t' in text: return False try: result = urlparse.urlparse(text) hostname = result.hostname # Must have valid scheme, netloc, and hostname if not (result.scheme in ('http', 'https') and result.netloc and hostname): return False # Accept IPs, localhost, or domains with dots try: ipaddress.ip_address(hostname) return True except ValueError: return hostname == 'localhost' or '.' in hostname except (ValueError, AttributeError): return False def download_text_from_url( url: str, timeout: int = DEFAULT_TIMEOUT_SECONDS, show_progress: bool = True, chunk_size: int = 8192, ) -> str: """Download text content from a URL with optional progress bar. Args: url: The URL to download from. timeout: Request timeout in seconds. show_progress: Whether to show a progress bar during download. chunk_size: Size of chunks to download at a time. Returns: The text content of the URL. Raises: requests.RequestException: If the download fails. ValueError: If the content is not text-based. """ try: # Make initial request to get headers response = requests.get(url, stream=True, timeout=timeout) response.raise_for_status() # Check content type content_type = response.headers.get('Content-Type', '').lower() if not any( ct in content_type for ct in ['text/', 'application/json', 'application/xml'] ): # Try to proceed anyway, but warn print(f"Warning: Content-Type '{content_type}' may not be text-based") # Get content length for progress bar total_size = int(response.headers.get('Content-Length', 0)) filename = url.split('/')[-1][:50] # Download content with progress bar chunks = [] if show_progress and total_size > 0: progress_bar = progress.create_download_progress_bar( total_size=total_size, url=url ) for chunk in response.iter_content(chunk_size=chunk_size): if chunk: chunks.append(chunk) progress_bar.update(len(chunk)) progress_bar.close() else: # Download without progress bar for chunk in response.iter_content(chunk_size=chunk_size): if chunk: chunks.append(chunk) # Combine chunks and decode content = b''.join(chunks) # Try to decode as text encodings = ['utf-8', 'latin-1', 'ascii', 'utf-16'] text_content = None for encoding in encodings: try: text_content = content.decode(encoding) break except UnicodeDecodeError: continue if text_content is None: raise ValueError(f'Could not decode content from {url} as text') # Show content summary with clean formatting if show_progress: char_count = len(text_content) word_count = len(text_content.split()) progress.print_download_complete(char_count, word_count, filename) return text_content except requests.RequestException as e: raise requests.RequestException( f'Failed to download from {url}: {str(e)}' ) from e ================================================ FILE: langextract/plugins.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provider discovery and registration system. This module provides centralized provider discovery without circular imports. It supports both built-in providers and third-party providers via entry points. """ from __future__ import annotations import functools import importlib from importlib import metadata from absl import logging from langextract.core import base_model __all__ = ["available_providers", "get_provider_class"] # Static mapping for built-in providers (always available) _BUILTINS: dict[str, str] = { "gemini": "langextract.providers.gemini:GeminiLanguageModel", "ollama": "langextract.providers.ollama:OllamaLanguageModel", } # Optional built-in providers (require extra dependencies) _OPTIONAL_BUILTINS: dict[str, str] = { "openai": "langextract.providers.openai:OpenAILanguageModel", } def _safe_entry_points(group: str) -> list: """Get entry points with Python 3.8-3.12 compatibility. Args: group: Entry point group name. Returns: List of entry points in the specified group. """ eps = metadata.entry_points() try: # Python 3.10+ return list(eps.select(group=group)) except AttributeError: # Python 3.8-3.9 return list(getattr(eps, "get")(group, [])) @functools.lru_cache(maxsize=1) def _discovered() -> dict[str, str]: """Cache discovered third-party providers. Returns: Dictionary mapping provider names to import specs. """ discovered: dict[str, str] = {} for ep in _safe_entry_points("langextract.providers"): # Handle both old and new entry_points API if hasattr(ep, "value"): discovered.setdefault(ep.name, ep.value) else: # Legacy API - construct from module and attr value = f"{ep.module}:{ep.attr}" if ep.attr else ep.module discovered.setdefault(ep.name, value) if discovered: logging.debug( "Discovered third-party providers: %s", list(discovered.keys()) ) return discovered def available_providers( allow_override: bool = False, include_optional: bool = True ) -> dict[str, str]: """Get all available providers (built-in + optional + third-party). Args: allow_override: If True, third-party providers can override built-ins. If False (default), built-ins take precedence. include_optional: If True (default), include optional built-in providers that may require extra dependencies. Returns: Dictionary mapping provider names to import specifications. """ providers = dict(_discovered()) if include_optional: if allow_override: # Third-party can override optional built-ins providers.update(_OPTIONAL_BUILTINS) else: # Optional built-ins override third-party providers = {**providers, **_OPTIONAL_BUILTINS} # Always add core built-ins with highest precedence (unless allow_override) if allow_override: # Third-party and optional can override core built-ins providers.update(_BUILTINS) else: # Core built-ins take precedence over everything providers = {**providers, **_BUILTINS} return providers def _load_class(spec: str) -> type[base_model.BaseLanguageModel]: """Load a provider class from module:Class specification. Args: spec: Import specification in format "module.path:ClassName". Returns: The loaded provider class. Raises: ImportError: If the spec is invalid or module cannot be imported. TypeError: If the loaded class is not a BaseLanguageModel. """ module_path, _, class_name = spec.partition(":") if not module_path or not class_name: raise ImportError( f"Invalid provider spec '{spec}' - expected 'module:Class'" ) try: module = importlib.import_module(module_path) except ImportError as e: raise ImportError( f"Failed to import provider module '{module_path}': {e}" ) from e try: cls = getattr(module, class_name) except AttributeError as e: raise ImportError( f"Provider class '{class_name}' not found in module '{module_path}'" ) from e # Validate it's a language model if not isinstance(cls, type) or not issubclass( cls, base_model.BaseLanguageModel ): # Fallback: check structural compatibility for non-ABC classes missing = [] for method in ("infer", "parse_output"): if not hasattr(cls, method): missing.append(method) if missing: raise TypeError( f"{cls} is not a BaseLanguageModel and missing required methods:" f" {missing}" ) logging.warning( "Provider %s does not inherit from BaseLanguageModel but appears" " compatible", cls, ) return cls @functools.lru_cache(maxsize=None) # Cache all loaded classes def get_provider_class( name: str, allow_override: bool = False, include_optional: bool = True ) -> type[base_model.BaseLanguageModel]: """Get a provider class by name. Args: name: Provider name (e.g., "gemini", "openai", "ollama"). allow_override: If True, allow third-party providers to override built-ins. include_optional: If True (default), include optional providers that may require extra dependencies. Returns: The provider class. Raises: KeyError: If the provider name is not found. ImportError: If the provider module cannot be imported (including missing optional dependencies). TypeError: If the provider class is not compatible. """ providers = available_providers(allow_override, include_optional) if name not in providers: available = sorted(providers.keys()) raise KeyError( f"Unknown provider '{name}'. Available providers:" f" {', '.join(available) if available else 'none'}.\nHint: Did you" " install the necessary extras (e.g., pip install" f" langextract[{name}])?" ) return _load_class(providers[name]) ================================================ FILE: langextract/progress.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Progress and visualization utilities for LangExtract.""" from __future__ import annotations from typing import Any import urllib.parse import tqdm # ANSI color codes for terminal output BLUE = "\033[94m" GREEN = "\033[92m" CYAN = "\033[96m" BOLD = "\033[1m" RESET = "\033[0m" # Google Blue color for progress bars GOOGLE_BLUE = "#4285F4" def create_download_progress_bar( total_size: int, url: str, ncols: int = 100, max_url_length: int = 50 ) -> tqdm.tqdm: """Create a styled progress bar for downloads. Args: total_size: Total size in bytes. url: The URL being downloaded. ncols: Number of columns for the progress bar. max_url_length: Maximum length to show for the URL. Returns: A configured tqdm progress bar. """ # Truncate URL if too long, keeping the domain and end if len(url) > max_url_length: parsed = urllib.parse.urlparse(url) domain = parsed.netloc or parsed.hostname or "unknown" path_parts = parsed.path.strip("/").split("/") filename = path_parts[-1] if path_parts and path_parts[-1] else "file" available = max_url_length - len(domain) - len(filename) - 5 if available > 0: url_display = f"{domain}/.../{filename}" else: url_display = url[: max_url_length - 3] + "..." else: url_display = url return tqdm.tqdm( total=total_size, unit="B", unit_scale=True, desc=( f"{BLUE}{BOLD}LangExtract{RESET}: Downloading" f" {GREEN}{url_display}{RESET}" ), bar_format=( "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}" " [{elapsed}<{remaining}, {rate_fmt}]" ), colour=GOOGLE_BLUE, ncols=ncols, ) def create_extraction_progress_bar( iterable: Any, model_info: str | None = None, disable: bool = False ) -> tqdm.tqdm: """Create a styled progress bar for extraction. Args: iterable: The iterable to wrap with progress bar. model_info: Optional model information to display (e.g., "gemini-1.5-pro"). disable: Whether to disable the progress bar. Returns: A configured tqdm progress bar. """ desc = format_extraction_progress(model_info) return tqdm.tqdm( iterable, desc=desc, bar_format="{desc} [{elapsed}]", disable=disable, dynamic_ncols=True, ) def print_download_complete( char_count: int, word_count: int, filename: str ) -> None: """Print a styled download completion message. Args: char_count: Number of characters downloaded. word_count: Number of words downloaded. filename: Name of the downloaded file. """ print( f"{GREEN}✓{RESET} Downloaded {BOLD}{char_count:,}{RESET} characters " f"({BOLD}{word_count:,}{RESET} words) from {BLUE}{filename}{RESET}", flush=True, ) def print_extraction_complete() -> None: """Print a generic extraction completion message.""" print(f"{GREEN}✓{RESET} Extraction processing complete", flush=True) def print_extraction_summary( num_extractions: int, unique_classes: int, elapsed_time: float | None = None, chars_processed: int | None = None, num_chunks: int | None = None, ) -> None: """Print a styled extraction summary with optional performance metrics. Args: num_extractions: Total number of extractions. unique_classes: Number of unique extraction classes. elapsed_time: Optional elapsed time in seconds. chars_processed: Optional number of characters processed. num_chunks: Optional number of chunks processed. """ print( f"{GREEN}✓{RESET} Extracted {BOLD}{num_extractions}{RESET} entities " f"({BOLD}{unique_classes}{RESET} unique types)", flush=True, ) if elapsed_time is not None: metrics = [] # Time metrics.append(f"Time: {BOLD}{elapsed_time:.2f}s{RESET}") # Speed if chars_processed is not None and elapsed_time > 0: speed = chars_processed / elapsed_time metrics.append(f"Speed: {BOLD}{speed:,.0f}{RESET} chars/sec") if num_chunks is not None: metrics.append(f"Chunks: {BOLD}{num_chunks}{RESET}") for metric in metrics: print(f" {CYAN}•{RESET} {metric}", flush=True) def create_save_progress_bar( output_path: str, disable: bool = False ) -> tqdm.tqdm: """Create a progress bar for saving documents. Args: output_path: The output file path. disable: Whether to disable the progress bar. Returns: A configured tqdm progress bar. """ filename = output_path.split("/")[-1] return tqdm.tqdm( desc=( f"{BLUE}{BOLD}LangExtract{RESET}: Saving to {GREEN}{filename}{RESET}" ), unit=" docs", disable=disable, ) def create_load_progress_bar( file_path: str, total_size: int | None = None, disable: bool = False ) -> tqdm.tqdm: """Create a progress bar for loading documents. Args: file_path: The file path being loaded. total_size: Optional total file size in bytes. disable: Whether to disable the progress bar. Returns: A configured tqdm progress bar. """ filename = file_path.split("/")[-1] if total_size: return tqdm.tqdm( total=total_size, desc=( f"{BLUE}{BOLD}LangExtract{RESET}: Loading {GREEN}{filename}{RESET}" ), unit="B", unit_scale=True, disable=disable, ) else: return tqdm.tqdm( desc=( f"{BLUE}{BOLD}LangExtract{RESET}: Loading {GREEN}{filename}{RESET}" ), unit=" docs", disable=disable, ) def print_save_complete(num_docs: int, file_path: str) -> None: """Print a save completion message. Args: num_docs: Number of documents saved. file_path: Path to the saved file. """ filename = file_path.split("/")[-1] print( f"{GREEN}✓{RESET} Saved {BOLD}{num_docs}{RESET} documents to" f" {GREEN}{filename}{RESET}", flush=True, ) def print_load_complete(num_docs: int, file_path: str) -> None: """Print a load completion message. Args: num_docs: Number of documents loaded. file_path: Path to the loaded file. """ filename = file_path.split("/")[-1] print( f"{GREEN}✓{RESET} Loaded {BOLD}{num_docs}{RESET} documents from" f" {GREEN}{filename}{RESET}", flush=True, ) def get_model_info(language_model: Any) -> str | None: """Extract model information from a language model instance. Args: language_model: A language model instance. Returns: A string describing the model, or None if not available. """ if hasattr(language_model, "model_id"): return language_model.model_id if hasattr(language_model, "model_url"): return language_model.model_url return None def format_extraction_stats(current_chars: int, processed_chars: int) -> str: """Format extraction progress statistics with colors. Args: current_chars: Number of characters in current batch. processed_chars: Total number of characters processed so far. Returns: Formatted string with colored statistics. """ current_str = f"{GREEN}{current_chars:,}{RESET}" processed_str = f"{GREEN}{processed_chars:,}{RESET}" return f"current={current_str} chars, processed={processed_str} chars" def create_extraction_postfix(current_chars: int, processed_chars: int) -> str: """Create a formatted postfix string for extraction progress. Args: current_chars: Number of characters in current batch. processed_chars: Total number of characters processed so far. Returns: Formatted string with statistics. """ current_str = f"{GREEN}{current_chars:,}{RESET}" processed_str = f"{GREEN}{processed_chars:,}{RESET}" return f"current={current_str} chars, processed={processed_str} chars" def format_extraction_progress( model_info: str | None, current_chars: int | None = None, processed_chars: int | None = None, ) -> str: """Format the complete extraction progress bar description. Args: model_info: Optional model information (e.g., "gemini-2.0-flash"). current_chars: Number of characters in current batch (optional). processed_chars: Total number of characters processed so far (optional). Returns: Formatted description string. """ # Base description if model_info: desc = f"{BLUE}{BOLD}LangExtract{RESET}: model={GREEN}{model_info}{RESET}" else: desc = f"{BLUE}{BOLD}LangExtract{RESET}: Processing" # Add stats if provided if current_chars is not None and processed_chars is not None: current_str = f"{GREEN}{current_chars:,}{RESET}" processed_str = f"{GREEN}{processed_chars:,}{RESET}" desc += f", current={current_str} chars, processed={processed_str} chars" return desc def create_pass_progress_bar( total_passes: int, disable: bool = False ) -> tqdm.tqdm: """Create a progress bar for sequential extraction passes. Args: total_passes: Total number of sequential passes. disable: Whether to disable the progress bar. Returns: A configured tqdm progress bar. """ desc = f"{BLUE}{BOLD}LangExtract{RESET}: Extraction passes" return tqdm.tqdm( total=total_passes, desc=desc, bar_format=( "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}]" ), disable=disable, colour=GOOGLE_BLUE, ncols=100, ) ================================================ FILE: langextract/prompt_validation.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Prompt validation for alignment checks on few-shot examples.""" from __future__ import annotations from collections.abc import Sequence import copy import dataclasses import enum from absl import logging from langextract import resolver from langextract.core import data from langextract.core import tokenizer as tokenizer_lib __all__ = [ "PromptValidationLevel", "ValidationIssue", "ValidationReport", "PromptAlignmentError", "AlignmentPolicy", "validate_prompt_alignment", "handle_alignment_report", ] _FUZZY_ALIGNMENT_MIN_THRESHOLD = 0.75 class PromptValidationLevel(enum.Enum): """Validation levels for prompt alignment checks.""" OFF = "off" WARNING = "warning" ERROR = "error" class _IssueKind(enum.Enum): """Internal categorization of alignment issues.""" FAILED = "failed" # alignment_status is None NON_EXACT = "non_exact" # MATCH_FUZZY or MATCH_LESSER @dataclasses.dataclass(frozen=True) class ValidationIssue: """Represents a single validation issue found during alignment.""" example_index: int example_id: str | None extraction_class: str extraction_text_preview: str alignment_status: data.AlignmentStatus | None issue_kind: _IssueKind char_interval: tuple[int, int] | None = None token_interval: tuple[int, int] | None = None def short_msg(self) -> str: """Returns a concise message describing the issue.""" ex_id = f" id={self.example_id}" if self.example_id else "" span = "" if self.char_interval: span = f" char_span={self.char_interval}" return ( f"[example#{self.example_index}{ex_id}] " f"class='{self.extraction_class}' " f"status={self.alignment_status} " f"text='{self.extraction_text_preview}'{span}" ) @dataclasses.dataclass class ValidationReport: """Collection of validation issues from prompt alignment checks.""" issues: list[ValidationIssue] @property def has_failed(self) -> bool: """Returns True if any extraction failed to align.""" return any(i.issue_kind is _IssueKind.FAILED for i in self.issues) @property def has_non_exact(self) -> bool: """Returns True if any extraction has non-exact alignment.""" return any(i.issue_kind is _IssueKind.NON_EXACT for i in self.issues) class PromptAlignmentError(RuntimeError): """Raised when prompt alignment validation fails under ERROR mode.""" @dataclasses.dataclass(frozen=True) class AlignmentPolicy: """Configuration for alignment validation behavior.""" enable_fuzzy_alignment: bool = True fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD accept_match_lesser: bool = True def _preview(s: str, n: int = 120) -> str: """Creates a preview of text for logging, collapsing whitespace.""" s = " ".join(s.split()) # Collapse whitespace for logs return s if len(s) <= n else s[: n - 1] + "…" def validate_prompt_alignment( examples: Sequence[data.ExampleData], aligner: resolver.WordAligner | None = None, policy: AlignmentPolicy | None = None, tokenizer: tokenizer_lib.Tokenizer | None = None, ) -> ValidationReport: """Align extractions to their own example text and collect issues. Args: examples: The few-shot examples to validate. aligner: WordAligner instance to use (creates new if None). policy: Alignment configuration (uses defaults if None). tokenizer: Optional tokenizer to use for alignment. If None, defaults to RegexTokenizer. Returns: ValidationReport containing any alignment issues found. """ if not examples: return ValidationReport(issues=[]) aligner = aligner or resolver.WordAligner() policy = policy or AlignmentPolicy() issues: list[ValidationIssue] = [] for idx, ex in enumerate(examples): # Defensive copy so validation never mutates user examples. copied_extractions = [[copy.deepcopy(e) for e in ex.extractions]] aligned_groups = aligner.align_extractions( extraction_groups=copied_extractions, source_text=ex.text, token_offset=0, char_offset=0, enable_fuzzy_alignment=policy.enable_fuzzy_alignment, fuzzy_alignment_threshold=policy.fuzzy_alignment_threshold, accept_match_lesser=policy.accept_match_lesser, tokenizer_impl=tokenizer, ) for aligned in aligned_groups[0]: status = getattr(aligned, "alignment_status", None) char_interval = getattr(aligned, "char_interval", None) token_interval = getattr(aligned, "token_interval", None) klass = getattr(aligned, "extraction_class", "") text = getattr(aligned, "extraction_text", "") if status is None: issues.append( ValidationIssue( example_index=idx, example_id=getattr(ex, "example_id", None), extraction_class=klass, extraction_text_preview=_preview(text), alignment_status=None, issue_kind=_IssueKind.FAILED, char_interval=None, token_interval=None, ) ) elif status in ( data.AlignmentStatus.MATCH_FUZZY, data.AlignmentStatus.MATCH_LESSER, ): char_interval_tuple = None token_interval_tuple = None if char_interval: char_interval_tuple = (char_interval.start_pos, char_interval.end_pos) if token_interval: token_interval_tuple = ( token_interval.start_index, token_interval.end_index, ) issues.append( ValidationIssue( example_index=idx, example_id=getattr(ex, "example_id", None), extraction_class=klass, extraction_text_preview=_preview(text), alignment_status=status, issue_kind=_IssueKind.NON_EXACT, char_interval=char_interval_tuple, token_interval=token_interval_tuple, ) ) return ValidationReport(issues=issues) def handle_alignment_report( report: ValidationReport, level: PromptValidationLevel, *, strict_non_exact: bool = False, ) -> None: """Log or raise based on validation level. Args: report: The validation report to handle. level: The validation level determining behavior. strict_non_exact: If True, treat non-exact matches as errors in ERROR mode. Raises: PromptAlignmentError: If validation fails in ERROR mode. """ if level is PromptValidationLevel.OFF: return for issue in report.issues: if issue.issue_kind is _IssueKind.NON_EXACT: logging.warning( "Prompt alignment: non-exact match: %s", issue.short_msg() ) else: logging.warning( "Prompt alignment: FAILED to align: %s", issue.short_msg() ) if level is PromptValidationLevel.ERROR: failed = [i for i in report.issues if i.issue_kind is _IssueKind.FAILED] non_exact = [ i for i in report.issues if i.issue_kind is _IssueKind.NON_EXACT ] if failed: sample = failed[0].short_msg() raise PromptAlignmentError( f"Prompt alignment validation failed: {len(failed)} extraction(s) " f"could not be aligned (e.g., {sample})" ) if strict_non_exact and non_exact: sample = non_exact[0].short_msg() raise PromptAlignmentError( "Prompt alignment validation failed under strict mode: " f"{len(non_exact)} non-exact match(es) found (e.g., {sample})" ) ================================================ FILE: langextract/prompting.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library for building prompts.""" from __future__ import annotations import dataclasses import json import pathlib import pydantic from typing_extensions import override import yaml from langextract.core import data from langextract.core import exceptions from langextract.core import format_handler class PromptBuilderError(exceptions.LangExtractError): """Failure to build prompt.""" class ParseError(PromptBuilderError): """Prompt template cannot be parsed.""" @dataclasses.dataclass class PromptTemplateStructured: """A structured prompt template for few-shot examples. Attributes: description: Instructions or guidelines for the LLM. examples: ExampleData objects demonstrating expected input→output behavior. """ description: str examples: list[data.ExampleData] = dataclasses.field(default_factory=list) def read_prompt_template_structured_from_file( prompt_path: str, format_type: data.FormatType = data.FormatType.YAML, ) -> PromptTemplateStructured: """Reads a structured prompt template from a file. Args: prompt_path: Path to a file containing PromptTemplateStructured data. format_type: The format of the file; YAML or JSON. Returns: A PromptTemplateStructured object loaded from the file. Raises: ParseError: If the file cannot be parsed successfully. """ adapter = pydantic.TypeAdapter(PromptTemplateStructured) try: with pathlib.Path(prompt_path).open("rt") as f: data_dict = {} prompt_content = f.read() if format_type == data.FormatType.YAML: data_dict = yaml.safe_load(prompt_content) elif format_type == data.FormatType.JSON: data_dict = json.loads(prompt_content) return adapter.validate_python(data_dict) except Exception as e: raise ParseError( f"Failed to parse prompt template from file: {prompt_path}" ) from e @dataclasses.dataclass class QAPromptGenerator: """Generates question-answer prompts from the provided template.""" template: PromptTemplateStructured format_handler: format_handler.FormatHandler examples_heading: str = "Examples" question_prefix: str = "Q: " answer_prefix: str = "A: " def __str__(self) -> str: """Returns a string representation of the prompt with an empty question.""" return self.render("") def format_example_as_text(self, example: data.ExampleData) -> str: """Formats a single example for the prompt. Args: example: The example data to format. Returns: A string representation of the example, including the question and answer. """ question = example.text answer = self.format_handler.format_extraction_example(example.extractions) return "\n".join([ f"{self.question_prefix}{question}", f"{self.answer_prefix}{answer}\n", ]) def render(self, question: str, additional_context: str | None = None) -> str: """Generate a text representation of the prompt. Args: question: That will be presented to the model. additional_context: Additional context to include in the prompt. An empty string is ignored. Returns: Text prompt with a question to be presented to a language model. """ prompt_lines: list[str] = [f"{self.template.description}\n"] if additional_context: prompt_lines.append(f"{additional_context}\n") if self.template.examples: prompt_lines.append(self.examples_heading) for ex in self.template.examples: prompt_lines.append(self.format_example_as_text(ex)) prompt_lines.append(f"{self.question_prefix}{question}") prompt_lines.append(self.answer_prefix) return "\n".join(prompt_lines) class PromptBuilder: """Builds prompts for text chunks using a QAPromptGenerator. This base class provides a simple interface for prompt generation. Subclasses can extend this to add stateful behavior like cross-chunk context tracking. """ def __init__(self, generator: QAPromptGenerator): """Initializes the builder with the given prompt generator. Args: generator: The underlying prompt generator to use. """ self._generator = generator def build_prompt( self, chunk_text: str, document_id: str, additional_context: str | None = None, ) -> str: """Builds a prompt for the given chunk. Args: chunk_text: The text of the current chunk to process. document_id: Identifier for the source document. additional_context: Optional additional context from the document. Returns: The rendered prompt string ready for the language model. """ del document_id # Unused in base class. return self._generator.render( question=chunk_text, additional_context=additional_context, ) class ContextAwarePromptBuilder(PromptBuilder): """Prompt builder with cross-chunk context tracking. Extends PromptBuilder to inject text from the previous chunk into each prompt. This helps language models resolve coreferences across chunk boundaries (e.g., connecting "She" to "Dr. Sarah Johnson" from the previous chunk). Context is tracked per document_id, so multiple documents can be processed without context bleeding between them. """ _CONTEXT_PREFIX = "[Previous text]: ..." def __init__( self, generator: QAPromptGenerator, context_window_chars: int | None = None, ): """Initializes the builder with context tracking configuration. Args: generator: The underlying prompt generator to use. context_window_chars: Number of characters from the previous chunk's tail to include as context. Defaults to None (disabled). """ super().__init__(generator) self._context_window_chars = context_window_chars self._prev_chunk_by_doc_id: dict[str, str] = {} @property def context_window_chars(self) -> int | None: """Number of trailing characters from previous chunk to include.""" return self._context_window_chars @override def build_prompt( self, chunk_text: str, document_id: str, additional_context: str | None = None, ) -> str: """Builds a prompt, injecting previous chunk context if enabled. Args: chunk_text: The text of the current chunk to process. document_id: Identifier for the source document (used to track context per document). additional_context: Optional additional context from the document. Returns: The rendered prompt string ready for the language model. """ effective_context = self._build_effective_context( document_id, additional_context ) prompt = self._generator.render( question=chunk_text, additional_context=effective_context, ) self._update_state(document_id, chunk_text) return prompt def _build_effective_context( self, document_id: str, additional_context: str | None, ) -> str | None: """Combines previous chunk context with any additional context. Args: document_id: Identifier for the source document. additional_context: Optional additional context from the document. Returns: Combined context string, or None if no context is available. """ context_parts: list[str] = [] if self._context_window_chars and document_id in self._prev_chunk_by_doc_id: prev_text = self._prev_chunk_by_doc_id[document_id] window = prev_text[-self._context_window_chars :] context_parts.append(f"{self._CONTEXT_PREFIX}{window}") if additional_context: context_parts.append(additional_context) return "\n\n".join(context_parts) if context_parts else None def _update_state(self, document_id: str, chunk_text: str) -> None: """Stores current chunk as context for the next chunk in this document. Args: document_id: Identifier for the source document. chunk_text: The current chunk text to store. """ if self._context_window_chars: self._prev_chunk_by_doc_id[document_id] = chunk_text ================================================ FILE: langextract/providers/README.md ================================================ # LangExtract Provider System This directory contains the provider system for LangExtract, which enables support for different Large Language Model (LLM) backends. **Quick Start**: Use the [provider plugin generator script](../../scripts/create_provider_plugin.py) to create a new provider in minutes: ```bash python scripts/create_provider_plugin.py MyProvider --with-schema ``` ## Architecture Overview The provider system uses a **registry pattern** with **automatic discovery**: 1. **Registry** (`registry.py`): Maps model ID patterns to provider classes 2. **Factory** (`../factory.py`): Creates provider instances based on model IDs 3. **Providers**: Implement the `BaseLanguageModel` interface ### Provider Resolution Flow ``` User Code LangExtract Provider ───────── ─────────── ──────── | | | | lx.extract( | | | model_id="gemini-2.5-flash") | |─────────────────────────────> | | | | | factory.create_model() | | | | | registry.resolve("gemini-2.5-flash") | | Pattern match: ^gemini | | ↓ | | GeminiLanguageModel | | | | | Instantiate provider | | |─────────────────────────────>| | | | | | Provider API calls | | |<─────────────────────────────| | | | |<──────────────────────────── | | AnnotatedDocument | | ``` ### Explicit Provider Selection When multiple providers might support the same model ID, or when you want to use a specific provider, you can explicitly specify the provider: ```python import langextract as lx # Method 1: Using factory directly with provider parameter config = lx.factory.ModelConfig( model_id="gpt-4", provider="OpenAILanguageModel", # Explicit provider provider_kwargs={"api_key": "..."} ) model = lx.factory.create_model(config) # Method 2: Using provider without model_id (uses provider's default) config = lx.factory.ModelConfig( provider="GeminiLanguageModel", # Will use default gemini-2.5-flash provider_kwargs={"api_key": "..."} ) model = lx.factory.create_model(config) # Method 3: Auto-detection (when no conflicts exist) config = lx.factory.ModelConfig( model_id="gemini-2.5-flash" # Provider auto-detected ) model = lx.factory.create_model(config) ``` Provider names can be: - Full class name: `"GeminiLanguageModel"`, `"OpenAILanguageModel"`, `"OllamaLanguageModel"` - Partial match: `"gemini"`, `"openai"`, `"ollama"` (case-insensitive) ## Provider Types ### 1. Core Providers (Always Available) Ships with langextract, dependencies included: - **Gemini** (`gemini.py`): Google's Gemini models - **Ollama** (`ollama.py`): Local models via Ollama ### 2. Built-in Provider with Optional Dependencies Ships with langextract, but requires extra installation: - **OpenAI** (`openai.py`): OpenAI's GPT models - Code included in package - Requires: `pip install langextract[openai]` to install OpenAI SDK - Future: May be moved to external plugin package ### 3. External Plugins (Third-party) Separate packages that extend LangExtract with new providers: - **Installed separately**: `pip install langextract-yourprovider` - **Auto-discovered**: Uses Python entry points for automatic registration - **Zero configuration**: Import langextract and the provider is available - **Independent updates**: Update providers without touching core ```python # Install a third-party provider pip install langextract-yourprovider # Use it immediately - no imports needed! import langextract as lx result = lx.extract( text="...", model_id="yourmodel-latest" # Automatically finds the provider ) ``` #### How Plugin Discovery Works ``` 1. pip install langextract-yourprovider └── Installs package containing: • Provider class with @lx.providers.registry.register decorator • Python entry point pointing to this class 2. import langextract └── Loads providers/__init__.py └── Plugin loading is lazy (on-demand) 3. lx.extract(model_id="yourmodel-latest") └── Triggers plugin discovery via entry points └── @lx.providers.registry.register decorator fires └── Provider patterns added to registry └── Registry matches pattern and uses your provider ``` **Important Notes:** - Plugin loading is **lazy** - plugins are discovered when first needed - To manually trigger plugin loading: `lx.providers.load_plugins_once()` - Set `LANGEXTRACT_DISABLE_PLUGINS=1` to disable plugin loading - Registry entries are tuples: `(patterns_list, priority_int)` ## How Provider Selection Works When you call `lx.extract(model_id="gemini-2.5-flash", ...)`, here's what happens: 1. **Factory receives model_id**: "gemini-2.5-flash" 2. **Registry searches patterns**: Each provider registers regex patterns 3. **First match wins**: Returns the matching provider class 4. **Provider instantiated**: With model_id and any kwargs 5. **Inference runs**: Using the selected provider ### Pattern Registration Example ```python import langextract as lx # Gemini provider registration: @lx.providers.registry.register( r'^GeminiLanguageModel$', # Explicit: model_id="GeminiLanguageModel" r'^gemini', # Prefix: model_id="gemini-2.5-flash" r'^palm' # Legacy: model_id="palm-2" ) class GeminiLanguageModel(lx.inference.BaseLanguageModel): def __init__(self, model_id: str, api_key: str = None, **kwargs): # Initialize Gemini client ... def infer(self, batch_prompts, **kwargs): # Call Gemini API ... ``` ## Usage Examples ### Using Default Provider Selection ```python import langextract as lx # Automatically selects Gemini provider result = lx.extract( text="...", model_id="gemini-2.5-flash" ) ``` ### Passing Parameters to Providers Parameters flow from `lx.extract()` to providers through several mechanisms: ```python # 1. Common parameters handled by lx.extract itself: result = lx.extract( text="Your document", model_id="gemini-2.5-flash", prompt_description="Extract key facts", examples=[...], # Used for few-shot prompting num_workers=4, # Parallel processing max_chunk_size=3000, # Document chunking ) # 2. Provider-specific parameters passed via **kwargs: result = lx.extract( text="Your document", model_id="gemini-2.5-flash", prompt_description="Extract entities", # These go directly to the Gemini provider: temperature=0.7, # Sampling temperature api_key="your-key", # Override environment variable max_output_tokens=1000, # Token limit ) ``` ### Using the Factory for Advanced Control ```python # When you need explicit provider selection or advanced configuration from langextract import factory # Specify both model and provider (useful when multiple providers support same model) config = factory.ModelConfig( model_id="gemma2:2b", provider="OllamaLanguageModel", # Explicitly use Ollama provider_kwargs={ "model_url": "http://localhost:11434" } ) model = factory.create_model(config) ``` ### Direct Provider Usage ```python import langextract as lx # Direct import if you prefer (optional) from langextract.providers.gemini import GeminiLanguageModel model = GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="your-key" ) outputs = model.infer(["prompt1", "prompt2"]) ``` ## Creating a New Provider **📁 Complete Example**: See [examples/custom_provider_plugin/](../../examples/custom_provider_plugin/) for a fully-functional plugin template with testing and documentation. ### Quick Start Checklist Creating a provider plugin? Follow this checklist: #### ☐ **1. Setup Package Structure** ``` langextract-yourprovider/ ├── pyproject.toml # Package config with entry point ├── README.md # Documentation ├── LICENSE # License file └── langextract_yourprovider/ # Package directory ├── __init__.py # Exports provider class ├── provider.py # Provider implementation └── schema.py # (Optional) Custom schema ``` #### ☐ **2. Configure Entry Point** (`pyproject.toml`) ```toml [build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "langextract-yourprovider" version = "0.1.0" dependencies = ["langextract>=1.0.0"] [project.entry-points."langextract.providers"] yourprovider = "langextract_yourprovider:YourProviderLanguageModel" ``` #### ☐ **3. Implement Provider** (`provider.py`) - [ ] Import required modules - [ ] Add `@lx.providers.registry.register()` decorator with patterns - [ ] Inherit from `lx.inference.BaseLanguageModel` - [ ] Implement `__init__()` method - [ ] Implement `infer()` method returning `ScoredOutput` objects - [ ] Export class from `__init__.py` #### ☐ **4. (Optional) Add Schema Support** (`schema.py`) - [ ] Create schema class inheriting from `lx.schema.BaseSchema` - [ ] Implement `from_examples()` class method - [ ] Implement `to_provider_config()` method - [ ] Add `get_schema_class()` to provider - [ ] Handle schema in provider's `__init__()` and `infer()` #### ☐ **5. Testing** - [ ] Install plugin with `pip install -e .` - [ ] Test that your provider loads and handles basic inference - [ ] Verify schema support works (if implemented) #### ☐ **6. Documentation** - [ ] Document supported model IDs and patterns - [ ] List required environment variables - [ ] Provide usage examples - [ ] Document any provider-specific parameters #### ☐ **7. Distribution & Community** - [ ] Test installation with `pip install -e .` - [ ] Build package with `python -m build` - [ ] Test in clean environment - [ ] Publish to PyPI with `twine upload dist/*` - [ ] Share your provider by opening an issue on [LangExtract GitHub](https://github.com/google/langextract/issues) to get feedback and help others discover it - [ ] Consider submitting a PR to add your provider to the community providers list (coming soon) ### Option 1: External Plugin (Recommended) External plugins are the recommended approach for adding new providers. They're easy to maintain, distribute, and don't require changes to the core package. #### For Users (Installing an External Plugin) Simply install the plugin package: ```bash pip install langextract-yourprovider # That's it! The provider is now available in langextract ``` #### For Developers (Creating an External Plugin) 1. Create a new package: ``` langextract-myprovider/ ├── pyproject.toml ├── README.md └── langextract_myprovider/ └── __init__.py ``` 2. Configure entry point in `pyproject.toml`: ```toml [build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "langextract-myprovider" version = "0.1.0" dependencies = ["langextract>=1.0.0", "your-sdk"] [project.entry-points."langextract.providers"] # Pattern 1: Register the class directly myprovider = "langextract_myprovider:MyProviderLanguageModel" # Pattern 2: Register a module that self-registers # myprovider = "langextract_myprovider" ``` 3. Implement your provider: ```python # langextract_myprovider/__init__.py import os import langextract as lx @lx.providers.registry.register(r'^mymodel', r'^custom', priority=10) class MyProviderLanguageModel(lx.inference.BaseLanguageModel): def __init__(self, model_id: str, api_key: str = None, **kwargs): super().__init__() self.model_id = model_id self.api_key = api_key or os.environ.get('MYPROVIDER_API_KEY') # Initialize your client self.client = MyProviderClient(api_key=self.api_key) def infer(self, batch_prompts, **kwargs): # Implement inference for prompt in batch_prompts: result = self.client.generate(prompt, **kwargs) yield [lx.inference.ScoredOutput(score=1.0, output=result)] ``` **Pattern Registration Explained:** - The `@register` decorator patterns (e.g., `r'^mymodel'`, `r'^custom'`) define which model IDs your provider supports - When users call `lx.extract(model_id="mymodel-3b")`, the registry matches against these patterns - Your provider will handle any model_id starting with "mymodel" or "custom" - Users can explicitly select your provider using its class name: ```python config = lx.factory.ModelConfig(provider="MyProviderLanguageModel") # Or partial match: provider="myprovider" (matches class name) 4. Publish your package to PyPI: ```bash pip install build twine python -m build twine upload dist/* ``` Now users can install and use your provider with just `pip install langextract-myprovider`! ### Adding Schema Support Schemas enable structured output with strict JSON constraints. Here's how to add schema support to your provider: #### 1. Create a Schema Class ```python # langextract_myprovider/schema.py import langextract as lx from langextract import schema class MyProviderSchema(lx.schema.BaseSchema): def __init__(self, schema_dict: dict): self._schema_dict = schema_dict @property def schema_dict(self) -> dict: return self._schema_dict @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): """Build schema from example extractions.""" # Analyze examples to determine structure extraction_types = {} for example in examples_data: for extraction in example.extractions: class_name = extraction.extraction_class if class_name not in extraction_types: extraction_types[class_name] = set() if extraction.attributes: extraction_types[class_name].update(extraction.attributes.keys()) # Build JSON schema schema_dict = { "type": "object", "properties": { "extractions": { "type": "array", "items": {"type": "object"} # Simplified } } } return cls(schema_dict) def to_provider_config(self) -> dict: """Convert to provider-specific configuration.""" return { "response_schema": self._schema_dict, "structured_output": True } @property def supports_strict_mode(self) -> bool: """Return True if provider enforces valid JSON output.""" return True ``` #### 2. Update Your Provider ```python # langextract_myprovider/provider.py class MyProviderLanguageModel(lx.inference.BaseLanguageModel): def __init__(self, model_id: str, **kwargs): super().__init__() self.model_id = model_id # Schema config will be in kwargs when use_schema_constraints=True self.response_schema = kwargs.get('response_schema') self.structured_output = kwargs.get('structured_output', False) @classmethod def get_schema_class(cls): """Tell LangExtract about our schema support.""" from langextract_myprovider.schema import MyProviderSchema return MyProviderSchema def apply_schema(self, schema_instance): """Apply or clear schema configuration.""" super().apply_schema(schema_instance) if schema_instance: config = schema_instance.to_provider_config() self.response_schema = config.get('response_schema') self.structured_output = config.get('structured_output', False) else: self.response_schema = None self.structured_output = False def infer(self, batch_prompts, **kwargs): for prompt in batch_prompts: # Use schema in API call if available api_params = {} if self.response_schema: api_params['response_schema'] = self.response_schema result = self.client.generate(prompt, **api_params) yield [lx.inference.ScoredOutput(score=1.0, output=result)] ``` #### 3. Schema Usage When users set `use_schema_constraints=True`, LangExtract will: 1. Call your provider's `get_schema_class()` 2. Use `from_examples()` to build a schema from provided examples 3. Call `to_provider_config()` to get provider-specific kwargs 4. Pass these kwargs to your provider's `__init__()` 5. Your provider uses the schema for structured output ### Option 2: Built-in Provider (Requires Core Team Approval) **⚠️ Note**: Adding a provider to the core package requires: - Significant community demand and support - Commitment to long-term maintenance - Approval from the LangExtract maintainers - A pull request to the main repository This approach should only be used for providers that benefit a large portion of the user base. 1. Create your provider file: ```python # langextract/providers/myprovider.py import langextract as lx @lx.providers.registry.register(r'^mymodel', r'^custom') class MyProviderLanguageModel(lx.inference.BaseLanguageModel): # Implementation same as above ``` 2. Import it in `providers/__init__.py`: ```python # In langextract/providers/__init__.py from langextract.providers import myprovider # noqa: F401 ``` 3. Submit a pull request with: - Provider implementation - Comprehensive tests - Documentation - Justification for inclusion in core ## Environment Variables The factory automatically resolves API keys from environment: | Provider | Environment Variables (in priority order) | |----------|------------------------------------------| | Gemini | `GEMINI_API_KEY`, `LANGEXTRACT_API_KEY` | | OpenAI | `OPENAI_API_KEY`, `LANGEXTRACT_API_KEY` | | Ollama | `OLLAMA_BASE_URL` (default: http://localhost:11434) | ## Design Principles 1. **Zero Configuration**: Providers auto-register when imported 2. **Extensible**: Easy to add new providers without modifying core 3. **Lazy Loading**: Optional dependencies only loaded when needed 4. **Explicit Control**: Users can force specific providers when needed 5. **Pattern Priority**: All patterns have equal priority (0) by default ## Common Issues ### Provider Not Found ```python ValueError: No provider registered for model_id='unknown-model' ``` **Solution**: Check available patterns with `registry.list_entries()` ### Plugin Not Loading ```python # Your plugin isn't being discovered ``` **Solutions**: 1. Manually trigger loading: `lx.providers.load_plugins_once()` 2. Check entry points are installed: `pip show -f your-package` 3. Verify no typos in `pyproject.toml` entry point 4. Ensure package is installed: `pip list | grep your-package` ### Missing Dependencies ```python InferenceConfigError: OpenAI provider requires openai package ``` **Solution**: Install optional dependencies: `pip install langextract[openai]` ### Schema Not Working ```python # Schema constraints not being applied ``` **Solutions**: 1. Ensure provider implements `get_schema_class()` 2. Check `use_schema_constraints=True` is set 3. Verify schema's `supports_strict_mode` returns `True` 4. Test schema creation with `Schema.from_examples(examples)` ### Pattern Conflicts ```python # Multiple providers match the same model_id ``` **Solution**: Use explicit provider selection: ```python config = lx.factory.ModelConfig( model_id="model-name", provider="YourProviderClass" # Explicit selection ) ================================================ FILE: langextract/providers/__init__.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provider package for LangExtract. This package contains provider implementations for various LLM backends. Each provider can be imported independently for fine-grained dependency management in build systems. """ import importlib from importlib import metadata import os from absl import logging from langextract.providers import builtin_registry from langextract.providers import router registry = router # Backward compat alias __all__ = [ "gemini", "openai", "ollama", "router", "registry", # Backward compat "schemas", "load_plugins_once", "load_builtins_once", ] # Track provider loading for lazy initialization _plugins_loaded = False # pylint: disable=invalid-name _builtins_loaded = False # pylint: disable=invalid-name def load_builtins_once() -> None: """Load built-in providers to register their patterns. Idempotent function that ensures provider patterns are available for model resolution. Uses lazy registration to ensure providers can be re-registered after registry.clear() even if their modules are already in sys.modules. """ global _builtins_loaded # pylint: disable=global-statement if _builtins_loaded: return # Register built-ins lazily so they can be re-registered after a registry.clear() # even if their modules were already imported earlier in the test run. for config in builtin_registry.BUILTIN_PROVIDERS: router.register_lazy( *config["patterns"], target=config["target"], priority=config["priority"], ) _builtins_loaded = True def load_plugins_once() -> None: """Load provider plugins from installed packages. Discovers and loads langextract provider plugins using entry points. This function is idempotent - multiple calls have no effect. """ global _plugins_loaded # pylint: disable=global-statement if _plugins_loaded: return if os.environ.get("LANGEXTRACT_DISABLE_PLUGINS", "").lower() in ( "1", "true", "yes", ): logging.info("Plugin loading disabled via LANGEXTRACT_DISABLE_PLUGINS") _plugins_loaded = True return load_builtins_once() try: eps = metadata.entry_points() # Try different APIs based on what's available if hasattr(eps, "select"): # Python 3.10+ API provider_eps = eps.select(group="langextract.providers") elif hasattr(eps, "get"): # Python 3.9 API provider_eps = eps.get("langextract.providers", []) else: # Fallback for older versions provider_eps = [ ep for ep in eps if getattr(ep, "group", None) == "langextract.providers" ] for entry_point in provider_eps: try: provider_class = entry_point.load() logging.info("Loaded provider plugin: %s", entry_point.name) if hasattr(provider_class, "get_model_patterns"): patterns = provider_class.get_model_patterns() for pattern in patterns: router.register( pattern, priority=getattr( provider_class, "pattern_priority", 20, # Default plugin priority ), )(provider_class) logging.info( "Registered %d patterns for %s", len(patterns), entry_point.name ) except Exception as e: logging.warning( "Failed to load provider plugin %s: %s", entry_point.name, e ) except Exception as e: logging.warning("Error discovering provider plugins: %s", e) _plugins_loaded = True def _reset_for_testing() -> None: """Reset plugin loading state for testing. Should only be used in tests.""" global _plugins_loaded, _builtins_loaded # pylint: disable=global-statement _plugins_loaded = False _builtins_loaded = False def __getattr__(name: str): """Lazy loading for submodules.""" if name == "router": return importlib.import_module("langextract.providers.router") elif name == "schemas": return importlib.import_module("langextract.providers.schemas") elif name == "_plugins_loaded": return _plugins_loaded elif name == "_builtins_loaded": return _builtins_loaded raise AttributeError(f"module {__name__!r} has no attribute {name!r}") ================================================ FILE: langextract/providers/builtin_registry.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Built-in provider registration configuration. This module defines the registration details for all built-in providers, using patterns from the centralized patterns module. """ from typing import TypedDict from langextract.providers import patterns class ProviderConfig(TypedDict): """Configuration for a provider registration.""" patterns: tuple[str, ...] target: str priority: int # Built-in provider configurations using centralized patterns BUILTIN_PROVIDERS: list[ProviderConfig] = [ { 'patterns': patterns.GEMINI_PATTERNS, 'target': 'langextract.providers.gemini:GeminiLanguageModel', 'priority': patterns.GEMINI_PRIORITY, }, { 'patterns': patterns.OLLAMA_PATTERNS, 'target': 'langextract.providers.ollama:OllamaLanguageModel', 'priority': patterns.OLLAMA_PRIORITY, }, { 'patterns': patterns.OPENAI_PATTERNS, 'target': 'langextract.providers.openai:OpenAILanguageModel', 'priority': patterns.OPENAI_PRIORITY, }, ] ================================================ FILE: langextract/providers/gemini.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Gemini provider for LangExtract.""" # pylint: disable=duplicate-code from __future__ import annotations import concurrent.futures import dataclasses from typing import Any, Final, Iterator, Sequence from absl import logging from langextract.core import base_model from langextract.core import data from langextract.core import exceptions from langextract.core import schema from langextract.core import types as core_types from langextract.providers import gemini_batch from langextract.providers import patterns from langextract.providers import router from langextract.providers import schemas _DEFAULT_MODEL_ID = 'gemini-2.5-flash' _DEFAULT_LOCATION = 'us-central1' _MIME_TYPE_JSON = 'application/json' _API_CONFIG_KEYS: Final[set[str]] = { 'response_mime_type', 'response_schema', 'safety_settings', 'system_instruction', 'tools', 'stop_sequences', 'candidate_count', } @router.register( *patterns.GEMINI_PATTERNS, priority=patterns.GEMINI_PRIORITY, ) @dataclasses.dataclass(init=False) class GeminiLanguageModel(base_model.BaseLanguageModel): # pylint: disable=too-many-instance-attributes """Language model inference using Google's Gemini API with structured output.""" model_id: str = _DEFAULT_MODEL_ID api_key: str | None = None vertexai: bool = False credentials: Any | None = None project: str | None = None location: str | None = None http_options: Any | None = None gemini_schema: schemas.gemini.GeminiSchema | None = None format_type: data.FormatType = data.FormatType.JSON temperature: float = 0.0 max_workers: int = 10 fence_output: bool = False _extra_kwargs: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, compare=False ) @classmethod def get_schema_class(cls) -> type[schema.BaseSchema] | None: """Return the GeminiSchema class for structured output support. Returns: The GeminiSchema class that supports strict schema constraints. """ return schemas.gemini.GeminiSchema def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None: """Apply a schema instance to this provider. Args: schema_instance: The schema instance to apply, or None to clear. """ super().apply_schema(schema_instance) if isinstance(schema_instance, schemas.gemini.GeminiSchema): self.gemini_schema = schema_instance def __init__( self, model_id: str = _DEFAULT_MODEL_ID, api_key: str | None = None, vertexai: bool = False, credentials: Any | None = None, project: str | None = None, location: str | None = None, http_options: Any | None = None, gemini_schema: schemas.gemini.GeminiSchema | None = None, format_type: data.FormatType = data.FormatType.JSON, temperature: float = 0.0, max_workers: int = 10, fence_output: bool = False, **kwargs, ) -> None: """Initialize the Gemini language model. Args: model_id: The Gemini model ID to use. api_key: API key for Gemini service. vertexai: Whether to use Vertex AI instead of API key authentication. credentials: Optional Google auth credentials for Vertex AI. project: Google Cloud project ID for Vertex AI. location: Vertex AI location (e.g., 'global', 'us-central1'). http_options: Optional HTTP options for the client (e.g., for VPC endpoints). gemini_schema: Optional schema for structured output. format_type: Output format (JSON or YAML). temperature: Sampling temperature. max_workers: Maximum number of parallel API calls. fence_output: Whether to wrap output in markdown fences (ignored, Gemini handles this based on schema). **kwargs: Additional Gemini API parameters. Only allowlisted keys are forwarded to the API (response_schema, response_mime_type, tools, safety_settings, stop_sequences, candidate_count, system_instruction). See https://ai.google.dev/api/generate-content for details. """ try: # pylint: disable=import-outside-toplevel from google import genai except ImportError as e: raise exceptions.InferenceConfigError( 'google-genai is required for Gemini. Install it with: pip install' ' google-genai' ) from e self.model_id = model_id self.api_key = api_key self.vertexai = vertexai self.credentials = credentials self.project = project self.location = location self.http_options = http_options self.gemini_schema = gemini_schema self.format_type = format_type self.temperature = temperature self.max_workers = max_workers self.fence_output = fence_output # Extract batch config before we filter kwargs into _extra_kwargs batch_cfg_dict = kwargs.pop('batch', None) self._batch_cfg = gemini_batch.BatchConfig.from_dict(batch_cfg_dict) if not self.api_key and not self.vertexai: raise exceptions.InferenceConfigError( 'Gemini models require either:\n - An API key via api_key parameter' ' or LANGEXTRACT_API_KEY env var\n - Vertex AI configuration with' ' vertexai=True, project, and location' ) if self.vertexai and (not self.project or not self.location): raise exceptions.InferenceConfigError( 'Vertex AI mode requires both project and location parameters' ) if self.api_key and self.vertexai: logging.warning( 'Both API key and Vertex AI configuration provided. ' 'API key will take precedence for authentication.' ) self._client = genai.Client( api_key=self.api_key, vertexai=vertexai, credentials=credentials, project=project, location=location, http_options=http_options, ) super().__init__( constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) ) self._extra_kwargs = { k: v for k, v in (kwargs or {}).items() if k in _API_CONFIG_KEYS } def _validate_schema_config(self) -> None: """Validate that schema configuration is compatible with format type. Raises: InferenceConfigError: If gemini_schema is set but format_type is not JSON. """ if self.gemini_schema and self.format_type != data.FormatType.JSON: raise exceptions.InferenceConfigError( 'Gemini structured output only supports JSON format. ' 'Set format_type=JSON or use_schema_constraints=False.' ) def _process_single_prompt( self, prompt: str, config: dict ) -> core_types.ScoredOutput: """Process a single prompt and return a ScoredOutput.""" try: # Apply stored kwargs that weren't already set in config for key, value in self._extra_kwargs.items(): if key not in config and value is not None: config[key] = value if self.gemini_schema: self._validate_schema_config() config.setdefault('response_mime_type', 'application/json') config.setdefault('response_schema', self.gemini_schema.schema_dict) response = self._client.models.generate_content( model=self.model_id, contents=prompt, config=config ) return core_types.ScoredOutput(score=1.0, output=response.text) except Exception as e: raise exceptions.InferenceRuntimeError( f'Gemini API error: {str(e)}', original=e ) from e def infer( self, batch_prompts: Sequence[str], **kwargs ) -> Iterator[Sequence[core_types.ScoredOutput]]: """Runs inference on a list of prompts via Gemini's API. Args: batch_prompts: A list of string prompts. **kwargs: Additional generation params (temperature, top_p, top_k, etc.) Yields: Lists of ScoredOutputs. """ merged_kwargs = self.merge_kwargs(kwargs) config = { 'temperature': merged_kwargs.get('temperature', self.temperature), } for key in ('max_output_tokens', 'top_p', 'top_k'): if key in merged_kwargs: config[key] = merged_kwargs[key] handled_keys = {'temperature', 'max_output_tokens', 'top_p', 'top_k'} for key, value in merged_kwargs.items(): if ( key not in handled_keys and key in _API_CONFIG_KEYS and value is not None ): config[key] = value # Use batch API if threshold met if self._batch_cfg and self._batch_cfg.enabled: if len(batch_prompts) >= self._batch_cfg.threshold: try: if self.gemini_schema: self._validate_schema_config() schema_dict = ( self.gemini_schema.schema_dict if self.gemini_schema else None ) # Remove schema fields from config for batch API - they're handled via schema_dict batch_config = dict(config) batch_config.pop('response_mime_type', None) batch_config.pop('response_schema', None) # Extract top-level fields that don't belong in generationConfig system_instruction = batch_config.pop('system_instruction', None) safety_settings = batch_config.pop('safety_settings', None) outputs = gemini_batch.infer_batch( client=self._client, model_id=self.model_id, prompts=batch_prompts, schema_dict=schema_dict, gen_config=batch_config, cfg=self._batch_cfg, system_instruction=system_instruction, safety_settings=safety_settings, project=self.project, location=self.location, ) except exceptions.InferenceRuntimeError: raise except Exception as e: raise exceptions.InferenceRuntimeError( f'Gemini Batch API error: {e}', original=e ) from e for text in outputs: yield [core_types.ScoredOutput(score=1.0, output=text)] return else: logging.info( 'Gemini batch mode enabled but prompt count (%d) is below the' ' threshold (%d); using real-time API. Submit at least %d prompts' ' to trigger batch mode.', len(batch_prompts), self._batch_cfg.threshold, self._batch_cfg.threshold, ) # Use parallel processing for batches larger than 1 if len(batch_prompts) > 1 and self.max_workers > 1: with concurrent.futures.ThreadPoolExecutor( max_workers=min(self.max_workers, len(batch_prompts)) ) as executor: future_to_index = { executor.submit( self._process_single_prompt, prompt, config.copy() ): i for i, prompt in enumerate(batch_prompts) } results: list[core_types.ScoredOutput | None] = [None] * len( batch_prompts ) for future in concurrent.futures.as_completed(future_to_index): index = future_to_index[future] try: results[index] = future.result() except Exception as e: raise exceptions.InferenceRuntimeError( f'Parallel inference error: {str(e)}', original=e ) from e for result in results: if result is None: raise exceptions.InferenceRuntimeError( 'Failed to process one or more prompts' ) yield [result] else: # Sequential processing for single prompt or worker for prompt in batch_prompts: result = self._process_single_prompt(prompt, config.copy()) yield [result] # pylint: disable=duplicate-code ================================================ FILE: langextract/providers/gemini_batch.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Gemini Batch API helper module for LangExtract. This module provides batch inference support using the google-genai SDK. It handles: - File-based batch submission for all batch sizes - Job polling and result extraction - Schema-based structured output - Order preservation across batch processing """ from __future__ import annotations from collections.abc import Iterator, Sequence import concurrent.futures import dataclasses import enum import hashlib import json import logging as std_logging import os import re import tempfile import time from typing import Any, Callable, Protocol import uuid from absl import logging from google import genai from google.api_core import exceptions as google_exceptions from google.cloud import storage from langextract.core import exceptions _MIME_TYPE_JSON = "application/json" _DEFAULT_LOCATION = "us-central1" _EXT_JSON = ".json" _EXT_JSONL = ".jsonl" _KEY_IDX = "idx-" _CACHE_PREFIX = "cache" _UNSET = object() @dataclasses.dataclass(slots=True, frozen=True) class BatchConfig: """Define and validate Gemini Batch API configuration. Attributes: enabled: Whether batch mode is enabled. threshold: Minimum prompts to trigger batch processing. poll_interval: Seconds between job status checks. timeout: Maximum seconds to wait for job completion. max_prompts_per_job: Max prompts allowed in one batch job. ignore_item_errors: If True, continue on per-item errors. enable_caching: If True, use GCS-based caching for inference results. retention_days: Days to keep GCS data (default 30). None for permanent. """ enabled: bool = False threshold: int = 50 poll_interval: int = 30 timeout: int = 3600 max_prompts_per_job: int = 20000 ignore_item_errors: bool = False enable_caching: bool | None = _UNSET # type: ignore retention_days: int | None = _UNSET # type: ignore on_job_create: Callable[[Any], None] | None = None def __post_init__(self): """Validate numeric knobs early.""" validations = [ (self.threshold >= 1, "batch.threshold must be >= 1"), (self.poll_interval > 0, "batch.poll_interval must be > 0"), (self.timeout > 0, "batch.timeout must be > 0"), (self.timeout > 0, "batch.timeout must be > 0"), (self.max_prompts_per_job > 0, "batch.max_prompts_per_job must be > 0"), ] for is_valid, error_msg in validations: if not is_valid: raise ValueError(error_msg) if self.enabled: if self.enable_caching is _UNSET: raise ValueError( "batch.enable_caching must be explicitly set when batch is enabled" ) if self.retention_days is _UNSET: raise ValueError( "batch.retention_days must be explicitly set when batch is enabled" " (use None for permanent)" ) if self.retention_days is not None and self.retention_days <= 0: raise ValueError( "batch.retention_days must be > 0 or None (for permanent). " "0 (immediate delete) is not allowed." ) @classmethod def from_dict(cls, d: dict | None) -> BatchConfig: """Create BatchConfig from dictionary, using defaults for missing keys.""" if d is None: return cls() valid_keys = {f.name for f in dataclasses.fields(cls)} filtered_dict = {k: v for k, v in d.items() if k in valid_keys} unknown = sorted(set(d.keys()) - valid_keys) if unknown: logging.warning( "Ignoring unknown batch config keys: %s", ", ".join(unknown) ) cfg = cls(**filtered_dict) if cfg.on_job_create is None: object.__setattr__(cfg, "on_job_create", _default_job_create_callback) return cfg _TERMINAL_FAIL = frozenset({ genai.types.JobState.JOB_STATE_FAILED, genai.types.JobState.JOB_STATE_CANCELLED, genai.types.JobState.JOB_STATE_EXPIRED, }) _TERMINAL_OK = frozenset({ genai.types.JobState.JOB_STATE_SUCCEEDED, genai.types.JobState.JOB_STATE_PAUSED, }) def _default_job_create_callback(job: Any) -> None: """Default callback to log batch job details.""" logging.info("Batch job created successfully: %s", job.name) logging.info("Job State: %s", job.state) # Extract project and job ID for console URL try: # job.name format: projects/{project}/locations/{location}/batchPredictionJobs/{job_id} parts = job.name.split("/") if len(parts) >= 6: job_id = parts[-1] location = parts[3] project = parts[1] logging.info( "Job Console URL:" " https://console.cloud.google.com/vertex-ai/locations/%s/batch-predictions/%s?project=%s", location, job_id, project, ) except Exception: pass def _snake_to_camel(key: str) -> str: """Convert snake_case to camelCase for REST API compatibility.""" parts = key.split("_") return parts[0] + "".join(p.title() for p in parts[1:]) def _is_vertexai_client(client) -> bool: """Check if client is configured for Vertex AI with explicit identity check. Args: client: The genai.Client instance to check. Returns: True if client.vertexai is explicitly True, False otherwise. """ return getattr(client, "vertexai", False) is True def _get_project_location( client: genai.Client, project: str | None = None, location: str | None = None, ) -> tuple[str | None, str]: """Extract project and location from client or arguments.""" if project: proj = project else: # Try to get from client (if available in future versions) or env. proj = getattr(client, "project", None) or os.getenv("GOOGLE_CLOUD_PROJECT") if location: loc = location else: loc = getattr(client, "location", None) or _DEFAULT_LOCATION return proj, loc def _get_bucket_name(project: str | None, location: str) -> str: """Generate consistent GCS bucket name for batch operations.""" base = f"langextract-{project}-{location}-batch".lower() return re.sub(r"[^a-z0-9._-]", "-", base) def _ensure_bucket_lifecycle( bucket: storage.Bucket, retention_days: int | None ) -> None: """Ensure bucket has a lifecycle rule to delete objects after retention_days. This is a best-effort optimization to reduce storage costs. It checks if a rule with the exact age exists, and if not, adds it. It does NOT remove existing rules. Args: bucket: The GCS bucket to configure. retention_days: Number of days to keep objects. If None, no rule is added. """ if retention_days is None or retention_days <= 0: return # Check if rule already exists for rule in bucket.lifecycle_rules: if ( rule.get("action", {}).get("type") == "Delete" and rule.get("condition", {}).get("age") == retention_days ): return # Add new rule bucket.add_lifecycle_delete_rule(age=retention_days) try: bucket.patch() logging.info( "Added lifecycle rule to bucket %s: delete after %d days", bucket.name, retention_days, ) except Exception as e: logging.warning( "Failed to update lifecycle rule for bucket %s: %s", bucket.name, e ) def _build_request( prompt: str, schema_dict: dict | None, gen_config: dict | None, system_instruction: str | None = None, safety_settings: Sequence[Any] | None = None, ) -> dict: """Build a batch request in REST format for file-based submission. Constructs a properly formatted request dictionary for batch processing. Per the Gemini Batch API documentation, each request in the JSONL file can include its own generationConfig with schema and generation parameters, as well as top-level systemInstruction and safetySettings. Args: prompt: The text prompt to send to the model. schema_dict: Optional JSON schema for structured output. gen_config: Optional generation configuration parameters. system_instruction: Optional system instruction text. safety_settings: Optional safety settings sequence. Returns: A dictionary formatted for REST API file-based submission, containing: * contents: The prompt content. * systemInstruction: Optional system instructions. * safetySettings: Optional safety settings. * generationConfig: Optional generation configuration and schema. """ request = {"contents": [{"role": "user", "parts": [{"text": prompt}]}]} if system_instruction: request["systemInstruction"] = {"parts": [{"text": system_instruction}]} if safety_settings: request["safetySettings"] = safety_settings if schema_dict or gen_config: generation_config = {} if schema_dict: generation_config["responseMimeType"] = _MIME_TYPE_JSON generation_config["responseSchema"] = schema_dict if gen_config: for k, v in gen_config.items(): generation_config[_snake_to_camel(k)] = v request["generationConfig"] = generation_config return request def _submit_file( client: genai.Client, model_id: str, requests: Sequence[dict], display: str, retention_days: int | None, project: str | None = None, location: str | None = None, ) -> genai.types.BatchJob: """Submit a file-based batch job to Vertex AI using GCS storage. Batch processing is only supported with Vertex AI because it requires GCS for file upload. Creates JSONL file, uploads to auto-created bucket, and submits job for async processing. Args: client: google.genai.Client instance configured for Vertex AI (must have client.vertexai=True). model_id: Model identifier (e.g., "gemini-2.5-flash"). requests: List of request dictionaries with embedded configuration. Each request contains contents and optional generationConfig (including schema and generation parameters). display: Display name for the batch job, used for identification and as part of the GCS blob name. retention_days: Days to keep GCS data. If set, applies lifecycle rule. project: Optional GCP project ID. If not provided, will attempt to determine from client or environment. location: Optional GCP region/location. If not provided, will attempt to determine from client or use default. Returns: BatchJob object that can be polled for completion status. Raises: ValueError: If client is not configured for Vertex AI. """ path = None try: with tempfile.NamedTemporaryFile( "w", suffix=_EXT_JSONL, delete=False, encoding="utf-8" ) as f: path = f.name for idx, req in enumerate(requests): # We use a simple "idx-{N}" key format to track the original order # of prompts, as batch processing may return results out of order. line = {"key": f"{_KEY_IDX}{idx}", "request": req} f.write(json.dumps(line, ensure_ascii=False) + "\n") project, location = _get_project_location(client, project, location) bucket_name = _get_bucket_name(project, location) blob_name = f"batch-input/{display}-{uuid.uuid4().hex}.jsonl" storage_client = storage.Client(project=project) try: bucket = storage_client.create_bucket(bucket_name, location=location) logging.info("Created GCS bucket: %s", bucket_name) except google_exceptions.Conflict: bucket = storage_client.bucket(bucket_name) logging.info("Using existing GCS bucket: %s", bucket_name) if retention_days: _ensure_bucket_lifecycle(bucket, retention_days) blob = bucket.blob(blob_name) blob.upload_from_filename(path) gcs_uri = f"gs://{bucket.name}/{blob.name}" # Create batch job (config and schema are in per-request generationConfig) job = client.batches.create( model=model_id, src=gcs_uri, config={"display_name": display} ) return job finally: if path: try: os.unlink(path) except OSError: pass class GCSBatchCache: """GCS-based cache for batch inference results.""" def __init__(self, bucket_name: str, project: str | None = None): self.bucket_name = bucket_name self.project = project self._client = storage.Client(project=project) self._bucket = self._client.bucket(bucket_name) def _compute_hash(self, key_data: dict) -> str: """Compute SHA256 hash of the canonicalized request data.""" canonical_json = json.dumps(key_data, sort_keys=True, ensure_ascii=False) return hashlib.sha256(canonical_json.encode("utf-8")).hexdigest() def _get_single(self, key_hash: str) -> str | None: """Fetch single item from GCS.""" blob = self._bucket.blob(f"{_CACHE_PREFIX}/{key_hash}{_EXT_JSON}") try: data = json.loads(blob.download_as_text()) return data.get("text") except google_exceptions.NotFound: return None except Exception as e: logging.warning("Cache read error for %s: %s", key_hash, e) return None def get_multi(self, key_data_list: Sequence[dict]) -> dict[int, str]: """Fetch multiple items from GCS in parallel. Returns: Dict mapping index in key_data_list to cached text. """ results = {} # Limit max_workers to 10 to match default HTTP connection pool size. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: future_to_idx = {} for idx, key_data in enumerate(key_data_list): key_hash = self._compute_hash(key_data) future = executor.submit(self._get_single, key_hash) future_to_idx[future] = idx for future in concurrent.futures.as_completed(future_to_idx): idx = future_to_idx[future] text = future.result() if text is not None: results[idx] = text return results def set_multi(self, items: Sequence[tuple[dict, str]]) -> None: """Upload multiple items to GCS in parallel. Args: items: List of (key_data, result_text) tuples. """ def _upload(text: str, key_data: dict): key_hash = self._compute_hash(key_data) blob = self._bucket.blob(f"{_CACHE_PREFIX}/{key_hash}{_EXT_JSON}") try: blob.upload_from_string( json.dumps({"text": text}, ensure_ascii=False), content_type=_MIME_TYPE_JSON, ) except Exception as e: logging.warning( "Cache write error for %s: %s", key_hash, e, exc_info=True ) def _json_default(obj): if dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) if isinstance(obj, enum.Enum): return obj.value raise TypeError(f"Object of type {type(obj)} is not JSON serializable") with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: for key_data, text in items: # If text is not a string, try to serialize it if not isinstance(text, str): try: text = json.dumps(text, default=_json_default, ensure_ascii=False) except Exception as e: logging.warning("Serialization error: %s", e) continue executor.submit(_upload, text, key_data) def iter_items(self) -> Iterator[tuple[str, str]]: """Iterate over all items in the cache. Yields: Tuple of (key_hash, text_content). """ blobs = self._bucket.list_blobs(prefix=f"{_CACHE_PREFIX}/") for blob in blobs: if not blob.name.endswith(_EXT_JSON): continue try: key_hash = blob.name.split("/")[-1].replace(_EXT_JSON, "") data = json.loads(blob.download_as_text()) text = data.get("text") if text is not None: yield key_hash, text except (json.JSONDecodeError, Exception) as e: logging.warning("Failed to read cache item %s: %s", blob.name, e) class _TextResponse(Protocol): """Protocol for inline response objects with text attribute.""" text: str def _safe_get_nested(data: dict, *keys) -> Any: """Safely traverse nested dictionaries/lists. Args: data: The dict to traverse. *keys: Keys/indices to access. Use integers for list indices. Returns: The value at the path, or None if any key doesn't exist. """ current = data for key in keys: if current is None: return None if isinstance(key, int): if not isinstance(current, list) or len(current) <= key: return None current = current[key] else: if not isinstance(current, dict): return None current = current.get(key) return current def _extract_text(resp: _TextResponse | dict[str, Any] | None) -> str | None: """Extract text from Vertex AI batch API response. Args: resp: Response object (inline) or dict (file) containing text. Returns: Extracted text string, or None if not found or invalid. """ if resp is None: return None if hasattr(resp, "text"): text = getattr(resp, "text", None) return text if isinstance(text, str) else None if not isinstance(resp, dict): return None # Vertex AI format: {"candidates": [{"content": {"parts": [{"text": "..."}]}}]} text = _safe_get_nested(resp, "candidates", 0, "content", "parts", 0, "text") return text if isinstance(text, str) else None def _poll_completion( client: genai.Client, job: genai.types.BatchJob, cfg: BatchConfig ) -> genai.types.BatchJob: """Poll batch job until completion or timeout. Args: client: google.genai.Client instance for polling job status. job: Batch job object returned from client.batches.create(). cfg: Batch configuration including timeout and poll_interval. Returns: Completed batch job object. Raises: RuntimeError: If the job enters a failed terminal state. TimeoutError: If the job does not complete within cfg.timeout. """ start = time.time() name = job.name while True: job = client.batches.get(name=name) state = job.state if state in _TERMINAL_OK: return job if state in _TERMINAL_FAIL: error_details = job.error or "(no error details)" raise exceptions.InferenceRuntimeError( f"Batch job failed: state={state.name}, name={name}, " f"error={error_details}" ) if time.time() - start > cfg.timeout: try: client.batches.cancel(name=name) except Exception as e: logging.warning("Failed to cancel timed-out batch job %s: %s", name, e) raise exceptions.InferenceRuntimeError( f"Batch job timed out after {cfg.timeout}s: {name}" ) time.sleep(cfg.poll_interval) logging.info("Batch job is running... (State: %s)", state.name) def _parse_batch_line( line: str, outputs: dict[int, str], cfg: BatchConfig ) -> None: """Parse a single line from batch output JSONL.""" try: obj = json.loads(line) except json.JSONDecodeError: return error = obj.get("error") if error and not cfg.ignore_item_errors: code = error.get("code") if isinstance(error, dict) else None if code not in (None, 0): raise exceptions.InferenceRuntimeError(f"Batch item error: {error}") resp = obj.get("response", {}) text = _extract_text(resp) or "" key = obj.get("key", "") try: # Extract the original index from the key (e.g., "idx-5" -> 5) idx = int(str(key).rsplit(_KEY_IDX, maxsplit=1)[-1]) except (ValueError, IndexError): idx = max(outputs.keys(), default=-1) + 1 outputs[idx] = text def _extract_from_file( client: genai.Client, job: genai.types.BatchJob, cfg: BatchConfig, expected_count: int, ) -> list[str]: """Extract text outputs from file-based batch results, preserving order. Reads results from GCS output directory. Args: client: google.genai.Client instance for downloading result file. job: Completed batch job object with result location. cfg: Batch configuration including error handling settings. expected_count: Number of prompts submitted (for order preservation). Returns: List of text outputs corresponding 1:1 to input prompts. Missing results are padded with empty strings. Raises: RuntimeError: If job is missing result location or item has error. """ if not _is_vertexai_client(client): raise ValueError("Batch API is only supported with Vertex AI.") outputs_by_idx: dict[int, str] = {} if not job.dest: raise exceptions.InferenceRuntimeError("Vertex AI batch job missing dest") gcs_uri = getattr(job.dest, "gcs_uri", None) or getattr( job.dest, "gcs_output_directory", None ) if not gcs_uri: raise exceptions.InferenceRuntimeError( "Vertex AI batch job missing output GCS URI" ) if not gcs_uri.startswith("gs://"): raise exceptions.InferenceRuntimeError(f"Invalid GCS URI format: {gcs_uri}") bucket_name, _, prefix = gcs_uri[5:].partition("/") project = getattr(client, "project", None) or os.getenv( "GOOGLE_CLOUD_PROJECT" ) storage_client = storage.Client(project=project) bucket = storage_client.bucket(bucket_name) # Vertex AI may write multiple output files. blobs = list(bucket.list_blobs(prefix=prefix)) if not blobs: raise exceptions.InferenceRuntimeError( f"No output files found in {gcs_uri}" ) logging.info("Batch API: Downloading results from %s", gcs_uri) logging.info("Batch API: Found %d output files", len(blobs)) for blob in blobs: if not blob.name.endswith(_EXT_JSONL): continue # Stream file line by line to avoid loading entire file into memory. with blob.open("r", encoding="utf-8") as f: for line in f: if not line.strip(): continue _parse_batch_line(line, outputs_by_idx, cfg) logging.info("Batch API: Parsed %d results", len(outputs_by_idx)) return [outputs_by_idx.get(i, "") for i in range(expected_count)] def infer_batch( client: genai.Client, model_id: str, prompts: Sequence[str], schema_dict: dict | None, gen_config: dict, cfg: BatchConfig, system_instruction: str | None = None, safety_settings: Sequence[Any] | None = None, project: str | None = None, location: str | None = None, ) -> list[str]: """Execute batch inference on multiple prompts using the Vertex AI Batch API. This function provides file-based batch processing via Vertex AI. It: - Uploads prompts to GCS (Google Cloud Storage) - Submits batch job to Vertex AI - Polls for job completion - Extracts and returns results Args: client: google.genai.Client instance configured for Vertex AI (must have client.vertexai=True). model_id: Model identifier (e.g., "gemini-2.5-flash"). prompts: Sequence of prompts to process in batch. schema_dict: Optional JSON schema for structured output. When provided, enables JSON mode with the specified schema constraints. gen_config: Generation configuration parameters (temperature, top_p, etc.). cfg: Batch configuration including thresholds, timeouts, and error handling. system_instruction: Optional system instruction text. safety_settings: Optional safety settings sequence. project: Google Cloud project ID (optional, overrides client/env). location: Vertex AI location (optional, overrides client/env). Returns: List of text outputs corresponding 1:1 to input prompts. Missing results are padded with empty strings. Raises: RuntimeError: If batch job fails or individual items have errors (when cfg.ignore_item_errors is False). TimeoutError: If batch job doesn't complete within cfg.timeout seconds. """ if not prompts: return [] if not _is_vertexai_client(client): raise ValueError( "Batch API is only supported with Vertex AI. To use batch mode, create" " your client with: genai.Client(vertexai=True, project='YOUR_PROJECT'," " location='us-central1'). For Google AI API keys, batch mode is not" " currently supported." ) # Suppress verbose HTTP logs from underlying libraries std_logging.getLogger("google.auth.transport.requests").setLevel( std_logging.WARNING ) std_logging.getLogger("urllib3.connectionpool").setLevel(std_logging.WARNING) std_logging.getLogger("httpx").setLevel(std_logging.WARNING) std_logging.getLogger("httpcore").setLevel(std_logging.WARNING) # Force disable httpx propagation or handlers if level setting fails std_logging.getLogger("httpx").disabled = True logging.info("Batch API: Processing %d prompts", len(prompts)) display_base = f"langextract-batch-{int(time.time())}" project, location = _get_project_location(client, project, location) bucket_name = _get_bucket_name(project, location) cache = GCSBatchCache(bucket_name, project) if cfg.enable_caching else None if cache: logging.info( "Batch API: Using GCS bucket:" " https://console.cloud.google.com/storage/browser/%s", bucket_name, ) prompts_to_process: list[tuple[int, str]] = [] cached_results: dict[int, str] = {} if cache: key_data_list = [] for prompt in prompts: key_data_list.append({ "model_id": model_id, "prompt": prompt, "system_instruction": system_instruction, "gen_config": gen_config, "safety_settings": safety_settings, "schema": schema_dict, }) cached_results = cache.get_multi(key_data_list) for idx, prompt in enumerate(prompts): if idx not in cached_results: prompts_to_process.append((idx, prompt)) else: prompts_to_process = list(enumerate(prompts)) if not prompts_to_process: logging.info("Batch API: All %d prompts found in cache", len(prompts)) return [cached_results[i] for i in range(len(prompts))] logging.info( "Batch API: %d cached, %d to submit", len(cached_results), len(prompts_to_process), ) def _process_batch( batch_items: Sequence[tuple[int, str]], display: str ) -> dict[int, str]: """Submit batch job, poll completion, and extract results. Returns: Dict mapping original index to result text. """ batch_prompts = [p for _, p in batch_items] requests = [ _build_request( p, schema_dict, gen_config, system_instruction, safety_settings ) for p in batch_prompts ] job = _submit_file( client, model_id, requests, display, cfg.retention_days, project, location, ) if cfg.on_job_create: try: cfg.on_job_create(job) except Exception as e: logging.warning("Batch job creation callback failed: %s", e) job = _poll_completion(client, job, cfg) logging.info("Batch job completed successfully.") results = _extract_from_file( client, job, cfg, expected_count=len(batch_prompts) ) # Map results back to original indices mapped_results = {} for (orig_idx, _), result in zip(batch_items, results): mapped_results[orig_idx] = result return mapped_results new_results: dict[int, str] = {} if ( cfg.max_prompts_per_job and len(prompts_to_process) > cfg.max_prompts_per_job ): chunk_size = cfg.max_prompts_per_job for chunk_num, i in enumerate( range(0, len(prompts_to_process), chunk_size) ): chunk_items = prompts_to_process[i : i + chunk_size] chunk_results = _process_batch( chunk_items, f"{display_base}-part-{chunk_num}" ) new_results.update(chunk_results) else: new_results = _process_batch(prompts_to_process, display_base) if cache: upload_list = [] for idx, text in new_results.items(): prompt = prompts[idx] key_data = { "model_id": model_id, "prompt": prompt, "system_instruction": system_instruction, "gen_config": gen_config, "safety_settings": safety_settings, "schema": schema_dict, } upload_list.append((key_data, text)) cache.set_multi(upload_list) final_outputs = [] for i in range(len(prompts)): if i in cached_results: final_outputs.append(cached_results[i]) else: final_outputs.append(new_results.get(i, "")) return final_outputs ================================================ FILE: langextract/providers/ollama.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Ollama provider for LangExtract. This provider enables using local Ollama models with LangExtract's extract() function. No API key is required since Ollama runs locally on your machine. Usage with extract(): import langextract as lx from langextract.data import ExampleData, Extraction # Create an example for few-shot learning example = ExampleData( text="Marie Curie was a pioneering physicist and chemist.", extractions=[ Extraction( extraction_class="person", extraction_text="Marie Curie", attributes={"name": "Marie Curie", "field": "physics and chemistry"} ) ] ) # Basic usage with Ollama result = lx.extract( text_or_documents="Isaac Asimov was a prolific science fiction writer.", model_id="gemma2:2b", prompt_description="Extract the person's name and field", examples=[example], ) Direct provider instantiation (when model ID conflicts with other providers): from langextract.providers.ollama import OllamaLanguageModel # Create Ollama provider directly model = OllamaLanguageModel( model_id="gemma2:2b", model_url="http://localhost:11434", # optional, uses default if not specified ) # Use with extract by passing the model instance result = lx.extract( text_or_documents="Your text here", model=model, # Pass the model instance directly prompt_description="Extract information", examples=[example], ) Using pre-configured FormatHandler for manual control: from langextract.providers.ollama import OLLAMA_FORMAT_HANDLER # Use the pre-configured Ollama FormatHandler result = lx.extract( text_or_documents="Your text here", model_id="gemma2:2b", prompt_description="Extract information", examples=[example], resolver_params={'format_handler': OLLAMA_FORMAT_HANDLER} ) Supported model ID formats: - Standard Ollama: llama3.2:1b, gemma2:2b, mistral:7b, qwen2.5:7b, etc. - Hugging Face style: meta-llama/Llama-3.2-1B-Instruct, google/gemma-2b, etc. Prerequisites: 1. Install Ollama: https://ollama.ai 2. Pull the model: ollama pull gemma2:2b 3. Ollama server will start automatically when you use extract() """ # pylint: disable=duplicate-code from __future__ import annotations import dataclasses from typing import Any, Iterator, Mapping, Sequence from urllib.parse import urljoin from urllib.parse import urlparse import warnings import requests # Import from core modules directly from langextract.core import base_model from langextract.core import data from langextract.core import exceptions from langextract.core import format_handler as fh from langextract.core import schema from langextract.core import types as core_types from langextract.providers import patterns from langextract.providers import router # Ollama defaults _OLLAMA_DEFAULT_MODEL_URL = 'http://localhost:11434' _DEFAULT_TEMPERATURE = 0.1 _DEFAULT_TIMEOUT = 120 _DEFAULT_KEEP_ALIVE = 5 * 60 # 5 minutes _DEFAULT_NUM_CTX = 2048 # Pre-configured FormatHandler for consistent Ollama configuration # use_wrapper=True creates {"extractions": [...]} vs just [...] # Ollama's JSON mode expects a dictionary root, not a bare list OLLAMA_FORMAT_HANDLER = fh.FormatHandler( format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key=None, use_fences=False, strict_fences=False, ) @router.register( *patterns.OLLAMA_PATTERNS, priority=patterns.OLLAMA_PRIORITY, ) @dataclasses.dataclass(init=False) class OllamaLanguageModel(base_model.BaseLanguageModel): """Language model inference class using Ollama based host. Timeout can be set via constructor or passed through lx.extract(): lx.extract(..., language_model_params={"timeout": 300}) Authentication is supported for proxied Ollama instances: lx.extract(..., language_model_params={"api_key": "sk-..."}) """ _model: str _model_url: str format_type: core_types.FormatType = core_types.FormatType.JSON _constraint: schema.Constraint = dataclasses.field( default_factory=schema.Constraint, repr=False, compare=False ) _extra_kwargs: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, compare=False ) # Authentication _api_key: str | None = None _auth_scheme: str = 'Bearer' _auth_header: str = 'Authorization' @classmethod def get_schema_class(cls) -> type[schema.BaseSchema] | None: """Return the FormatModeSchema class for JSON output support. Returns: The FormatModeSchema class that enables JSON mode (non-strict). """ return schema.FormatModeSchema def __repr__(self) -> str: """Return string representation with redacted API key.""" api_key_display = '[REDACTED]' if self._api_key else None return ( f'{self.__class__.__name__}(' f'model={self._model!r}, ' f'model_url={self._model_url!r}, ' f'format_type={self.format_type!r}, ' f'api_key={api_key_display})' ) def __init__( self, model_id: str, model_url: str = _OLLAMA_DEFAULT_MODEL_URL, base_url: str | None = None, # Alias for model_url format_type: core_types.FormatType | None = None, structured_output_format: str | None = None, # Deprecated constraint: schema.Constraint = schema.Constraint(), timeout: int | None = None, **kwargs, ) -> None: """Initialize the Ollama language model. Args: model_id: The Ollama model ID to use. model_url: URL for Ollama server (legacy parameter). base_url: Alternative parameter name for Ollama server URL. format_type: Output format (JSON or YAML). Defaults to JSON. structured_output_format: DEPRECATED - use format_type instead. constraint: Schema constraints. timeout: Request timeout in seconds. Defaults to 120. **kwargs: Additional parameters. """ self._requests = requests # Handle deprecated structured_output_format parameter if structured_output_format is not None: warnings.warn( "'structured_output_format' is deprecated and will be removed in " "v2.0.0. Use 'format_type' instead.", FutureWarning, stacklevel=2, ) if format_type is None: format_type = ( core_types.FormatType.JSON if structured_output_format == 'json' else core_types.FormatType.YAML ) fmt = kwargs.pop('format', None) if format_type is None and fmt in ('json', 'yaml'): format_type = ( core_types.FormatType.JSON if fmt == 'json' else core_types.FormatType.YAML ) if format_type is None: format_type = core_types.FormatType.JSON self._model = model_id self._model_url = base_url or model_url or _OLLAMA_DEFAULT_MODEL_URL self.format_type = format_type self._constraint = constraint self._api_key = kwargs.pop('api_key', None) self._auth_scheme = kwargs.pop('auth_scheme', 'Bearer') self._auth_header = kwargs.pop('auth_header', 'Authorization') if self._api_key: host = urlparse(self._model_url).hostname if host in ('localhost', '127.0.0.1', '::1'): warnings.warn( 'API key provided for localhost Ollama instance. ' "Native Ollama doesn't require authentication. " 'This is typically only needed for proxied instances.', UserWarning, ) super().__init__(constraint=constraint) if timeout is not None: kwargs['timeout'] = timeout self._extra_kwargs = kwargs or {} def infer( self, batch_prompts: Sequence[str], **kwargs ) -> Iterator[Sequence[core_types.ScoredOutput]]: """Runs inference on a list of prompts via Ollama's API. Args: batch_prompts: A list of string prompts. **kwargs: Additional generation params. Yields: Lists of ScoredOutputs. """ combined_kwargs = self.merge_kwargs(kwargs) for prompt in batch_prompts: try: response = self._ollama_query( prompt=prompt, model=self._model, structured_output_format='json' if self.format_type == core_types.FormatType.JSON else 'yaml', model_url=self._model_url, **combined_kwargs, ) yield [core_types.ScoredOutput(score=1.0, output=response['response'])] except Exception as e: raise exceptions.InferenceRuntimeError( f'Ollama API error: {str(e)}', original=e ) from e def _ollama_query( self, prompt: str, model: str | None = None, temperature: float | None = None, seed: int | None = None, top_k: int | None = None, top_p: float | None = None, max_output_tokens: int | None = None, structured_output_format: str | None = None, system: str = '', raw: bool = False, model_url: str | None = None, timeout: int | None = None, keep_alive: int | None = None, num_threads: int | None = None, num_ctx: int | None = None, stop: str | list[str] | None = None, **kwargs, ) -> Mapping[str, Any]: """Sends a prompt to an Ollama model and returns the generated response. Note: This is a low-level method. Constructor timeout is only used when calling through infer(). Direct calls use the timeout parameter here. This function makes an HTTP POST request to the `/api/generate` endpoint of an Ollama server. It can optionally load the specified model first, generate a response (with or without streaming), then return a parsed JSON response. Args: prompt: The text prompt to send to the model. model: The name of the model to use. Defaults to self._model. temperature: Sampling temperature. Higher values produce more diverse output. seed: Seed for reproducible generation. If None, random seed is used. top_k: The top-K parameter for sampling. top_p: The top-P (nucleus) sampling parameter. max_output_tokens: Maximum tokens to generate. If None, the model's default is used. structured_output_format: If set to "json" or a JSON schema dict, requests structured outputs from the model. See Ollama documentation for details. system: A system prompt to override any system-level instructions. raw: If True, bypasses any internal prompt templating; you provide the entire raw prompt. model_url: The base URL for the Ollama server. Defaults to self._model_url. timeout: Timeout (in seconds) for the HTTP request. Defaults to 120. keep_alive: How long (in seconds) the model remains loaded after generation completes. num_threads: Number of CPU threads to use. If None, Ollama uses a default heuristic. num_ctx: Number of context tokens allowed. If None, uses model's default or config. stop: Stop sequences to halt generation. Can be a string or list of strings. **kwargs: Additional parameters passed through. Returns: A mapping (dictionary-like) containing the server's JSON response. For non-streaming calls, the `"response"` key typically contains the entire generated text. Raises: InferenceConfigError: If the server returns a 404 (model not found). InferenceRuntimeError: For any other HTTP errors, timeouts, or request exceptions. """ model = model or self._model model_url = model_url or self._model_url if structured_output_format is None and self.format_type is not None: structured_output_format = ( 'json' if self.format_type == core_types.FormatType.JSON else 'yaml' ) options: dict[str, Any] = {} if keep_alive is not None: options['keep_alive'] = keep_alive else: options['keep_alive'] = _DEFAULT_KEEP_ALIVE if seed is not None: options['seed'] = seed if temperature is not None: options['temperature'] = temperature else: options['temperature'] = _DEFAULT_TEMPERATURE if top_k is not None: options['top_k'] = top_k if top_p is not None: options['top_p'] = top_p if num_threads is not None: options['num_thread'] = num_threads if max_output_tokens is not None: options['num_predict'] = max_output_tokens if num_ctx is not None: options['num_ctx'] = num_ctx else: options['num_ctx'] = _DEFAULT_NUM_CTX reserved_top_level = { 'model', 'prompt', 'system', 'stop', 'format', 'stream', 'raw', } for key, value in kwargs.items(): if value is None: continue if key in reserved_top_level: continue if key not in options: options[key] = value api_url = urljoin( model_url if model_url.endswith('/') else model_url + '/', 'api/generate', ) payload: dict[str, Any] = { 'model': model, 'prompt': prompt, 'system': system, 'stream': False, 'raw': raw, 'options': options, } if structured_output_format is not None: payload['format'] = structured_output_format if stop is not None: payload['stop'] = stop request_timeout = timeout if timeout is not None else _DEFAULT_TIMEOUT headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', } if self._api_key: if self._auth_scheme: headers[self._auth_header] = f'{self._auth_scheme} {self._api_key}' else: headers[self._auth_header] = self._api_key try: response = self._requests.post( api_url, headers=headers, json=payload, timeout=request_timeout, ) except self._requests.exceptions.RequestException as e: if isinstance(e, self._requests.exceptions.ReadTimeout): msg = ( f'Ollama Model timed out (timeout={request_timeout},' f' num_threads={num_threads})' ) raise exceptions.InferenceRuntimeError( msg, original=e, provider='Ollama' ) from e raise exceptions.InferenceRuntimeError( f'Ollama request failed: {str(e)}', original=e, provider='Ollama' ) from e response.encoding = 'utf-8' if response.status_code == 200: return response.json() if response.status_code == 404: raise exceptions.InferenceConfigError( f"Can't find Ollama {model}. Try: ollama run {model}" ) else: msg = f'Bad status code from Ollama: {response.status_code}' raise exceptions.InferenceRuntimeError(msg, provider='Ollama') ================================================ FILE: langextract/providers/openai.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """OpenAI provider for LangExtract.""" # pylint: disable=duplicate-code from __future__ import annotations import concurrent.futures import dataclasses from typing import Any, Iterator, Sequence from langextract.core import base_model from langextract.core import data from langextract.core import exceptions from langextract.core import schema from langextract.core import types as core_types from langextract.providers import patterns from langextract.providers import router @router.register( *patterns.OPENAI_PATTERNS, priority=patterns.OPENAI_PRIORITY, ) @dataclasses.dataclass(init=False) class OpenAILanguageModel(base_model.BaseLanguageModel): """Language model inference using OpenAI's API with structured output.""" model_id: str = 'gpt-4o-mini' api_key: str | None = None base_url: str | None = None organization: str | None = None format_type: data.FormatType = data.FormatType.JSON temperature: float | None = None max_workers: int = 10 _client: Any = dataclasses.field(default=None, repr=False, compare=False) _extra_kwargs: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, compare=False ) @property def requires_fence_output(self) -> bool: """OpenAI JSON mode returns raw JSON without fences.""" if self.format_type == data.FormatType.JSON: return False return super().requires_fence_output def __init__( self, model_id: str = 'gpt-4o-mini', api_key: str | None = None, base_url: str | None = None, organization: str | None = None, format_type: data.FormatType = data.FormatType.JSON, temperature: float | None = None, max_workers: int = 10, **kwargs, ) -> None: """Initialize the OpenAI language model. Args: model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o'). api_key: API key for OpenAI service. base_url: Base URL for OpenAI service. organization: Optional OpenAI organization ID. format_type: Output format (JSON or YAML). temperature: Sampling temperature. max_workers: Maximum number of parallel API calls. **kwargs: Ignored extra parameters so callers can pass a superset of arguments shared across back-ends without raising ``TypeError``. """ # Lazy import: OpenAI package required try: # pylint: disable=import-outside-toplevel import openai except ImportError as e: raise exceptions.InferenceConfigError( 'OpenAI provider requires openai package. ' 'Install with: pip install langextract[openai]' ) from e self.model_id = model_id self.api_key = api_key self.base_url = base_url self.organization = organization self.format_type = format_type self.temperature = temperature self.max_workers = max_workers if not self.api_key: raise exceptions.InferenceConfigError('API key not provided.') # Initialize the OpenAI client self._client = openai.OpenAI( api_key=self.api_key, base_url=self.base_url, organization=self.organization, ) super().__init__( constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE) ) self._extra_kwargs = kwargs or {} def _normalize_reasoning_params(self, config: dict) -> dict: """Normalize reasoning parameters for API compatibility. Converts flat 'reasoning_effort' to nested 'reasoning' structure. Merges with existing reasoning dict if present. """ result = config.copy() if 'reasoning_effort' in result: effort = result.pop('reasoning_effort') reasoning = result.get('reasoning', {}) or {} reasoning.setdefault('effort', effort) result['reasoning'] = reasoning return result def _process_single_prompt( self, prompt: str, config: dict ) -> core_types.ScoredOutput: """Process a single prompt and return a ScoredOutput.""" try: normalized_config = self._normalize_reasoning_params(config) system_message = '' if self.format_type == data.FormatType.JSON: system_message = ( 'You are a helpful assistant that responds in JSON format.' ) elif self.format_type == data.FormatType.YAML: system_message = ( 'You are a helpful assistant that responds in YAML format.' ) messages = [{'role': 'user', 'content': prompt}] if system_message: messages.insert(0, {'role': 'system', 'content': system_message}) api_params = { 'model': self.model_id, 'messages': messages, 'n': 1, } temp = normalized_config.get('temperature', self.temperature) if temp is not None: api_params['temperature'] = temp if self.format_type == data.FormatType.JSON: api_params.setdefault('response_format', {'type': 'json_object'}) if (v := normalized_config.get('max_output_tokens')) is not None: api_params['max_tokens'] = v if (v := normalized_config.get('top_p')) is not None: api_params['top_p'] = v for key in [ 'frequency_penalty', 'presence_penalty', 'seed', 'stop', 'logprobs', 'top_logprobs', 'reasoning', 'response_format', ]: if (v := normalized_config.get(key)) is not None: api_params[key] = v response = self._client.chat.completions.create(**api_params) # Extract the response text using the v1.x response format output_text = response.choices[0].message.content return core_types.ScoredOutput(score=1.0, output=output_text) except Exception as e: raise exceptions.InferenceRuntimeError( f'OpenAI API error: {str(e)}', original=e ) from e def infer( self, batch_prompts: Sequence[str], **kwargs ) -> Iterator[Sequence[core_types.ScoredOutput]]: """Runs inference on a list of prompts via OpenAI's API. Args: batch_prompts: A list of string prompts. **kwargs: Additional generation params (temperature, top_p, etc.) Yields: Lists of ScoredOutputs. """ merged_kwargs = self.merge_kwargs(kwargs) config = {} temp = merged_kwargs.get('temperature', self.temperature) if temp is not None: config['temperature'] = temp if 'max_output_tokens' in merged_kwargs: config['max_output_tokens'] = merged_kwargs['max_output_tokens'] if 'top_p' in merged_kwargs: config['top_p'] = merged_kwargs['top_p'] for key in [ 'frequency_penalty', 'presence_penalty', 'seed', 'stop', 'logprobs', 'top_logprobs', 'reasoning_effort', 'reasoning', 'response_format', ]: if key in merged_kwargs: config[key] = merged_kwargs[key] # Use parallel processing for batches larger than 1 if len(batch_prompts) > 1 and self.max_workers > 1: with concurrent.futures.ThreadPoolExecutor( max_workers=min(self.max_workers, len(batch_prompts)) ) as executor: future_to_index = { executor.submit( self._process_single_prompt, prompt, config.copy() ): i for i, prompt in enumerate(batch_prompts) } results: list[core_types.ScoredOutput | None] = [None] * len( batch_prompts ) for future in concurrent.futures.as_completed(future_to_index): index = future_to_index[future] try: results[index] = future.result() except Exception as e: raise exceptions.InferenceRuntimeError( f'Parallel inference error: {str(e)}', original=e ) from e for result in results: if result is None: raise exceptions.InferenceRuntimeError( 'Failed to process one or more prompts' ) yield [result] else: # Sequential processing for single prompt or worker for prompt in batch_prompts: result = self._process_single_prompt(prompt, config.copy()) yield [result] # pylint: disable=duplicate-code ================================================ FILE: langextract/providers/patterns.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Centralized pattern definitions for built-in providers. This module defines all patterns and priorities for built-in providers in one place to avoid duplication. """ # Gemini provider patterns GEMINI_PATTERNS = (r'^gemini',) GEMINI_PRIORITY = 10 # OpenAI provider patterns OPENAI_PATTERNS = ( r'^gpt-4', r'^gpt4\.', r'^gpt-5', r'^gpt5\.', ) OPENAI_PRIORITY = 10 # Ollama provider patterns OLLAMA_PATTERNS = ( # Standard Ollama naming patterns r'^gemma', # gemma2:2b, gemma2:9b, etc. r'^llama', # llama3.2:1b, llama3.1:8b, etc. r'^mistral', # mistral:7b, mistral-nemo:12b, etc. r'^mixtral', # mixtral:8x7b, mixtral:8x22b, etc. r'^phi', # phi3:3.8b, phi3:14b, etc. r'^qwen', # qwen2.5:0.5b to 72b r'^deepseek', # deepseek-coder-v2, etc. r'^command-r', # command-r:35b, command-r-plus:104b r'^starcoder', # starcoder2:3b, starcoder2:7b, etc. r'^codellama', # codellama:7b, codellama:13b, etc. r'^codegemma', # codegemma:2b, codegemma:7b r'^tinyllama', # tinyllama:1.1b r'^wizardcoder', # wizardcoder:7b, wizardcoder:13b, etc. r'^gpt-oss', # Open source GPT variants # HuggingFace model patterns r'^meta-llama/[Ll]lama', r'^google/gemma', r'^mistralai/[Mm]istral', r'^mistralai/[Mm]ixtral', r'^microsoft/phi', r'^Qwen/', r'^deepseek-ai/', r'^bigcode/starcoder', r'^codellama/', r'^TinyLlama/', r'^WizardLM/', ) OLLAMA_PRIORITY = 10 ================================================ FILE: langextract/providers/router.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Runtime registry that maps model-ID patterns to provider classes. This module provides a lazy registration system for LLM providers, allowing providers to be registered without importing their dependencies until needed. """ # pylint: disable=duplicate-code from __future__ import annotations import dataclasses import functools import importlib import re import typing from absl import logging from langextract.core import base_model from langextract.core import exceptions TLanguageModel = typing.TypeVar( "TLanguageModel", bound=base_model.BaseLanguageModel ) @dataclasses.dataclass(frozen=True, slots=True) class _Entry: """Registry entry for a provider.""" patterns: tuple[re.Pattern[str], ...] loader: typing.Callable[[], type[base_model.BaseLanguageModel]] priority: int _entries: list[_Entry] = [] _entry_keys: set[tuple[str, tuple[str, ...], int]] = ( set() ) # (provider_id, patterns, priority) def _add_entry( *, provider_id: str, patterns: tuple[re.Pattern[str], ...], loader: typing.Callable[[], type[base_model.BaseLanguageModel]], priority: int, ) -> None: """Add an entry to the registry with deduplication.""" key = (provider_id, tuple(p.pattern for p in patterns), priority) if key in _entry_keys: logging.debug( "Skipping duplicate registration for %s with patterns %s at" " priority %d", provider_id, [p.pattern for p in patterns], priority, ) return _entry_keys.add(key) _entries.append(_Entry(patterns=patterns, loader=loader, priority=priority)) logging.debug( "Registered provider %s with patterns %s at priority %d", provider_id, [p.pattern for p in patterns], priority, ) def register_lazy( *patterns: str | re.Pattern[str], target: str, priority: int = 0 ) -> None: """Register a provider lazily using string import path. Args: *patterns: One or more regex patterns to match model IDs. target: Import path in format "module.path:ClassName". priority: Priority for resolution (higher wins on conflicts). """ compiled = tuple(re.compile(p) if isinstance(p, str) else p for p in patterns) def _loader() -> type[base_model.BaseLanguageModel]: module_path, class_name = target.rsplit(":", 1) module = importlib.import_module(module_path) return getattr(module, class_name) _add_entry( provider_id=target, patterns=compiled, loader=_loader, priority=priority, ) def register( *patterns: str | re.Pattern[str], priority: int = 0 ) -> typing.Callable[[type[TLanguageModel]], type[TLanguageModel]]: """Decorator to register a provider class directly. Args: *patterns: One or more regex patterns to match model IDs. priority: Priority for resolution (higher wins on conflicts). Returns: Decorator function that registers the class. """ compiled = tuple(re.compile(p) if isinstance(p, str) else p for p in patterns) def _decorator(cls: type[TLanguageModel]) -> type[TLanguageModel]: def _loader() -> type[base_model.BaseLanguageModel]: return cls provider_id = f"{cls.__module__}:{cls.__name__}" _add_entry( provider_id=provider_id, patterns=compiled, loader=_loader, priority=priority, ) return cls return _decorator @functools.lru_cache(maxsize=128) def resolve(model_id: str) -> type[base_model.BaseLanguageModel]: """Resolve a model ID to a provider class. Args: model_id: The model identifier to resolve. Returns: The provider class that handles this model ID. Raises: ValueError: If no provider is registered for the model ID. """ # Providers should be loaded by the caller (e.g., factory.create_model) # Router doesn't load providers to avoid circular dependencies sorted_entries = sorted(_entries, key=lambda e: e.priority, reverse=True) for entry in sorted_entries: if any(pattern.search(model_id) for pattern in entry.patterns): return entry.loader() available_patterns = [str(p.pattern) for e in _entries for p in e.patterns] raise exceptions.InferenceConfigError( f"No provider registered for model_id={model_id!r}. " f"Available patterns: {available_patterns}\n" "Tip: You can explicitly specify a provider using 'config' parameter " "with factory.ModelConfig and a provider class." ) @functools.lru_cache(maxsize=128) def resolve_provider(provider_name: str) -> type[base_model.BaseLanguageModel]: """Resolve a provider name to a provider class. This allows explicit provider selection by name or class name. Args: provider_name: The provider name (e.g., "gemini", "openai") or class name (e.g., "GeminiLanguageModel"). Returns: The provider class. Raises: ValueError: If no provider matches the name. """ # Providers should be loaded by the caller (e.g., factory.create_model) # Router doesn't load providers to avoid circular dependencies for entry in _entries: for pattern in entry.patterns: if pattern.pattern == f"^{re.escape(provider_name)}$": return entry.loader() for entry in _entries: try: provider_class = entry.loader() class_name = provider_class.__name__ if provider_name.lower() in class_name.lower(): return provider_class except (ImportError, AttributeError): continue try: pattern = re.compile(f"^{provider_name}$", re.IGNORECASE) for entry in _entries: for entry_pattern in entry.patterns: if pattern.pattern == entry_pattern.pattern: return entry.loader() except re.error: pass raise exceptions.InferenceConfigError( f"No provider found matching: {provider_name!r}. " "Available providers can be listed with list_providers()" ) def clear() -> None: """Clear all registered providers. Mainly for testing.""" global _entries # pylint: disable=global-statement _entries = [] _entry_keys.clear() # Also clear dedup keys to allow re-registration resolve.cache_clear() resolve_provider.cache_clear() def list_providers() -> list[tuple[tuple[str, ...], int]]: """List all registered providers with their patterns and priorities. Returns: List of (patterns, priority) tuples for debugging. """ return [ (tuple(p.pattern for p in entry.patterns), entry.priority) for entry in _entries ] def list_entries() -> list[tuple[list[str], int]]: """List all registered patterns and priorities. Mainly for debugging. Returns: List of (patterns, priority) tuples. """ return [([p.pattern for p in e.patterns], e.priority) for e in _entries] ================================================ FILE: langextract/providers/schemas/__init__.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provider-specific schema implementations.""" from __future__ import annotations from langextract.providers.schemas import gemini GeminiSchema = gemini.GeminiSchema # Backward compat __all__ = ["GeminiSchema"] ================================================ FILE: langextract/providers/schemas/gemini.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Gemini provider schema implementation.""" # pylint: disable=duplicate-code from __future__ import annotations from collections.abc import Sequence import dataclasses from typing import Any import warnings from langextract.core import data from langextract.core import format_handler as fh from langextract.core import schema @dataclasses.dataclass class GeminiSchema(schema.BaseSchema): """Schema implementation for Gemini structured output. Converts ExampleData objects into an OpenAPI/JSON-schema definition that Gemini can interpret via 'response_schema'. """ _schema_dict: dict[str, Any] @property def schema_dict(self) -> dict[str, Any]: """Returns the schema dictionary.""" return self._schema_dict @schema_dict.setter def schema_dict(self, schema_dict: dict[str, Any]) -> None: """Sets the schema dictionary.""" self._schema_dict = schema_dict def to_provider_config(self) -> dict[str, Any]: """Convert schema to Gemini-specific configuration. Returns: Dictionary with response_schema and response_mime_type for Gemini API. """ return { "response_schema": self._schema_dict, "response_mime_type": "application/json", } @property def requires_raw_output(self) -> bool: """Gemini outputs raw JSON via response_mime_type.""" return True def validate_format(self, format_handler: fh.FormatHandler) -> None: """Validate Gemini's format requirements. Gemini requires: - No fence markers (outputs raw JSON via response_mime_type) - Wrapper with EXTRACTIONS_KEY (built into response_schema) """ # Check for fence usage with raw JSON output if format_handler.use_fences: warnings.warn( "Gemini outputs native JSON via" " response_mime_type='application/json'. Using fence_output=True may" " cause parsing issues. Set fence_output=False.", UserWarning, stacklevel=3, ) # Verify wrapper is enabled with correct key if ( not format_handler.use_wrapper or format_handler.wrapper_key != data.EXTRACTIONS_KEY ): warnings.warn( "Gemini's response_schema expects" f" wrapper_key='{data.EXTRACTIONS_KEY}'. Current settings:" f" use_wrapper={format_handler.use_wrapper}," f" wrapper_key='{format_handler.wrapper_key}'", UserWarning, stacklevel=3, ) @classmethod def from_examples( cls, examples_data: Sequence[data.ExampleData], attribute_suffix: str = data.ATTRIBUTE_SUFFIX, ) -> GeminiSchema: """Creates a GeminiSchema from example extractions. Builds a JSON-based schema with a top-level "extractions" array. Each element in that array is an object containing the extraction class name and an accompanying "_attributes" object for its attributes. Args: examples_data: A sequence of ExampleData objects containing extraction classes and attributes. attribute_suffix: String appended to each class name to form the attributes field name (defaults to "_attributes"). Returns: A GeminiSchema with internal dictionary represents the JSON constraint. """ # Track attribute types for each category extraction_categories: dict[str, dict[str, set[type]]] = {} for example in examples_data: for extraction in example.extractions: category = extraction.extraction_class if category not in extraction_categories: extraction_categories[category] = {} if extraction.attributes: for attr_name, attr_value in extraction.attributes.items(): if attr_name not in extraction_categories[category]: extraction_categories[category][attr_name] = set() extraction_categories[category][attr_name].add(type(attr_value)) extraction_properties: dict[str, dict[str, Any]] = {} for category, attrs in extraction_categories.items(): extraction_properties[category] = {"type": "string"} attributes_field = f"{category}{attribute_suffix}" attr_properties = {} # Default property for categories without attributes if not attrs: attr_properties["_unused"] = {"type": "string"} else: for attr_name, attr_types in attrs.items(): # List attributes become arrays if list in attr_types: attr_properties[attr_name] = { "type": "array", "items": {"type": "string"}, # type: ignore[dict-item] } else: attr_properties[attr_name] = {"type": "string"} extraction_properties[attributes_field] = { "type": "object", "properties": attr_properties, "nullable": True, } extraction_schema = { "type": "object", "properties": extraction_properties, } schema_dict = { "type": "object", "properties": { data.EXTRACTIONS_KEY: {"type": "array", "items": extraction_schema} }, "required": [data.EXTRACTIONS_KEY], } return cls(_schema_dict=schema_dict) ================================================ FILE: langextract/py.typed ================================================ ================================================ FILE: langextract/registry.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.registry imports. This module redirects to langextract.plugins for backward compatibility. Will be removed in v2.0.0. """ from __future__ import annotations import warnings from langextract import plugins def __getattr__(name: str): """Redirect to plugins module with deprecation warning.""" warnings.warn( "`langextract.registry` is deprecated and will be removed in v2.0.0; " "use `langextract.plugins` instead.", FutureWarning, stacklevel=2, ) return getattr(plugins, name) ================================================ FILE: langextract/resolver.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library for resolving LLM output. In the context of this module, a "resolver" is a component designed to parse and transform the textual output of an LLM into structured data. """ from __future__ import annotations import abc import collections from collections.abc import Iterator, Mapping, Sequence import difflib import functools import itertools import operator from typing import Final from absl import logging from langextract.core import data from langextract.core import exceptions from langextract.core import format_handler as fh from langextract.core import schema from langextract.core import tokenizer as tokenizer_lib _FUZZY_ALIGNMENT_MIN_THRESHOLD = 0.75 # Default suffix for extraction index keys (e.g., "entity_index") DEFAULT_INDEX_SUFFIX = "_index" # Suffix for index fields in extraction sorting ALIGNMENT_PARAM_KEYS: Final[frozenset[str]] = frozenset({ "enable_fuzzy_alignment", "fuzzy_alignment_threshold", "accept_match_lesser", "suppress_parse_errors", }) class AbstractResolver(abc.ABC): """Resolves LLM text outputs into structured data.""" # TODO: Review value and requirements for abstract class. def __init__( self, fence_output: bool = True, constraint: schema.Constraint = schema.Constraint(), format_type: data.FormatType = data.FormatType.JSON, ): """Initializes the BaseResolver. Delimiters are used for parsing text blocks, and are used primarily for models that do not have constrained-decoding support. Args: fence_output: Whether to expect/generate fenced output (```json or ```yaml). When True, the model is prompted to generate fenced output and the resolver expects it. When False, raw JSON/YAML is expected. If your model utilizes schema constraints, this can generally be set to False unless the constraint also accounts for code fence delimiters. constraint: Applies constraint when decoding the output. Defaults to no constraint. format_type: The format type for the output (JSON or YAML). """ self._fence_output = fence_output self._constraint = constraint self._format_type = format_type @property def fence_output(self) -> bool: """Returns whether fenced output is expected.""" return self._fence_output @fence_output.setter def fence_output(self, fence_output: bool) -> None: """Sets whether fenced output is expected. Args: fence_output: Whether to expect fenced output. """ self._fence_output = fence_output @property def format_type(self) -> data.FormatType: """Returns the format type.""" return self._format_type @format_type.setter def format_type(self, new_format_type: data.FormatType) -> None: """Sets a new format type.""" self._format_type = new_format_type @abc.abstractmethod def resolve( self, input_text: str, **kwargs, ) -> Sequence[data.Extraction]: """Run resolve function on input text. Args: input_text: The input text to be processed. **kwargs: Additional arguments for subclass implementations. Returns: Annotated text in the form of Extractions. """ @abc.abstractmethod def align( self, extractions: Sequence[data.Extraction], source_text: str, token_offset: int, char_offset: int | None = None, enable_fuzzy_alignment: bool = True, fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, accept_match_lesser: bool = True, **kwargs, ) -> Iterator[data.Extraction]: """Aligns extractions with source text, setting token/char intervals and alignment status. Uses exact matching first (difflib), then fuzzy alignment fallback if enabled. Alignment Status Results: - MATCH_EXACT: Perfect token-level match - MATCH_LESSER: Partial exact match (extraction longer than matched text) - MATCH_FUZZY: Best overlap window meets threshold (≥ fuzzy_alignment_threshold) - None: No alignment found Args: extractions: Annotated extractions to align with the source text. source_text: The text in which to align the extractions. token_offset: The token_offset corresponding to the starting token index of the chunk. char_offset: The char_offset corresponding to the starting character index of the chunk. enable_fuzzy_alignment: Whether to use fuzzy alignment when exact matching fails. fuzzy_alignment_threshold: Minimum token overlap ratio for fuzzy alignment (0-1). accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER status). **kwargs: Additional keyword arguments for provider-specific alignment. Yields: Aligned extractions with updated token intervals and alignment status. """ class ResolverParsingError(exceptions.LangExtractError): """Error raised when content cannot be parsed as the given format.""" class Resolver(AbstractResolver): """Resolver for YAML/JSON-based information extraction. By default, extractions are returned in the order they appear in the model output. To enable index-based sorting, set extraction_index_suffix to a value like "_index" (the DEFAULT_INDEX_SUFFIX constant). This will sort extractions by fields ending with that suffix (e.g., "entity_index"). Uses FormatHandler for parsing model output into extractions. """ def __init__( self, format_handler: fh.FormatHandler | None = None, extraction_index_suffix: str | None = None, **kwargs, # Collect legacy parameters ): """Constructor. Args: format_handler: The format handler that knows how to parse output. extraction_index_suffix: Suffix identifying index keys that determine the ordering of extractions. **kwargs: Legacy parameters (fence_output, format_type, etc.) for backward compatibility. These will be used to create a FormatHandler if one is not provided. Support for these parameters will be removed in v2.0.0. """ constraint = kwargs.pop("constraint", None) extraction_attributes_suffix = kwargs.pop( "extraction_attributes_suffix", None ) if format_handler is None: if kwargs or extraction_attributes_suffix is not None: handler_kwargs = dict(kwargs) if extraction_attributes_suffix is not None: handler_kwargs["attribute_suffix"] = extraction_attributes_suffix format_handler = fh.FormatHandler.from_kwargs(**handler_kwargs) for param in [ "fence_output", "format_type", "strict_fences", "require_extractions_key", "attribute_suffix", ]: kwargs.pop(param, None) else: format_handler = fh.FormatHandler() if kwargs: raise TypeError( f"got an unexpected keyword argument '{list(kwargs.keys())[0]}'" ) constraint = constraint or schema.Constraint() super().__init__( fence_output=format_handler.use_fences, format_type=format_handler.format_type, constraint=constraint, ) self.format_handler = format_handler self.extraction_index_suffix = extraction_index_suffix self._constraint = constraint def resolve( self, input_text: str, suppress_parse_errors: bool = False, **kwargs, ) -> Sequence[data.Extraction]: """Runs resolve function on text with YAML/JSON extraction data. Args: input_text: The input text to be processed. suppress_parse_errors: Log errors and continue pipeline. **kwargs: Additional keyword arguments. Returns: Annotated text in the form of a sequence of data.Extraction objects. Raises: ResolverParsingError: If the content within the string cannot be parsed due to formatting errors, or if the parsed content is not as expected. """ logging.debug("Starting resolver process for input text.") logging.debug("Input Text: %s", input_text) try: constraint = getattr(self, "_constraint", schema.Constraint()) strict = getattr(constraint, "strict", False) extraction_data = self.format_handler.parse_output( input_text, strict=strict ) logging.debug("Parsed content: %s", extraction_data) except exceptions.FormatError as e: if suppress_parse_errors: logging.exception( "Failed to parse input_text: %s, error: %s", input_text, e ) return [] raise ResolverParsingError(str(e)) from e processed_extractions = self.extract_ordered_extractions(extraction_data) logging.debug("Completed the resolver process.") return processed_extractions def align( self, extractions: Sequence[data.Extraction], source_text: str, token_offset: int, char_offset: int | None = None, enable_fuzzy_alignment: bool = True, fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, accept_match_lesser: bool = True, tokenizer_inst: tokenizer_lib.Tokenizer | None = None, **kwargs, ) -> Iterator[data.Extraction]: """Aligns annotated extractions with source text. This uses WordAligner which is based on Python's difflib SequenceMatcher to match tokens in the source text with tokens from the annotated extractions. If the extraction order is significantly different from the source text order, difflib may skip some matches, leaving certain extractions unmatched. Args: extractions: Annotated extractions. source_text: The text chunk in which to align the extractions. token_offset: The starting token index of the chunk. char_offset: The starting character index of the chunk. enable_fuzzy_alignment: Whether to enable fuzzy alignment fallback. fuzzy_alignment_threshold: Minimum overlap ratio required for fuzzy alignment. accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER status). tokenizer_inst: Optional tokenizer instance. **kwargs: Additional parameters. Yields: Iterator on aligned extractions. """ logging.debug("Starting alignment process for provided chunk text.") if not extractions: logging.debug( "No extractions found in the annotated text; exiting alignment" " process." ) return else: extractions_group = [extractions] aligner = WordAligner() aligned_yaml_extractions = aligner.align_extractions( extractions_group, source_text, token_offset, char_offset or 0, enable_fuzzy_alignment=enable_fuzzy_alignment, fuzzy_alignment_threshold=fuzzy_alignment_threshold, accept_match_lesser=accept_match_lesser, tokenizer_impl=tokenizer_inst, ) logging.debug( "Aligned extractions count: %d", sum(len(group) for group in aligned_yaml_extractions), ) for extraction in itertools.chain(*aligned_yaml_extractions): logging.debug("Yielding aligned extraction: %s", extraction) yield extraction logging.debug("Completed alignment process for the provided source_text.") def string_to_extraction_data( self, input_string: str, ) -> Sequence[Mapping[str, fh.ExtractionValueType]]: """Parses a YAML or JSON-formatted string into extraction data. This method is kept for backward compatibility with tests. It delegates to the FormatHandler for actual parsing. Args: input_string: A string containing YAML or JSON content. Returns: Sequence[Mapping[str, fh.ExtractionValueType]]: A sequence of parsed objects. Raises: ResolverParsingError: If the content within the string cannot be parsed. ValueError: If the input is invalid or does not contain expected format. """ if not input_string or not isinstance(input_string, str): logging.error("Input string must be a non-empty string.") raise ValueError("Input string must be a non-empty string.") try: constraint = getattr(self, "_constraint", schema.Constraint()) strict = getattr(constraint, "strict", False) return self.format_handler.parse_output(input_string, strict=strict) except exceptions.FormatError as e: raise ResolverParsingError(str(e)) from e except Exception as e: logging.exception("Failed to parse content.") raise ResolverParsingError("Failed to parse content.") from e def extract_ordered_extractions( self, extraction_data: Sequence[Mapping[str, fh.ExtractionValueType]], ) -> Sequence[data.Extraction]: """Extracts and orders extraction data based on their associated indexes. This function processes a list of dictionaries, each containing pairs of extraction class keys and their corresponding values, along with optionally associated index keys (identified by the index_suffix). It sorts these pairs by their indices in ascending order and excludes pairs without an index key, returning a list of lists of tuples (extraction_class: str, extraction_text: str). Args: extraction_data: A list of dictionaries. Each dictionary contains pairs of extraction class keys and their values, along with optional index keys. Returns: Extractions sorted by the index attribute or by order of appearance. If two extractions have the same index, their group order dictates the sorting order. Raises: ValueError: If the extraction text is not a string or integer, or if the index is not an integer. """ logging.debug("Starting to extract and order extractions from data.") if not extraction_data: logging.debug("Received empty extraction data.") processed_extractions = [] extraction_index = 0 index_suffix = self.extraction_index_suffix attributes_suffix = self.format_handler.attribute_suffix for group_index, group in enumerate(extraction_data): for extraction_class, extraction_value in group.items(): if index_suffix and extraction_class.endswith(index_suffix): if not isinstance(extraction_value, int): logging.error( "Index must be an integer. Found: %s", type(extraction_value), ) raise ValueError("Index must be an integer.") continue if attributes_suffix and extraction_class.endswith(attributes_suffix): if not isinstance(extraction_value, (dict, type(None))): logging.error( "Attributes must be a dict or None. Found: %s", type(extraction_value), ) raise ValueError( "Extraction value must be a dict or None for attributes." ) continue if not isinstance(extraction_value, (str, int, float)): logging.error( "Extraction text must be a string, integer, or float. Found: %s", type(extraction_value), ) raise ValueError( "Extraction text must be a string, integer, or float." ) if not isinstance(extraction_value, str): extraction_value = str(extraction_value) if index_suffix: index_key = extraction_class + index_suffix extraction_index = group.get(index_key, None) if extraction_index is None: logging.debug( "No index value for %s. Skipping extraction.", extraction_class ) continue else: extraction_index += 1 attributes = None if attributes_suffix: attributes_key = extraction_class + attributes_suffix attributes = group.get(attributes_key, None) processed_extractions.append( data.Extraction( extraction_class=extraction_class, extraction_text=extraction_value, extraction_index=extraction_index, group_index=group_index, attributes=attributes, ) ) processed_extractions.sort(key=operator.attrgetter("extraction_index")) logging.debug("Completed extraction and ordering of extractions.") return processed_extractions class WordAligner: """Aligns words between two sequences of tokens using Python's difflib.""" def __init__(self): """Initialize the WordAligner with difflib SequenceMatcher.""" self.matcher = difflib.SequenceMatcher(autojunk=False) self.source_tokens: Sequence[str] | None = None self.extraction_tokens: Sequence[str] | None = None def _set_seqs( self, source_tokens: Sequence[str] | Iterator[str], extraction_tokens: Sequence[str] | Iterator[str], ): """Sets the source and extraction tokens for alignment. Args: source_tokens: A nonempty sequence or iterator of word-level tokens from source text. extraction_tokens: A nonempty sequence or iterator of extraction tokens in order for matching to the source. """ if isinstance(source_tokens, Iterator): source_tokens = list(source_tokens) if isinstance(extraction_tokens, Iterator): extraction_tokens = list(extraction_tokens) if not source_tokens or not extraction_tokens: raise ValueError("Source tokens and extraction tokens cannot be empty.") self.source_tokens = source_tokens self.extraction_tokens = extraction_tokens self.matcher.set_seqs(a=source_tokens, b=extraction_tokens) def _get_matching_blocks(self) -> Sequence[tuple[int, int, int]]: """Utilizes difflib SequenceMatcher and returns matching blocks of tokens. Returns: Sequence of matching blocks between source_tokens (S) and extraction_tokens (E). Each block (i, j, n) conforms to: S[i:i+n] == E[j:j+n], guaranteed to be monotonically increasing in j. Final entry is a dummy with value (len(S), len(E), 0). """ if self.source_tokens is None or self.extraction_tokens is None: raise ValueError( "Source tokens and extraction tokens must be set before getting" " matching blocks." ) return self.matcher.get_matching_blocks() def _fuzzy_align_extraction( self, extraction: data.Extraction, source_tokens: list[str], tokenized_text: tokenizer_lib.TokenizedText, token_offset: int, char_offset: int, fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, tokenizer_impl: tokenizer_lib.Tokenizer | None = None, ) -> data.Extraction | None: """Fuzzy-align an extraction using difflib.SequenceMatcher on tokens. The algorithm scans every candidate window in `source_tokens` and selects the window with the highest SequenceMatcher `ratio`. It uses an efficient token-count intersection as a fast pre-check to discard windows that cannot meet the alignment threshold. A match is accepted when the ratio is ≥ `fuzzy_alignment_threshold`. This only runs on unmatched extractions, which is usually a small subset of the total extractions. Args: extraction: The extraction to align. source_tokens: The tokens from the source text. tokenized_text: The tokenized source text. token_offset: The token offset of the current chunk. char_offset: The character offset of the current chunk. fuzzy_alignment_threshold: The minimum ratio for a fuzzy match. tokenizer_impl: Optional tokenizer instance. Returns: The aligned data.Extraction if successful, None otherwise. """ extraction_tokens = list( _tokenize_with_lowercase( extraction.extraction_text, tokenizer_inst=tokenizer_impl ) ) # Work with lightly stemmed tokens so pluralisation doesn't block alignment extraction_tokens_norm = [_normalize_token(t) for t in extraction_tokens] if not extraction_tokens: return None logging.debug( "Fuzzy aligning %r (%d tokens)", extraction.extraction_text, len(extraction_tokens), ) best_ratio = 0.0 best_span: tuple[int, int] | None = None # (start_idx, window_size) len_e = len(extraction_tokens) max_window = len(source_tokens) extraction_counts = collections.Counter(extraction_tokens_norm) min_overlap = int(len_e * fuzzy_alignment_threshold) matcher = difflib.SequenceMatcher(autojunk=False, b=extraction_tokens_norm) for window_size in range(len_e, max_window + 1): if window_size > len(source_tokens): break # Initialize for sliding window window_deque = collections.deque(source_tokens[0:window_size]) window_counts = collections.Counter( [_normalize_token(t) for t in window_deque] ) for start_idx in range(len(source_tokens) - window_size + 1): # Optimization: check if enough overlapping tokens exist before expensive # sequence matching. This is an upper bound on the match count. if (extraction_counts & window_counts).total() >= min_overlap: window_tokens_norm = [_normalize_token(t) for t in window_deque] matcher.set_seq1(window_tokens_norm) matches = sum(size for _, _, size in matcher.get_matching_blocks()) if len_e > 0: ratio = matches / len_e else: ratio = 0.0 if ratio > best_ratio: best_ratio = ratio best_span = (start_idx, window_size) # Slide the window to the right if start_idx + window_size < len(source_tokens): # Remove the leftmost token from the count old_token = window_deque.popleft() old_token_norm = _normalize_token(old_token) window_counts[old_token_norm] -= 1 if window_counts[old_token_norm] == 0: del window_counts[old_token_norm] # Add the new rightmost token to the deque and count new_token = source_tokens[start_idx + window_size] window_deque.append(new_token) new_token_norm = _normalize_token(new_token) window_counts[new_token_norm] += 1 if best_span and best_ratio >= fuzzy_alignment_threshold: start_idx, window_size = best_span try: extraction.token_interval = tokenizer_lib.TokenInterval( start_index=start_idx + token_offset, end_index=start_idx + window_size + token_offset, ) start_token = tokenized_text.tokens[start_idx] end_token = tokenized_text.tokens[start_idx + window_size - 1] extraction.char_interval = data.CharInterval( start_pos=char_offset + start_token.char_interval.start_pos, end_pos=char_offset + end_token.char_interval.end_pos, ) extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY return extraction except IndexError: logging.exception( "Index error while setting intervals during fuzzy alignment." ) return None return None def align_extractions( self, extraction_groups: Sequence[Sequence[data.Extraction]], source_text: str, token_offset: int = 0, char_offset: int = 0, delim: str = "\u241F", # Unicode Symbol for unit separator enable_fuzzy_alignment: bool = True, fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD, accept_match_lesser: bool = True, tokenizer_impl: tokenizer_lib.Tokenizer | None = None, ) -> Sequence[Sequence[data.Extraction]]: """Aligns extractions with their positions in the source text. This method takes a sequence of extractions and the source text, aligning each extraction with its corresponding position in the source text. It returns a sequence of extractions along with token intervals indicating the start and end positions of each extraction in the source text. If an extraction cannot be aligned, its token interval is set to None. Args: extraction_groups: A sequence of sequences, where each inner sequence contains an Extraction object. source_text: The source text against which extractions are to be aligned. token_offset: The offset to add to the start and end indices of the token intervals. char_offset: The offset to add to the start and end positions of the character intervals. delim: Token used to separate multi-token extractions. enable_fuzzy_alignment: Whether to use fuzzy alignment when exact matching fails. fuzzy_alignment_threshold: Minimum token overlap ratio for fuzzy alignment (0-1). accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER status). tokenizer_impl: Optional tokenizer instance. Returns: A sequence of extractions aligned with the source text, including token intervals. """ logging.debug( "WordAligner: Starting alignment of extractions with the source text." " Extraction groups to align: %s", extraction_groups, ) if not extraction_groups: logging.info("No extraction groups provided; returning empty list.") return [] source_tokens = list( _tokenize_with_lowercase(source_text, tokenizer_inst=tokenizer_impl) ) delim_len = len( list(_tokenize_with_lowercase(delim, tokenizer_inst=tokenizer_impl)) ) if delim_len != 1: raise ValueError(f"Delimiter {delim!r} must be a single token.") logging.debug("Using delimiter %r for extraction alignment", delim) extraction_tokens = list( _tokenize_with_lowercase( f" {delim} ".join( extraction.extraction_text for extraction in itertools.chain(*extraction_groups) ), tokenizer_inst=tokenizer_impl, ) ) self._set_seqs(source_tokens, extraction_tokens) index_to_extraction_group = {} extraction_index = 0 for group_index, group in enumerate(extraction_groups): logging.debug( "Processing extraction group %d with %d extractions.", group_index, len(group), ) for extraction in group: # Validate delimiter doesn't appear in extraction text if delim in extraction.extraction_text: raise ValueError( f"Delimiter {delim!r} appears inside extraction text" f" {extraction.extraction_text!r}. This would corrupt alignment" " mapping." ) index_to_extraction_group[extraction_index] = (extraction, group_index) extraction_text_tokens = list( _tokenize_with_lowercase( extraction.extraction_text, tokenizer_inst=tokenizer_impl ) ) extraction_index += len(extraction_text_tokens) + delim_len aligned_extraction_groups: list[list[data.Extraction]] = [ [] for _ in extraction_groups ] tokenized_text = ( tokenizer_impl.tokenize(source_text) if tokenizer_impl else tokenizer_lib.tokenize(source_text) ) # Track which extractions were aligned in the exact matching phase aligned_extractions = [] exact_matches = 0 lesser_matches = 0 # Exact matching phase for i, j, n in self._get_matching_blocks()[:-1]: extraction, _ = index_to_extraction_group.get(j, (None, None)) if extraction is None: logging.debug( "No clean start index found for extraction index=%d iterating" " Difflib matching_blocks", j, ) continue extraction.token_interval = tokenizer_lib.TokenInterval( start_index=i + token_offset, end_index=i + n + token_offset, ) try: start_token = tokenized_text.tokens[i] end_token = tokenized_text.tokens[i + n - 1] extraction.char_interval = data.CharInterval( start_pos=char_offset + start_token.char_interval.start_pos, end_pos=char_offset + end_token.char_interval.end_pos, ) except IndexError as e: raise IndexError( "Failed to align extraction with source text. Extraction token" f" interval {extraction.token_interval} does not match source text" f" tokens {tokenized_text.tokens}." ) from e extraction_text_len = len( list( _tokenize_with_lowercase( extraction.extraction_text, tokenizer_inst=tokenizer_impl ) ) ) if extraction_text_len < n: raise ValueError( "Delimiter prevents blocks greater than extraction length: " f"extraction_text_len={extraction_text_len}, block_size={n}" ) if extraction_text_len == n: extraction.alignment_status = data.AlignmentStatus.MATCH_EXACT exact_matches += 1 aligned_extractions.append(extraction) else: # Partial match (extraction longer than matched text) if accept_match_lesser: extraction.alignment_status = data.AlignmentStatus.MATCH_LESSER lesser_matches += 1 aligned_extractions.append(extraction) else: # Reset intervals when not accepting lesser matches extraction.token_interval = None extraction.char_interval = None extraction.alignment_status = None # Collect unaligned extractions unaligned_extractions = [] for extraction, _ in index_to_extraction_group.values(): if extraction not in aligned_extractions: unaligned_extractions.append(extraction) # Apply fuzzy alignment to remaining extractions if enable_fuzzy_alignment and unaligned_extractions: logging.debug( "Starting fuzzy alignment for %d unaligned extractions", len(unaligned_extractions), ) for extraction in unaligned_extractions: aligned_extraction = self._fuzzy_align_extraction( extraction, source_tokens, tokenized_text, token_offset, char_offset, fuzzy_alignment_threshold, tokenizer_impl=tokenizer_impl, ) if aligned_extraction: aligned_extractions.append(aligned_extraction) logging.debug( "Fuzzy alignment successful for extraction: %s", extraction.extraction_text, ) for extraction, group_index in index_to_extraction_group.values(): aligned_extraction_groups[group_index].append(extraction) logging.debug( "Final aligned extraction groups: %s", aligned_extraction_groups ) return aligned_extraction_groups def _tokenize_with_lowercase( text: str, tokenizer_inst: tokenizer_lib.Tokenizer | None = None, ) -> Iterator[str]: """Extract and lowercase tokens from the input text into words. This function utilizes the tokenizer module to tokenize text and yields lowercased words. Args: text (str): The text to be tokenized. tokenizer_inst: Optional tokenizer instance. Yields: Iterator[str]: An iterator over tokenized words. """ if tokenizer_inst is not None: tokenized_pb2 = tokenizer_inst.tokenize(text) else: tokenized_pb2 = tokenizer_lib.tokenize(text) original_text = tokenized_pb2.text for token in tokenized_pb2.tokens: start = token.char_interval.start_pos end = token.char_interval.end_pos token_str = original_text[start:end] token_str = token_str.lower() yield token_str @functools.lru_cache(maxsize=10000) def _normalize_token(token: str) -> str: """Lowercases and applies light pluralisation stemming.""" token = token.lower() if len(token) > 3 and token.endswith("s") and not token.endswith("ss"): token = token[:-1] return token ================================================ FILE: langextract/schema.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Schema compatibility layer. This module provides backward compatibility for the schema module. New code should import from langextract.core.schema instead. """ from __future__ import annotations # Re-export core schema items with deprecation warnings import warnings from langextract._compat import schema def __getattr__(name: str): """Handle imports with appropriate warnings.""" core_items = { "BaseSchema": ("langextract.core.schema", "BaseSchema"), "Constraint": ("langextract.core.schema", "Constraint"), "ConstraintType": ("langextract.core.schema", "ConstraintType"), "EXTRACTIONS_KEY": ("langextract.core.data", "EXTRACTIONS_KEY"), "ATTRIBUTE_SUFFIX": ("langextract.core.data", "ATTRIBUTE_SUFFIX"), "FormatModeSchema": ("langextract.core.schema", "FormatModeSchema"), } if name in core_items: mod, attr = core_items[name] warnings.warn( f"`langextract.schema.{name}` has moved to `{mod}.{attr}`. Please" " update your imports. This compatibility layer will be removed in" " v2.0.0.", FutureWarning, stacklevel=2, ) module = __import__(mod, fromlist=[attr]) return getattr(module, attr) elif name == "GeminiSchema": return schema.__getattr__(name) raise AttributeError(f"module 'langextract.schema' has no attribute '{name}'") ================================================ FILE: langextract/tokenizer.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compatibility shim for langextract.tokenizer imports. This module provides backward compatibility for code that imports from langextract.tokenizer. All functionality has moved to langextract.core.tokenizer. """ from __future__ import annotations # Re-export everything from core.tokenizer for backward compatibility # pylint: disable=unused-wildcard-import from langextract.core.tokenizer import * ================================================ FILE: langextract/visualization.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for visualizing LangExtract extractions in notebooks. Example ------- >>> import langextract as lx >>> doc = lx.extract(...) >>> lx.visualize(doc) """ from __future__ import annotations import dataclasses import enum import html import itertools import json import pathlib import textwrap from langextract import io from langextract.core import data # Fallback if IPython is not present try: from IPython import get_ipython # type: ignore[import-not-found] from IPython.display import HTML # type: ignore[import-not-found] except ImportError: def get_ipython(): # type: ignore[no-redef] return None HTML = None # pytype: disable=annotation-type-mismatch def _is_jupyter() -> bool: """Check if we're in a Jupyter/IPython environment that can display HTML.""" try: if get_ipython is None: return False ip = get_ipython() if ip is None: return False # Simple check: if we're in IPython and NOT in a plain terminal return ip.__class__.__name__ != 'TerminalInteractiveShell' except Exception: return False _PALETTE: list[str] = [ '#D2E3FC', # Light Blue (Primary Container) '#C8E6C9', # Light Green (Tertiary Container) '#FEF0C3', # Light Yellow (Primary Color) '#F9DEDC', # Light Red (Error Container) '#FFDDBE', # Light Orange (Tertiary Container) '#EADDFF', # Light Purple (Secondary/Tertiary Container) '#C4E9E4', # Light Teal (Teal Container) '#FCE4EC', # Light Pink (Pink Container) '#E8EAED', # Very Light Grey (Neutral Highlight) '#DDE8E8', # Pale Cyan (Cyan Container) ] _VISUALIZATION_CSS = textwrap.dedent("""\ """) def _assign_colors(extractions: list[data.Extraction]) -> dict[str, str]: """Assigns a background colour to each extraction class. Args: extractions: list of extractions. Returns: Mapping from extraction_class to a hex colour string. """ classes = {e.extraction_class for e in extractions if e.char_interval} color_map: dict[str, str] = {} palette_cycle = itertools.cycle(_PALETTE) for cls in sorted(classes): color_map[cls] = next(palette_cycle) return color_map def _filter_valid_extractions( extractions: list[data.Extraction], ) -> list[data.Extraction]: """Filters extractions to only include those with valid char intervals.""" return [ e for e in extractions if ( e.char_interval and e.char_interval.start_pos is not None and e.char_interval.end_pos is not None ) ] class TagType(enum.Enum): """Enum for span boundary tag types.""" START = 'start' END = 'end' @dataclasses.dataclass(frozen=True) class SpanPoint: """Represents a span boundary point for HTML generation. Attributes: position: Character position in the text. tag_type: Type of span boundary (START or END). span_idx: Index of the span for HTML data-idx attribute. extraction: The extraction data associated with this span. """ position: int tag_type: TagType span_idx: int extraction: data.Extraction def _build_highlighted_text( text: str, extractions: list[data.Extraction], color_map: dict[str, str], ) -> str: """Returns text with highlights inserted, supporting nesting. Args: text: Original document text. extractions: List of extraction objects with char_intervals. color_map: Mapping of extraction_class to colour. """ points = [] span_lengths = {} for index, extraction in enumerate(extractions): if ( not extraction.char_interval or extraction.char_interval.start_pos is None or extraction.char_interval.end_pos is None or extraction.char_interval.start_pos >= extraction.char_interval.end_pos ): continue start_pos = extraction.char_interval.start_pos end_pos = extraction.char_interval.end_pos points.append(SpanPoint(start_pos, TagType.START, index, extraction)) points.append(SpanPoint(end_pos, TagType.END, index, extraction)) span_lengths[index] = end_pos - start_pos def sort_key(point: SpanPoint): """Sorts span boundary points for proper HTML nesting. Sorts by position first, then handles ties at the same position to ensure proper HTML nesting. At the same position: 1. End tags come before start tags (to close before opening) 2. Among end tags: shorter spans close first 3. Among start tags: longer spans open first Args: point: SpanPoint containing position, tag_type, span_idx, and extraction. Returns: Sort key tuple ensuring proper nesting order. """ span_length = span_lengths.get(point.span_idx, 0) if point.tag_type == TagType.END: return (point.position, 0, span_length) else: # point.tag_type == TagType.START return (point.position, 1, -span_length) points.sort(key=sort_key) html_parts: list[str] = [] cursor = 0 for point in points: if point.position > cursor: html_parts.append(html.escape(text[cursor : point.position])) if point.tag_type == TagType.START: colour = color_map.get(point.extraction.extraction_class, '#ffff8d') highlight_class = ' lx-current-highlight' if point.span_idx == 0 else '' span_html = ( f'' ) html_parts.append(span_html) else: # point.tag_type == TagType.END html_parts.append('') cursor = point.position if cursor < len(text): html_parts.append(html.escape(text[cursor:])) return ''.join(html_parts) def _build_legend_html(color_map: dict[str, str]) -> str: """Builds legend HTML showing extraction classes and their colors.""" if not color_map: return '' legend_items = [] for extraction_class, colour in color_map.items(): legend_items.append( '{html.escape(extraction_class)}' ) return ( '
Highlights Legend:' f' {" ".join(legend_items)}
' ) def _format_attributes(attributes: dict | None) -> str: """Formats attributes as a single-line string.""" if not attributes: return '{}' valid_attrs = { key: value for key, value in attributes.items() if value not in (None, '', 'null') } if not valid_attrs: return '{}' attrs_parts = [] for key, value in valid_attrs.items(): # Clean up array formatting for better readability if isinstance(value, list): value_str = ', '.join(str(v) for v in value) else: value_str = str(value) attrs_parts.append( f'{html.escape(str(key))}: {html.escape(value_str)}
' ) return '{' + ', '.join(attrs_parts) + '}' def _prepare_extraction_data( text: str, extractions: list[data.Extraction], color_map: dict[str, str], context_chars: int = 150, ) -> list[dict]: """Prepares JavaScript data for extractions.""" extraction_data = [] for i, extraction in enumerate(extractions): # Assertions to inform pytype about the invariants guaranteed by _filter_valid_extractions assert ( extraction.char_interval is not None ), 'char_interval must be non-None for valid extractions' assert ( extraction.char_interval.start_pos is not None ), 'start_pos must be non-None for valid extractions' assert ( extraction.char_interval.end_pos is not None ), 'end_pos must be non-None for valid extractions' start_pos = extraction.char_interval.start_pos end_pos = extraction.char_interval.end_pos context_start = max(0, start_pos - context_chars) context_end = min(len(text), end_pos + context_chars) before_text = text[context_start:start_pos] extraction_text = text[start_pos:end_pos] after_text = text[end_pos:context_end] colour = color_map.get(extraction.extraction_class, '#ffff8d') # Build attributes display attributes_html = ( '
class:' f' {html.escape(extraction.extraction_class)}
' ) attributes_html += ( '
attributes:' f' {_format_attributes(extraction.attributes)}
' ) extraction_data.append({ 'index': i, 'class': extraction.extraction_class, 'text': extraction.extraction_text, 'color': colour, 'startPos': start_pos, 'endPos': end_pos, 'beforeText': html.escape(before_text), 'extractionText': html.escape(extraction_text), 'afterText': html.escape(after_text), 'attributesHtml': attributes_html, }) return extraction_data def _build_visualization_html( text: str, extractions: list[data.Extraction], color_map: dict[str, str], animation_speed: float = 1.0, show_legend: bool = True, ) -> str: """Builds the complete visualization HTML.""" if not extractions: return ( '

No extractions to' ' animate.

' ) # Sort extractions by position for proper HTML nesting. def _extraction_sort_key(extraction): """Sort by position, then by span length descending for proper nesting.""" start = extraction.char_interval.start_pos end = extraction.char_interval.end_pos span_length = end - start return (start, -span_length) # longer spans first sorted_extractions = sorted(extractions, key=_extraction_sort_key) highlighted_text = _build_highlighted_text( text, sorted_extractions, color_map ) extraction_data = _prepare_extraction_data( text, sorted_extractions, color_map ) legend_html = _build_legend_html(color_map) if show_legend else '' js_data = json.dumps(extraction_data) # Prepare pos_info_str safely for pytype for the f-string below first_extraction = extractions[0] assert ( first_extraction.char_interval and first_extraction.char_interval.start_pos is not None and first_extraction.char_interval.end_pos is not None ), 'first extraction must have valid char_interval with start_pos and end_pos' pos_info_str = f'[{first_extraction.char_interval.start_pos}-{first_extraction.char_interval.end_pos}]' html_content = textwrap.dedent(f"""
{legend_html}
{highlighted_text}
Entity 1/{len(extractions)} | Pos {pos_info_str}
""") return html_content def visualize( data_source: data.AnnotatedDocument | str | pathlib.Path, *, animation_speed: float = 1.0, show_legend: bool = True, gif_optimized: bool = True, ) -> HTML | str: """Visualises extraction data as animated highlighted HTML. Args: data_source: Either an AnnotatedDocument or path to a JSONL file. animation_speed: Animation speed in seconds between extractions. show_legend: If ``True``, appends a colour legend mapping extraction classes to colours. gif_optimized: If ``True``, applies GIF-optimized styling with larger fonts, better contrast, and improved dimensions for video capture. Returns: An :class:`IPython.display.HTML` object if IPython is available, otherwise the generated HTML string. """ # Load document if it's a file path if isinstance(data_source, (str, pathlib.Path)): file_path = pathlib.Path(data_source) if not file_path.exists(): raise FileNotFoundError(f'JSONL file not found: {file_path}') documents = list(io.load_annotated_documents_jsonl(file_path)) if not documents: raise ValueError(f'No documents found in JSONL file: {file_path}') annotated_doc = documents[0] # Use first document else: annotated_doc = data_source if not annotated_doc or annotated_doc.text is None: raise ValueError('annotated_doc must contain text to visualise.') if annotated_doc.extractions is None: raise ValueError('annotated_doc must contain extractions to visualise.') # Filter valid extractions - show ALL of them valid_extractions = _filter_valid_extractions(annotated_doc.extractions) if not valid_extractions: empty_html = ( '

No valid extractions to' ' animate.

' ) full_html = _VISUALIZATION_CSS + empty_html if HTML is not None and _is_jupyter(): return HTML(full_html) return full_html color_map = _assign_colors(valid_extractions) visualization_html = _build_visualization_html( annotated_doc.text, valid_extractions, color_map, animation_speed, show_legend, ) full_html = _VISUALIZATION_CSS + visualization_html # Apply GIF optimizations if requested if gif_optimized: full_html = full_html.replace( 'class="lx-animated-wrapper"', 'class="lx-animated-wrapper lx-gif-optimized"', ) if HTML is not None and _is_jupyter(): return HTML(full_html) return full_html ================================================ FILE: pyproject.toml ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [build-system] requires = ["setuptools>=67.0.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "langextract" version = "1.1.1" description = "LangExtract: A library for extracting structured data from language models" readme = "README.md" requires-python = ">=3.10" license = "Apache-2.0" authors = [ {name = "Akshay Goel", email = "goelak@google.com"} ] dependencies = [ "absl-py>=1.0.0", "aiohttp>=3.8.0", "async_timeout>=4.0.0", "exceptiongroup>=1.1.0", "google-genai>=1.39.0", "google-cloud-storage>=2.14.0", "ml-collections>=0.1.0", "more-itertools>=8.0.0", "numpy>=1.20.0", "pandas>=1.3.0", "pydantic>=1.8.0", "python-dotenv>=0.19.0", "PyYAML>=6.0", "regex>=2023.0.0", "requests>=2.25.0", "tqdm>=4.64.0", "typing-extensions>=4.0.0" ] [project.urls] "Homepage" = "https://github.com/google/langextract" "Repository" = "https://github.com/google/langextract" "Documentation" = "https://github.com/google/langextract/blob/main/README.md" "Bug Tracker" = "https://github.com/google/langextract/issues" "Changelog" = "https://github.com/google/langextract/releases" "DOI" = "https://doi.org/10.5281/zenodo.17015089" [project.optional-dependencies] openai = ["openai>=1.50.0"] all = ["openai>=1.50.0"] dev = [ "pyink~=24.3.0", "isort>=5.13.0", "pylint>=3.0.0", "pytype>=2024.10.11", "tox>=4.0.0", "import-linter>=2.0", "pre-commit>=3.5.0", "types-regex>=2023.0.0" ] test = [ "pytest>=7.4.0", "tomli>=2.0.0" ] notebook = [ "ipython>=7.0.0", "notebook>=6.0.0" ] [tool.setuptools] packages = [ "langextract", "langextract._compat", "langextract.core", "langextract.providers", "langextract.providers.schemas" ] include-package-data = true [tool.setuptools.package-data] langextract = ["py.typed"] # Provider discovery mechanism for built-in and third-party providers [project.entry-points."langextract.providers"] gemini = "langextract.providers.gemini:GeminiLanguageModel" ollama = "langextract.providers.ollama:OllamaLanguageModel" openai = "langextract.providers.openai:OpenAILanguageModel" [tool.setuptools.exclude-package-data] "*" = [ "docs*", "tests*", "kokoro*", "*.gif", "*.svg", ] [tool.pytest.ini_options] testpaths = ["tests"] python_files = "*_test.py" python_classes = "Test*" python_functions = "test_*" # Show extra test summary info addopts = "-ra" markers = [ "live_api: marks tests as requiring live API access", "requires_pip: marks tests that perform pip install/uninstall operations", "integration: marks integration tests that test multiple components together", ] [tool.pyink] # Configuration for Google's style guide line-length = 80 unstable = true pyink-indentation = 2 pyink-use-majority-quotes = true [tool.isort] # Configuration for Google's style guide profile = "google" line_length = 80 force_sort_within_sections = true # Allow multiple imports on one line for these modules single_line_exclusions = ["typing", "typing_extensions", "collections.abc"] [tool.importlinter] root_package = "langextract" [[tool.importlinter.contracts]] name = "Providers must not import inference" type = "forbidden" source_modules = ["langextract.providers"] forbidden_modules = ["langextract.inference"] [[tool.importlinter.contracts]] name = "Core must not import providers" type = "forbidden" source_modules = ["langextract.core"] forbidden_modules = ["langextract.providers"] [[tool.importlinter.contracts]] name = "Core must not import high-level modules" type = "forbidden" source_modules = ["langextract.core"] forbidden_modules = [ "langextract.annotation", "langextract.chunking", "langextract.prompting", "langextract.resolver", ] ================================================ FILE: scripts/create_provider_plugin.py ================================================ #!/usr/bin/env python3 # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Create a new LangExtract provider plugin with all boilerplate code. This script automates steps 1-6 of the provider creation checklist: 1. Setup Package Structure 2. Configure Entry Point 3. Implement Provider 4. Add Schema Support (optional) 5. Create and run tests 6. Generate documentation For detailed documentation, see: https://github.com/google/langextract/blob/main/langextract/providers/README.md Usage: python create_provider_plugin.py MyProvider python create_provider_plugin.py MyProvider --with-schema python create_provider_plugin.py MyProvider --patterns "^mymodel" "^custom" """ import argparse import os from pathlib import Path import re import subprocess import sys import textwrap def create_directory_structure(package_name: str, force: bool = False) -> Path: """Step 1: Setup Package Structure.""" print("\n" + "=" * 60) print("STEP 1: Setup Package Structure") print("=" * 60) base_dir = Path(f"langextract-{package_name}") package_dir = base_dir / f"langextract_{package_name}" if base_dir.exists() and any(base_dir.iterdir()) and not force: print(f"ERROR: {base_dir} already exists and is not empty.") print("Use --force to overwrite or choose a different package name.") sys.exit(1) base_dir.mkdir(parents=True, exist_ok=True) package_dir.mkdir(parents=True, exist_ok=True) print(f"✓ Created directory: {base_dir}/") print(f"✓ Created package: {package_dir}/") print("✅ Step 1 complete: Package structure created") return base_dir def create_pyproject_toml( base_dir: Path, provider_name: str, package_name: str ) -> None: """Step 2: Configure Entry Point.""" print("\n" + "=" * 60) print("STEP 2: Configure Entry Point") print("=" * 60) content = textwrap.dedent(f"""\ [build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "langextract-{package_name}" version = "0.1.0" description = "LangExtract provider plugin for {provider_name}" readme = "README.md" requires-python = ">=3.10" license = {{text = "Apache-2.0"}} dependencies = [ "langextract>=1.0.0", # Add your provider's SDK dependencies here ] [project.entry-points."langextract.providers"] {package_name} = "langextract_{package_name}.provider:{provider_name}LanguageModel" [tool.setuptools.packages.find] where = ["."] include = ["langextract_{package_name}*"] """) (base_dir / "pyproject.toml").write_text(content, encoding="utf-8") print("✓ Created pyproject.toml with entry point configuration") print("✅ Step 2 complete: Entry point configured") def create_provider( base_dir: Path, provider_name: str, package_name: str, patterns: list[str], with_schema: bool, ) -> None: """Step 3: Implement Provider.""" print("\n" + "=" * 60) print("STEP 3: Implement Provider") print("=" * 60) package_dir = base_dir / f"langextract_{package_name}" patterns_str = ", ".join(f"r'{p}'" for p in patterns) env_var_safe = re.sub(r"[^A-Z0-9]+", "_", package_name.upper()) + "_API_KEY" schema_imports = ( f""" from langextract_{package_name}.schema import {provider_name}Schema""" if with_schema else "" ) schema_init = ( """ self.response_schema = kwargs.get('response_schema') self.structured_output = kwargs.get('structured_output', False)""" if with_schema else "" ) schema_methods = f""" @classmethod def get_schema_class(cls): \"\"\"Tell LangExtract about our schema support.\"\"\" from langextract_{package_name}.schema import {provider_name}Schema return {provider_name}Schema def apply_schema(self, schema_instance): \"\"\"Apply or clear schema configuration.\"\"\" super().apply_schema(schema_instance) if schema_instance: config = schema_instance.to_provider_config() self.response_schema = config.get('response_schema') self.structured_output = config.get('structured_output', False) else: self.response_schema = None self.structured_output = False""" if with_schema else "" schema_infer = ( """ api_params = {} if self.response_schema: api_params['response_schema'] = self.response_schema # result = self.client.generate(prompt, **api_params)""" if with_schema else """ # result = self.client.generate(prompt, **kwargs)""" ) provider_content = textwrap.dedent(f'''\ """Provider implementation for {provider_name}.""" import os import langextract as lx{schema_imports} @lx.providers.registry.register({patterns_str}, priority=10) class {provider_name}LanguageModel(lx.inference.BaseLanguageModel): """LangExtract provider for {provider_name}. This provider handles model IDs matching: {patterns} """ def __init__(self, model_id: str, api_key: str = None, **kwargs): """Initialize the {provider_name} provider. Args: model_id: The model identifier. api_key: API key for authentication. **kwargs: Additional provider-specific parameters. """ super().__init__() self.model_id = model_id self.api_key = api_key or os.environ.get('{env_var_safe}'){schema_init} # self.client = YourClient(api_key=self.api_key) self._extra_kwargs = kwargs{schema_methods} def infer(self, batch_prompts, **kwargs): """Run inference on a batch of prompts. Args: batch_prompts: List of prompts to process. **kwargs: Additional inference parameters. Yields: Lists of ScoredOutput objects, one per prompt. """ for prompt in batch_prompts:{schema_infer} result = f"Mock response for: {{prompt[:50]}}..." yield [lx.inference.ScoredOutput(score=1.0, output=result)] ''') (package_dir / "provider.py").write_text(provider_content, encoding="utf-8") print("✓ Created provider.py with mock implementation") # Create __init__.py init_content = textwrap.dedent(f'''\ """LangExtract provider plugin for {provider_name}.""" from langextract_{package_name}.provider import {provider_name}LanguageModel __all__ = ['{provider_name}LanguageModel'] __version__ = "0.1.0" ''') (package_dir / "__init__.py").write_text(init_content, encoding="utf-8") print("✓ Created __init__.py with exports") print("✅ Step 3 complete: Provider implementation created") def create_schema( base_dir: Path, provider_name: str, package_name: str ) -> None: """Step 4: Add Schema Support.""" print("\n" + "=" * 60) print("STEP 4: Add Schema Support (Optional)") print("=" * 60) package_dir = base_dir / f"langextract_{package_name}" schema_content = textwrap.dedent(f'''\ """Schema implementation for {provider_name} provider.""" import langextract as lx from langextract import schema class {provider_name}Schema(lx.schema.BaseSchema): """Schema implementation for {provider_name} structured output.""" def __init__(self, schema_dict: dict): """Initialize the schema with a dictionary.""" self._schema_dict = schema_dict @property def schema_dict(self) -> dict: """Return the schema dictionary.""" return self._schema_dict @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): """Build schema from example extractions. Args: examples_data: Sequence of ExampleData objects. attribute_suffix: Suffix for attribute fields. Returns: A configured {provider_name}Schema instance. """ extraction_types = {{}} for example in examples_data: for extraction in example.extractions: class_name = extraction.extraction_class if class_name not in extraction_types: extraction_types[class_name] = set() if extraction.attributes: extraction_types[class_name].update(extraction.attributes.keys()) schema_dict = {{ "type": "object", "properties": {{ "extractions": {{ "type": "array", "items": {{"type": "object"}} }} }}, "required": ["extractions"] }} return cls(schema_dict) def to_provider_config(self) -> dict: """Convert to provider-specific configuration. Returns: Dictionary of provider-specific configuration. """ return {{ "response_schema": self._schema_dict, "structured_output": True }} @property def supports_strict_mode(self) -> bool: """Whether this schema guarantees valid structured output. Returns: True if the provider enforces valid JSON output. """ return False # Set to True only if your provider guarantees valid JSON ''') (package_dir / "schema.py").write_text(schema_content, encoding="utf-8") print("✓ Created schema.py with BaseSchema implementation") print("✅ Step 4 complete: Schema support added") def create_test_script( base_dir: Path, provider_name: str, package_name: str, patterns: list[str], with_schema: bool, ) -> None: """Step 5: Create and run tests.""" print("\n" + "=" * 60) print("STEP 5: Create Tests") print("=" * 60) patterns_literal = "[" + ", ".join(repr(p) for p in patterns) + "]" provider_cls_name = f"{provider_name}LanguageModel" test_content = textwrap.dedent(f'''\ #!/usr/bin/env python3 """Test script for {provider_name} provider (Step 5 checklist).""" import re import sys import langextract as lx from langextract.providers import registry try: from langextract_{package_name} import {provider_cls_name} except ImportError: print("ERROR: Plugin not installed. Run: pip install -e .") sys.exit(1) lx.providers.load_plugins_once() PROVIDER_CLS_NAME = "{provider_cls_name}" PATTERNS = {patterns_literal} def _example_id(pattern: str) -> str: \"\"\"Generate test model ID from pattern.\"\"\" base = re.sub(r'^\\^', '', pattern) m = re.match(r"[A-Za-z0-9._-]+", base) base = m.group(0) if m else (base or "model") return f"{{base}}-test" sample_ids = [_example_id(p) for p in PATTERNS] sample_ids.append("unknown-model") print("Testing {provider_name} Provider - Step 5 Checklist:") print("-" * 50) # 1 & 2. Provider registration + pattern matching via resolve() print("1–2. Provider registration & pattern matching") for model_id in sample_ids: try: provider_class = registry.resolve(model_id) ok = provider_class.__name__ == PROVIDER_CLS_NAME status = "✓" if (ok or model_id == "unknown-model") else "✗" note = "expected" if ok else ("expected (no provider)" if model_id == "unknown-model" else "unexpected provider") print(f" {{status}} {{model_id}} -> {{provider_class.__name__ if ok else 'resolved'}} {{note}}") except Exception as e: if model_id == "unknown-model": print(f" ✓ {{model_id}}: No provider found (expected)") else: print(f" ✗ {{model_id}}: resolve() failed: {{e}}") # 3. Inference sanity check print("\\n3. Test inference with sample prompts") try: model_id = sample_ids[0] if sample_ids[0] != "unknown-model" else (_example_id(PATTERNS[0]) if PATTERNS else "test-model") provider = {provider_cls_name}(model_id=model_id) prompts = ["Test prompt 1", "Test prompt 2"] results = list(provider.infer(prompts)) print(f" ✓ Inference returned {{len(results)}} results") for i, result in enumerate(results): try: out = result[0].output if result and result[0] else None print(f" ✓ Result {{i+1}}: {{(out or '')[:60]}}...") except Exception: print(f" ✗ Result {{i+1}}: Unexpected result shape: {{result}}") except Exception as e: print(f" ✗ ERROR: {{e}}") ''') if with_schema: test_content += textwrap.dedent(f""" # 4. Test schema creation and application print("\\n4. Test schema creation and application") try: from langextract_{package_name}.schema import {provider_name}Schema from langextract import data examples = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="entity", extraction_text="test", attributes={{"type": "example"}} ) ] ) ] schema = {provider_name}Schema.from_examples(examples) print(f" ✓ Schema created (keys={{list(schema.schema_dict.keys())}})") schema_class = {provider_cls_name}.get_schema_class() print(f" ✓ Provider schema class: {{schema_class.__name__}}") provider = {provider_cls_name}(model_id=_example_id(PATTERNS[0]) if PATTERNS else "test-model") provider.apply_schema(schema) print(f" ✓ Schema applied: response_schema={{provider.response_schema is not None}} structured={{getattr(provider, 'structured_output', False)}}") except Exception as e: print(f" ✗ ERROR: {{e}}") """) test_content += textwrap.dedent(f""" # 5. Test factory integration print("\\n5. Test factory integration") try: from langextract import factory config = factory.ModelConfig( model_id=_example_id(PATTERNS[0]) if PATTERNS else "test-model", provider="{provider_cls_name}" ) model = factory.create_model(config) print(f" ✓ Factory created: {{type(model).__name__}}") except Exception as e: print(f" ✗ ERROR: {{e}}") print("\\n" + "-" * 50) print("✅ Testing complete!") """) (base_dir / "test_plugin.py").write_text(test_content, encoding="utf-8") print("✓ Created test_plugin.py with comprehensive tests") print("✅ Step 5 complete: Test suite created") def create_readme( base_dir: Path, provider_name: str, package_name: str, patterns: list[str] ) -> None: """Create README documentation.""" print("\n" + "=" * 60) print("STEP 6: Documentation") print("=" * 60) def _display(p: str) -> str: """Strip leading ^ from pattern for display.""" return p[1:] if p.startswith("^") else p env_var_safe = re.sub(r"[^A-Z0-9]+", "_", package_name.upper()) + "_API_KEY" supported = "\n".join( f"- `{_display(p)}*`: Models matching pattern {p}" for p in patterns ) readme_content = textwrap.dedent(f"""\ # LangExtract {provider_name} Provider A provider plugin for LangExtract that supports {provider_name} models. ## Installation ```bash pip install -e . ``` ## Supported Model IDs {supported} ## Environment Variables - `{env_var_safe}`: API key for authentication ## Usage ```python import langextract as lx result = lx.extract( text="Your document here", model_id="{_display(patterns[0]) if patterns else package_name}-model", prompt_description="Extract entities", examples=[...] ) ``` ## Development 1. Install in development mode: `pip install -e .` 2. Run tests: `python test_plugin.py` 3. Build package: `python -m build` 4. Publish to PyPI: `twine upload dist/*` ## License Apache License 2.0 """) (base_dir / "README.md").write_text(readme_content, encoding="utf-8") print("✓ Created README.md with usage examples") def create_gitignore(base_dir: Path) -> None: """Create .gitignore file with Python-specific entries.""" gitignore_content = textwrap.dedent("""\ # Python __pycache__/ *.py[cod] *$py.class *.so # Distribution / packaging build/ dist/ *.egg-info/ .eggs/ *.egg # Virtual environments .env .venv env/ venv/ ENV/ # Testing & coverage .pytest_cache/ .tox/ htmlcov/ .coverage .coverage.* # Type checking .mypy_cache/ .dmypy.json dmypy.json .pytype/ # IDEs .idea/ .vscode/ *.swp *.swo # OS-specific .DS_Store Thumbs.db # Logs *.log # Temp files *.tmp *.bak *.backup """) (base_dir / ".gitignore").write_text(gitignore_content, encoding="utf-8") print("✓ Created .gitignore file with Python-specific entries") def create_license(base_dir: Path) -> None: """Create LICENSE file.""" license_content = textwrap.dedent("""\ # LICENSE TODO: Add your license here. This is a placeholder license file for your provider plugin. Please replace this with your actual license before distribution. Common options include: - Apache License 2.0 - MIT License - BSD License - GPL License - Proprietary/Commercial License """) (base_dir / "LICENSE").write_text(license_content, encoding="utf-8") print("✓ Created LICENSE file") print("✅ Step 6 complete: Documentation created") def install_and_test(base_dir: Path) -> bool: """Install the plugin and run tests.""" print("\n" + "=" * 60) print("Installing and testing the plugin...") print("=" * 60) os.chdir(base_dir) print("\nInstalling plugin...") result = subprocess.run( [sys.executable, "-m", "pip", "install", "-e", "."], capture_output=True, text=True, check=False, ) if result.returncode: print(f"Installation failed: {result.stderr}") return False print("✓ Plugin installed successfully") print("\nRunning tests...") result = subprocess.run( [sys.executable, "test_plugin.py"], capture_output=True, text=True, check=False, ) print(result.stdout) if result.returncode: print(f"Tests failed: {result.stderr}") return False return True def parse_arguments(): """Parse command line arguments. Returns: Parsed arguments from argparse. """ parser = argparse.ArgumentParser( description="Create a new LangExtract provider plugin", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=textwrap.dedent(""" Examples: python create_provider_plugin.py MyProvider python create_provider_plugin.py MyProvider --with-schema python create_provider_plugin.py MyProvider --patterns "^mymodel" "^custom" python create_provider_plugin.py MyProvider --package-name my_custom_name """), ) parser.add_argument( "provider_name", help="Name of your provider (e.g., MyProvider, CustomLLM)", ) parser.add_argument( "--patterns", nargs="+", default=None, help="Regex patterns for model IDs (default: based on provider name)", ) parser.add_argument( "--package-name", default=None, help="Package name (default: lowercase provider name)", ) parser.add_argument( "--with-schema", action="store_true", help="Include schema support (Step 4)", ) parser.add_argument( "--no-install", action="store_true", help="Skip installation and testing" ) parser.add_argument( "--force", action="store_true", help="Overwrite existing plugin directory if it exists", ) return parser.parse_args() def validate_patterns(patterns: list[str]) -> None: """Validate regex patterns. Args: patterns: List of regex patterns to validate. Raises: SystemExit: If any pattern is invalid. """ for p in patterns: try: re.compile(p) except re.error as e: print(f"ERROR: Invalid regex pattern '{p}': {e}") sys.exit(1) def print_summary( provider_name: str, package_name: str, patterns: list[str], with_schema: bool, ) -> None: """Print configuration summary. Args: provider_name: Name of the provider. package_name: Package name. patterns: List of model ID patterns. with_schema: Whether to include schema support. """ print("\n" + "=" * 60) print("LANGEXTRACT PROVIDER PLUGIN GENERATOR") print("=" * 60) print(f"Provider Name: {provider_name}") print(f"Package Name: langextract-{package_name}") print(f"Model Patterns: {patterns}") print(f"Include Schema: {with_schema}") print("\nFor documentation, see:") print( "https://github.com/google/langextract/blob/main/langextract/providers/README.md" ) def create_plugin( args: argparse.Namespace, package_name: str, patterns: list[str] ) -> Path: """Create the plugin with all necessary files. Args: args: Parsed command line arguments. package_name: Package name. patterns: List of model ID patterns. Returns: Path to the created plugin directory. """ base_dir = create_directory_structure(package_name, force=args.force) create_pyproject_toml(base_dir, args.provider_name, package_name) create_provider( base_dir, args.provider_name, package_name, patterns, args.with_schema ) if args.with_schema: create_schema(base_dir, args.provider_name, package_name) create_test_script( base_dir, args.provider_name, package_name, patterns, args.with_schema ) create_readme(base_dir, args.provider_name, package_name, patterns) create_gitignore(base_dir) create_license(base_dir) return base_dir def print_completion_summary(with_schema: bool) -> None: """Print completion summary. Args: with_schema: Whether schema support was included. """ print("\n" + "=" * 60) print("SUMMARY: Steps 1-6 Completed") print("=" * 60) print("✅ Package structure created") print("✅ Entry point configured") print("✅ Provider implemented") if with_schema: print("✅ Schema support added") print("✅ Tests created") print("✅ Documentation generated") def main(): """Main entry point for the provider plugin generator.""" args = parse_arguments() package_name = args.package_name or args.provider_name.lower() patterns = args.patterns if args.patterns else [f"^{package_name}"] validate_patterns(patterns) print_summary(args.provider_name, package_name, patterns, args.with_schema) base_dir = create_plugin(args, package_name, patterns) print_completion_summary(args.with_schema) if not args.no_install: success = install_and_test(base_dir) if success: print("\n✅ Plugin created, installed, and tested successfully!") print(f"\nYour plugin is ready at: {base_dir.absolute()}") print("\nNext steps:") print(" 1. Replace mock inference with actual API calls") print(" 2. Update documentation with real examples") print(" 3. Build package: python -m build") print(" 4. Publish to PyPI: twine upload dist/*") else: print( "\n⚠️ Plugin created but tests failed. Please check the" " implementation." ) sys.exit(1) else: print(f"\nPlugin created at: {base_dir.absolute()}") print("\nTo install and test:") print(f" cd {base_dir}") print(" pip install -e .") print(" python test_plugin.py") if __name__ == "__main__": main() ================================================ FILE: scripts/validate_community_providers.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #!/usr/bin/env python3 """Validation for COMMUNITY_PROVIDERS.md plugin registry table.""" import os from pathlib import Path import re import re as regex_module import sys from typing import Dict, List, Tuple HEADER_ANCHOR = '| Plugin Name | PyPI Package |' END_MARKER = '' # GitHub username/org and repo patterns GH_NAME = r'[-a-zA-Z0-9]+' # usernames/orgs allow hyphens GH_REPO = r'[-a-zA-Z0-9._]+' # repos allow ., _ GH_USER_LINK = rf'\[@{GH_NAME}\]\(https://github\.com/{GH_NAME}\)' GH_MULTI_USER = rf'^{GH_USER_LINK}(,\s*{GH_USER_LINK})*$' # Markdown link to a GitHub repo GH_REPO_LINK = rf'^\[[^\]]+\]\(https://github\.com/{GH_NAME}/{GH_REPO}\)$' # Issue link must point to LangExtract repository (issues only) LANGEXTRACT_ISSUE_LINK = ( r'^\[[^\]]+\]\(https://github\.com/google/langextract/issues/\d+\)$' ) # PEP 503-ish normalized name (loose): lowercase letters/digits with - _ . separators PYPI_NORMALIZED = r'`[a-z0-9]([\-_.]?[a-z0-9]+)*`' MIN_DESC_LEN = 10 def normalize_pypi(name: str) -> str: """PEP 503 normalization for PyPI package names.""" return regex_module.sub(r'[-_.]+', '-', name.strip().lower()) def find_table_bounds(lines: List[str]) -> Tuple[int, int]: start = end = -1 for i, line in enumerate(lines): if HEADER_ANCHOR in line: start = i elif start >= 0 and END_MARKER in line: end = i break return start, end def parse_row(line: str) -> List[str]: # assumes caller trimmed line parts = [c.strip() for c in line.split('|')[1:-1]] return parts def validate(filepath: Path) -> bool: errors: List[str] = [] warnings: List[str] = [] content = filepath.read_text(encoding='utf-8') lines = content.splitlines() start, end = find_table_bounds(lines) if start < 0: errors.append('Could not find plugin registry table header.') print_report(errors, warnings) return False if end < 0: errors.append( 'Could not find end marker: .' ) print_report(errors, warnings) return False rows: List[Dict] = [] seen_names = set() seen_pkgs = set() for i in range(start + 2, end): raw = lines[i].strip() if not raw: continue if not raw.startswith('|') or not raw.endswith('|'): errors.append( f"Line {i+1}: Not a valid table row (must start and end with '|')." ) continue cols = parse_row(raw) if len(cols) != 6: errors.append(f'Line {i+1}: Expected 6 columns, found {len(cols)}.') continue plugin, pypi, maint, repo, desc, issue_link = cols # Basic presence checks if not plugin: errors.append(f'Line {i+1}: Plugin Name is required.') if not re.fullmatch(PYPI_NORMALIZED, pypi): errors.append( f'Line {i+1}: PyPI package must be backticked and normalized (e.g.,' ' `langextract-provider-foo`).' ) elif pypi and not pypi.strip('`').lower().startswith('langextract-'): errors.append( f'Line {i+1}: PyPI package should start with `langextract-` for' ' discoverability.' ) if not re.fullmatch(GH_MULTI_USER, maint): errors.append( f'Line {i+1}: Maintainer must be one or more GitHub handles as links ' '(e.g., [@alice](https://github.com/alice) or comma-separated).' ) if not re.fullmatch(GH_REPO_LINK, repo): errors.append( f'Line {i+1}: GitHub Repo must be a Markdown link to a GitHub' ' repository.' ) if not desc or len(desc) < MIN_DESC_LEN: errors.append( f'Line {i+1}: Description must be at least {MIN_DESC_LEN} characters.' ) # Issue link is required and must point to LangExtract repo if not issue_link: errors.append(f'Line {i+1}: Issue Link is required.') elif not re.fullmatch(LANGEXTRACT_ISSUE_LINK, issue_link): errors.append( f'Line {i+1}: Issue Link must point to a LangExtract issue (e.g.,' ' [#123](https://github.com/google/langextract/issues/123)).' ) rows.append({ 'line': i + 1, 'plugin': plugin, 'pypi': pypi.strip('`').lower() if pypi else '', }) # Duplicate checks (case-insensitive and PEP 503 normalized) for r in rows: pn_key = r['plugin'].strip().casefold() pk_key = normalize_pypi(r['pypi']) if r['pypi'] else None if pn_key in seen_names: errors.append(f"Line {r['line']}: Duplicate Plugin Name '{r['plugin']}'.") seen_names.add(pn_key) if pk_key and pk_key in seen_pkgs: errors.append(f"Line {r['line']}: Duplicate PyPI Package '{r['pypi']}'.") if pk_key: seen_pkgs.add(pk_key) # Required alphabetical sorting check sorted_by_name = sorted(rows, key=lambda r: r['plugin'].casefold()) if [r['plugin'] for r in rows] != [r['plugin'] for r in sorted_by_name]: errors.append('Registry rows must be alphabetically sorted by Plugin Name.') # Guardrail: discourage leaving only the example entry if len(rows) == 1 and rows[0]['plugin'].lower().startswith('example'): warnings.append( 'The registry currently contains only the example row. Add real' ' providers above the marker.' ) print_report(errors, warnings) return not errors def print_report(errors: List[str], warnings: List[str]) -> None: if errors: print('❌ Validation failed:') for e in errors: print(f' • {e}') if warnings: print('⚠️ Warnings:') for w in warnings: print(f' • {w}') if not errors and not warnings: print('✅ Table format validation passed!') if __name__ == '__main__': path = Path('COMMUNITY_PROVIDERS.md') if len(sys.argv) > 1: path = Path(sys.argv[1]) if not path.exists(): print(f'❌ Error: File not found: {path}') sys.exit(1) ok = validate(path) sys.exit(0 if ok else 1) ================================================ FILE: tests/.pylintrc ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Test-specific Pylint configuration # Inherits from parent ../.pylintrc and adds test-specific relaxations [MASTER] # Python will merge with parent; no need to repeat plugins. [MESSAGES CONTROL] # Additional disables for test code only disable= # --- Test-specific relaxations --- duplicate-code, # Test fixtures often have similar patterns too-many-lines, # Large test files are common missing-module-docstring, # Tests don't need module docs missing-class-docstring, # Test classes are self-explanatory missing-function-docstring, # Test method names describe intent line-too-long, # Golden strings and test data invalid-name, # setUp, tearDown, maxDiff, etc. protected-access, # Tests often access private members use-dict-literal, # Parametrized tests benefit from dict() bad-indentation, # pyink 2-space style conflicts with pylint unused-argument, # Mock callbacks often have unused args import-error, # Test dependencies may not be installed too-many-positional-arguments # Test methods can have many args [DESIGN] # Relax complexity limits for tests max-args = 10 # Fixtures often take many params max-locals = 25 # Complex test setups max-statements = 75 # Detailed test scenarios max-branches = 15 # Multiple test conditions [BASIC] # Allow common test naming patterns good-names=i,j,k,ex,Run,_,id,ok,fd,fp,maxDiff,setUp,tearDown # Include test-specific naming patterns method-rgx=[a-z_][a-z0-9_]{2,50}$|test[A-Z_][a-zA-Z0-9]*$|assert[A-Z][a-zA-Z0-9]*$ ================================================ FILE: tests/annotation_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence import dataclasses import inspect import textwrap from typing import Type from unittest import mock from absl.testing import absltest from absl.testing import parameterized from langextract import annotation from langextract import prompting from langextract import resolver as resolver_lib from langextract.core import data from langextract.core import exceptions from langextract.core import tokenizer from langextract.core import types from langextract.providers import gemini class AnnotatorTest(absltest.TestCase): def setUp(self): super().setUp() self.mock_language_model = self.enter_context( mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, prompt_template=prompting.PromptTemplateStructured(description=""), ) def assert_char_interval_match_source( self, source_text: str, extractions: Sequence[data.Extraction] ): """Case-insensitive assertion that char_interval matches source text. For each extraction, this function extracts the substring from the source text using the extraction's char_interval and asserts that it matches the extraction's text. Note the Alignment process between tokens is also case-insensitive. Args: source_text: The original source text. extractions: A sequence of extractions to check. """ for extraction in extractions: if extraction.alignment_status == data.AlignmentStatus.MATCH_EXACT: assert ( extraction.char_interval is not None ), "char_interval should not be None for AlignmentStatus.MATCH_EXACT" char_int = extraction.char_interval start = char_int.start_pos end = char_int.end_pos self.assertIsNotNone(start, "start_pos should not be None") self.assertIsNotNone(end, "end_pos should not be None") extracted = source_text[start:end] self.assertEqual( extracted.lower(), extraction.extraction_text.lower(), f"Extraction '{extraction.extraction_text}' does not match" f" extracted '{extracted}' using char_interval {char_int}", ) def test_annotate_text_single_chunk(self): text = ( "Patient Jane Doe, ID 67890, received 10mg of Lisinopril daily for" " hypertension diagnosed on 2023-03-15." ) self.mock_language_model.infer.return_value = [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - patient: "Jane Doe" patient_index: 1 patient_id: "67890" patient_id_index: 4 dosage: "10mg" dosage_index: 6 medication: "Lisinopril" medication_index: 8 frequency: "daily" frequency_index: 9 condition: "hypertension" condition_index: 11 diagnosis_date: "2023-03-15" diagnosis_date_index: 13 ```"""), ) ]] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) expected_annotated_text = data.AnnotatedDocument( text=text, extractions=[ data.Extraction( extraction_class="patient", extraction_index=1, extraction_text="Jane Doe", group_index=0, token_interval=tokenizer.TokenInterval( start_index=1, end_index=3 ), char_interval=data.CharInterval(start_pos=8, end_pos=16), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="patient_id", extraction_index=4, extraction_text="67890", group_index=0, token_interval=tokenizer.TokenInterval( start_index=5, end_index=6 ), char_interval=data.CharInterval(start_pos=21, end_pos=26), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="dosage", extraction_index=6, extraction_text="10mg", group_index=0, token_interval=tokenizer.TokenInterval( start_index=8, end_index=10 ), char_interval=data.CharInterval(start_pos=37, end_pos=41), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="medication", extraction_index=8, extraction_text="Lisinopril", group_index=0, token_interval=tokenizer.TokenInterval( start_index=11, end_index=12 ), char_interval=data.CharInterval(start_pos=45, end_pos=55), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="frequency", extraction_index=9, extraction_text="daily", group_index=0, token_interval=tokenizer.TokenInterval( start_index=12, end_index=13 ), char_interval=data.CharInterval(start_pos=56, end_pos=61), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="condition", extraction_index=11, extraction_text="hypertension", group_index=0, token_interval=tokenizer.TokenInterval( start_index=14, end_index=15 ), char_interval=data.CharInterval(start_pos=66, end_pos=78), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="diagnosis_date", extraction_index=13, extraction_text="2023-03-15", group_index=0, token_interval=tokenizer.TokenInterval( start_index=17, end_index=22 ), char_interval=data.CharInterval(start_pos=92, end_pos=102), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), ], ) actual_annotated_text = self.annotator.annotate_text( text, resolver=resolver ) self.assertDataclassEqual(expected_annotated_text, actual_annotated_text) self.assert_char_interval_match_source( text, actual_annotated_text.extractions ) self.mock_language_model.infer.assert_called_once_with( batch_prompts=[f"\n\nQ: {text}\nA: "], ) def test_annotate_text_without_index_suffix(self): text = ( "Patient Jane Doe, ID 67890, received 10mg of Lisinopril daily for" " hypertension diagnosed on 2023-03-15." ) self.mock_language_model.infer.return_value = [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - patient: "Jane Doe" patient_id: "67890" dosage: "10mg" medication: "Lisinopril" frequency: "daily" condition: "hypertension" diagnosis_date: "2023-03-15" ```"""), ) ]] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=None, ) expected_annotated_text = data.AnnotatedDocument( text=text, extractions=[ data.Extraction( extraction_class="patient", extraction_index=1, extraction_text="Jane Doe", group_index=0, token_interval=tokenizer.TokenInterval( start_index=1, end_index=3 ), char_interval=data.CharInterval(start_pos=8, end_pos=16), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="patient_id", extraction_index=2, extraction_text="67890", group_index=0, token_interval=tokenizer.TokenInterval( start_index=5, end_index=6 ), char_interval=data.CharInterval(start_pos=21, end_pos=26), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="dosage", extraction_index=3, extraction_text="10mg", group_index=0, token_interval=tokenizer.TokenInterval( start_index=8, end_index=10 ), char_interval=data.CharInterval(start_pos=37, end_pos=41), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="medication", extraction_index=4, extraction_text="Lisinopril", group_index=0, token_interval=tokenizer.TokenInterval( start_index=11, end_index=12 ), char_interval=data.CharInterval(start_pos=45, end_pos=55), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="frequency", extraction_index=5, extraction_text="daily", group_index=0, token_interval=tokenizer.TokenInterval( start_index=12, end_index=13 ), char_interval=data.CharInterval(start_pos=56, end_pos=61), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="condition", extraction_index=6, extraction_text="hypertension", group_index=0, token_interval=tokenizer.TokenInterval( start_index=14, end_index=15 ), char_interval=data.CharInterval(start_pos=66, end_pos=78), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="diagnosis_date", extraction_index=7, extraction_text="2023-03-15", group_index=0, token_interval=tokenizer.TokenInterval( start_index=17, end_index=22 ), char_interval=data.CharInterval(start_pos=92, end_pos=102), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), ], ) actual_annotated_text = self.annotator.annotate_text( text, resolver=resolver ) self.assertDataclassEqual(expected_annotated_text, actual_annotated_text) self.assert_char_interval_match_source( text, actual_annotated_text.extractions ) self.mock_language_model.infer.assert_called_once_with( batch_prompts=[f"\n\nQ: {text}\nA: "], ) def test_annotate_text_with_attributes_suffix(self): text = ( "Patient Jane Doe, ID 67890, received 10mg of Lisinopril daily for" " hypertension diagnosed on 2023-03-15." ) self.mock_language_model.infer.return_value = [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - patient: "Jane Doe" patient_attributes: status: "IDENTIFIABLE" patient_id: "67890" patient_id_attributes: type: "UNIQUE_IDENTIFIER" dosage: "10mg" dosage_attributes: frequency: "DAILY" medication: "Lisinopril" medication_attributes: class: "ANTIHYPERTENSIVE" frequency: "daily" frequency_attributes: time: "DAILY" condition: "hypertension" condition_attributes: type: "CHRONIC" diagnosis_date: "2023-03-15" diagnosis_date_attributes: status: "RELEVANT" ```"""), ) ]] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=None, extraction_attributes_suffix=data.ATTRIBUTE_SUFFIX, ) expected_annotated_text = data.AnnotatedDocument( text=text, extractions=[ data.Extraction( extraction_class="patient", extraction_index=1, extraction_text="Jane Doe", group_index=0, token_interval=tokenizer.TokenInterval( start_index=1, end_index=3 ), char_interval=data.CharInterval(start_pos=8, end_pos=16), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={ "status": "IDENTIFIABLE", }, ), data.Extraction( extraction_class="patient_id", extraction_index=2, extraction_text="67890", group_index=0, token_interval=tokenizer.TokenInterval( start_index=5, end_index=6 ), char_interval=data.CharInterval(start_pos=21, end_pos=26), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={"type": "UNIQUE_IDENTIFIER"}, ), data.Extraction( extraction_class="dosage", extraction_index=3, extraction_text="10mg", group_index=0, token_interval=tokenizer.TokenInterval( start_index=8, end_index=10 ), char_interval=data.CharInterval(start_pos=37, end_pos=41), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={"frequency": "DAILY"}, ), data.Extraction( extraction_class="medication", extraction_index=4, extraction_text="Lisinopril", group_index=0, token_interval=tokenizer.TokenInterval( start_index=11, end_index=12 ), char_interval=data.CharInterval(start_pos=45, end_pos=55), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={"class": "ANTIHYPERTENSIVE"}, ), data.Extraction( extraction_class="frequency", extraction_index=5, extraction_text="daily", group_index=0, token_interval=tokenizer.TokenInterval( start_index=12, end_index=13 ), char_interval=data.CharInterval(start_pos=56, end_pos=61), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={"time": "DAILY"}, ), data.Extraction( extraction_class="condition", extraction_index=6, extraction_text="hypertension", group_index=0, token_interval=tokenizer.TokenInterval( start_index=14, end_index=15 ), char_interval=data.CharInterval(start_pos=66, end_pos=78), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={"type": "CHRONIC"}, ), data.Extraction( extraction_class="diagnosis_date", extraction_index=7, extraction_text="2023-03-15", group_index=0, token_interval=tokenizer.TokenInterval( start_index=17, end_index=22 ), char_interval=data.CharInterval(start_pos=92, end_pos=102), alignment_status=data.AlignmentStatus.MATCH_EXACT, attributes={"status": "RELEVANT"}, ), ], ) actual_annotated_text = self.annotator.annotate_text( text, resolver=resolver, ) self.assertDataclassEqual(expected_annotated_text, actual_annotated_text) self.assert_char_interval_match_source( text, actual_annotated_text.extractions ) self.mock_language_model.infer.assert_called_once_with( batch_prompts=[f"\n\nQ: {text}\nA: "], ) def test_annotate_text_multiple_chunks(self): self.mock_language_model.infer.side_effect = [ [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - medication: "Aspirin" medication_index: 4 reason: "headache" reason_index: 8 ```"""), ) ]], [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - condition: "fever" condition_index: 2 ```"""), ) ]], ] # Simulating tokenization for text broken into two chunks: # Chunk 1: 'Patient takes one Aspirin for headaches.' # Chunk 2: 'Pt has fever.' text = "Patient takes one Aspirin for headaches. Pt has fever." # Indexes Aligned with Tokens # ------------------------------------------------------------------------- # Index | 0 1 2 3 4 5 6 7 8 9 10 # Token | Patient takes one Aspirin for headaches . Pt has fever . resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) expected_annotated_text = data.AnnotatedDocument( text=text, extractions=[ data.Extraction( extraction_class="medication", extraction_index=4, extraction_text="Aspirin", group_index=0, token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=18, end_pos=25), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="reason", extraction_index=8, extraction_text="headache", group_index=0, ), data.Extraction( extraction_class="condition", extraction_index=2, extraction_text="fever", group_index=0, token_interval=tokenizer.TokenInterval( start_index=9, end_index=10 ), char_interval=data.CharInterval(start_pos=48, end_pos=53), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), ], ) actual_annotated_text = self.annotator.annotate_text( text, max_char_buffer=40, batch_length=1, resolver=resolver, enable_fuzzy_alignment=False, ) self.assertDataclassEqual(expected_annotated_text, actual_annotated_text) self.assert_char_interval_match_source( text, actual_annotated_text.extractions ) self.mock_language_model.infer.assert_has_calls([ mock.call( batch_prompts=[ "\n\nQ: Patient takes one Aspirin for headaches.\nA: " ], enable_fuzzy_alignment=False, ), mock.call( batch_prompts=["\n\nQ: Pt has fever.\nA: "], enable_fuzzy_alignment=False, ), ]) def test_annotate_text_no_extractions(self): text = "Text without extractions." self.mock_language_model.infer.return_value = [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: [] ```"""), ) ]] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) expected_annotated_text = data.AnnotatedDocument(text=text, extractions=[]) actual_annotated_text = self.annotator.annotate_text( text, resolver=resolver ) self.assertDataclassEqual(expected_annotated_text, actual_annotated_text) self.mock_language_model.infer.assert_called_once_with( batch_prompts=[f"\n\nQ: {text}\nA: "], ) class AnnotatorMultipleDocumentTest(parameterized.TestCase): _FIXED_DOCUMENT_CONTENT = "Patient reports migraine." _LLM_INFERENCE = textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - PATIENT: "Patient" PATIENT_index: 0 - SYMPTOM: "migraine" SYMPTOM_index: 2 ```""") _ANNOTATED_DOCUMENT = data.AnnotatedDocument( document_id="", extractions=[ data.Extraction( extraction_class="PATIENT", extraction_text="Patient", token_interval=tokenizer.TokenInterval( start_index=0, end_index=1 ), char_interval=data.CharInterval(start_pos=0, end_pos=7), alignment_status=data.AlignmentStatus.MATCH_EXACT, extraction_index=0, group_index=0, ), data.Extraction( extraction_class="SYMPTOM", extraction_text="migraine", token_interval=tokenizer.TokenInterval( start_index=2, end_index=3 ), char_interval=data.CharInterval(start_pos=16, end_pos=24), alignment_status=data.AlignmentStatus.MATCH_EXACT, extraction_index=2, group_index=1, ), ], text="Patient reports migraine.", ) @parameterized.named_parameters( dict( testcase_name="single_document", documents=[ {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, ], expected_result=[ dataclasses.replace( _ANNOTATED_DOCUMENT, document_id="doc1", ), ], ), dict( testcase_name="multiple_documents", documents=[ {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc2"}, ], expected_result=[ dataclasses.replace( _ANNOTATED_DOCUMENT, document_id="doc1", ), dataclasses.replace( _ANNOTATED_DOCUMENT, document_id="doc2", ), ], ), dict( testcase_name="zero_documents", documents=[], expected_result=[], ), dict( testcase_name="multiple_documents_same_batch", documents=[ {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc2"}, ], expected_result=[ dataclasses.replace( _ANNOTATED_DOCUMENT, document_id="doc1", ), dataclasses.replace( _ANNOTATED_DOCUMENT, document_id="doc2", ), ], batch_length=10, ), ) def test_annotate_documents( self, documents: Sequence[dict[str, str]], expected_result: Sequence[data.AnnotatedDocument], batch_length: int = 1, ): mock_language_model = self.enter_context( mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) ) # Define a side effect function so return length based on batch length. def mock_infer_side_effect(batch_prompts, **kwargs): for _ in batch_prompts: yield [ types.ScoredOutput( score=1.0, output=self._LLM_INFERENCE, ) ] mock_language_model.infer.side_effect = mock_infer_side_effect annotator = annotation.Annotator( language_model=mock_language_model, prompt_template=prompting.PromptTemplateStructured(description=""), ) document_objects = [ data.Document( text=doc["text"], document_id=doc["document_id"], ) for doc in documents ] actual_annotations = list( annotator.annotate_documents( document_objects, resolver=resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), max_char_buffer=200, batch_length=batch_length, debug=False, ) ) self.assertLen(actual_annotations, len(expected_result)) for actual_annotation, expected_annotation in zip( actual_annotations, expected_result ): self.assertDataclassEqual(expected_annotation, actual_annotation) self.assertGreaterEqual(mock_language_model.infer.call_count, 0) @parameterized.named_parameters( dict( testcase_name="same_document_id_contiguous", documents=[ {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, ], expected_exception=exceptions.InvalidDocumentError, ), dict( testcase_name="same_document_id_separated", documents=[ {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc2"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, ], expected_exception=exceptions.InvalidDocumentError, ), ) def test_annotate_documents_exceptions( self, documents: Sequence[dict[str, str]], expected_exception: Type[exceptions.InvalidDocumentError], batch_length: int = 1, ): mock_language_model = self.enter_context( mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) ) mock_language_model.infer.return_value = [ [ types.ScoredOutput( score=1.0, output=self._LLM_INFERENCE, ) ] ] annotator = annotation.Annotator( language_model=mock_language_model, prompt_template=prompting.PromptTemplateStructured(description=""), ) document_objects = [ data.Document(text=doc["text"], document_id=doc["document_id"]) for doc in documents ] with self.assertRaises(expected_exception): list( annotator.annotate_documents( document_objects, max_char_buffer=200, batch_length=batch_length, debug=False, ) ) class AnnotatorMultiPassTest(absltest.TestCase): """Tests for multi-pass extraction functionality.""" def setUp(self): super().setUp() self.mock_language_model = self.enter_context( mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, prompt_template=prompting.PromptTemplateStructured(description=""), ) def test_multipass_extraction_non_overlapping(self): """Test multi-pass extraction with non-overlapping extractions.""" text = "Patient John Smith has diabetes and takes insulin daily." self.mock_language_model.infer.side_effect = [ [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - patient: "John Smith" patient_index: 1 - condition: "diabetes" condition_index: 4 ```"""), ) ]], [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - medication: "insulin" medication_index: 7 - frequency: "daily" frequency_index: 8 ```"""), ) ]], ] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) result = self.annotator.annotate_text( text, resolver=resolver, extraction_passes=2, debug=False ) self.assertLen(result.extractions, 4) extraction_classes = [e.extraction_class for e in result.extractions] self.assertCountEqual( extraction_classes, ["patient", "condition", "medication", "frequency"] ) self.assertEqual(self.mock_language_model.infer.call_count, 2) def test_multipass_extraction_overlapping(self): """Test multi-pass extraction with overlapping extractions (first pass wins).""" text = "Dr. Smith prescribed aspirin." # Mock overlapping extractions - both passes find "Smith" but differently self.mock_language_model.infer.side_effect = [ [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - doctor: "Dr. Smith" doctor_index: 0 ```"""), ) ]], [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - patient: "Smith" patient_index: 1 - medication: "aspirin" medication_index: 2 ```"""), ) ]], ] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) result = self.annotator.annotate_text( text, resolver=resolver, extraction_passes=2, debug=False ) self.assertLen(result.extractions, 2) extraction_classes = [e.extraction_class for e in result.extractions] self.assertCountEqual(extraction_classes, ["doctor", "medication"]) # Verify "Dr. Smith" from first pass is kept, not "Smith" from second pass doctor_extraction = next( e for e in result.extractions if e.extraction_class == "doctor" ) self.assertEqual(doctor_extraction.extraction_text, "Dr. Smith") def test_multipass_extraction_single_pass(self): """Test that extraction_passes=1 behaves like normal single-pass extraction.""" text = "Patient has fever." self.mock_language_model.infer.return_value = [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - patient: "Patient" patient_index: 0 - condition: "fever" condition_index: 2 ```"""), ) ]] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) result = self.annotator.annotate_text( text, resolver=resolver, extraction_passes=1, debug=False # Single pass ) self.assertLen(result.extractions, 2) self.assertEqual(self.mock_language_model.infer.call_count, 1) def test_multipass_extraction_empty_passes(self): """Test multi-pass extraction when some passes return no extractions.""" text = "Test text." self.mock_language_model.infer.side_effect = [ [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - test: "Test" test_index: 0 ```"""), ) ]], [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: [] ```"""), ) ]], ] resolver = resolver_lib.Resolver( format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) result = self.annotator.annotate_text( text, resolver=resolver, extraction_passes=2, debug=False ) self.assertLen(result.extractions, 1) self.assertEqual(result.extractions[0].extraction_class, "test") class MultiPassHelperFunctionsTest(parameterized.TestCase): """Tests for multi-pass helper functions.""" @parameterized.named_parameters( dict( testcase_name="empty_list", all_extractions=[], expected_count=0, expected_classes=[], ), dict( testcase_name="single_pass", all_extractions=[[ data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 5) ), data.Extraction( "class2", "text2", char_interval=data.CharInterval(10, 15) ), ]], expected_count=2, expected_classes=["class1", "class2"], ), dict( testcase_name="non_overlapping_passes", all_extractions=[ [ data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 5) ) ], [ data.Extraction( "class2", "text2", char_interval=data.CharInterval(10, 15) ) ], ], expected_count=2, expected_classes=["class1", "class2"], ), dict( testcase_name="overlapping_passes_first_wins", all_extractions=[ [ data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 10) ) ], [ data.Extraction( "class2", "text2", char_interval=data.CharInterval(5, 15) ), # Overlaps data.Extraction( "class3", "text3", char_interval=data.CharInterval(20, 25) ), # No overlap ], ], expected_count=2, expected_classes=[ "class1", "class3", ], # class2 excluded due to overlap ), ) def test_merge_non_overlapping_extractions( self, all_extractions, expected_count, expected_classes ): """Test merging extractions from multiple passes.""" result = annotation._merge_non_overlapping_extractions(all_extractions) self.assertLen(result, expected_count) if expected_classes: extraction_classes = [e.extraction_class for e in result] self.assertCountEqual(extraction_classes, expected_classes) @parameterized.named_parameters( dict( testcase_name="overlapping_intervals", ext1=data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 10) ), ext2=data.Extraction( "class2", "text2", char_interval=data.CharInterval(5, 15) ), expected=True, ), dict( testcase_name="non_overlapping_intervals", ext1=data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 5) ), ext2=data.Extraction( "class2", "text2", char_interval=data.CharInterval(10, 15) ), expected=False, ), dict( testcase_name="adjacent_intervals", ext1=data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 5) ), ext2=data.Extraction( "class2", "text2", char_interval=data.CharInterval(5, 10) ), expected=False, ), dict( testcase_name="none_interval_first", ext1=data.Extraction("class1", "text1", char_interval=None), ext2=data.Extraction( "class2", "text2", char_interval=data.CharInterval(5, 15) ), expected=False, ), dict( testcase_name="none_interval_second", ext1=data.Extraction( "class1", "text1", char_interval=data.CharInterval(0, 5) ), ext2=data.Extraction("class2", "text2", char_interval=None), expected=False, ), dict( testcase_name="both_none_intervals", ext1=data.Extraction("class1", "text1", char_interval=None), ext2=data.Extraction("class2", "text2", char_interval=None), expected=False, ), ) def test_extractions_overlap(self, ext1, ext2, expected): """Test overlap detection between extractions.""" result = annotation._extractions_overlap(ext1, ext2) self.assertEqual(result, expected) class AnnotateDocumentsGeneratorTest(absltest.TestCase): """Tests that annotate_documents uses 'yield from' for proper delegation.""" def setUp(self): super().setUp() self.mock_language_model = self.enter_context( mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) ) def mock_infer(batch_prompts, **_): """Return medication extractions based on prompt content.""" for prompt in batch_prompts: if "Ibuprofen" in prompt: text = textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - medication: "Ibuprofen" medication_index: 4 ```""") elif "Cefazolin" in prompt: text = textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - medication: "Cefazolin" medication_index: 4 ```""") else: text = f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```" yield [types.ScoredOutput(score=1.0, output=text)] self.mock_language_model.infer.side_effect = mock_infer self.annotator = annotation.Annotator( language_model=self.mock_language_model, prompt_template=prompting.PromptTemplateStructured(description=""), ) def test_yields_documents_not_generators(self): """Verifies annotate_documents yields AnnotatedDocument, not generators.""" docs = [ data.Document( text="Patient took 400 mg PO Ibuprofen q4h for two days.", document_id="doc1", ), data.Document( text="Patient was given 250 mg IV Cefazolin TID for one week.", document_id="doc2", ), ] results = list( self.annotator.annotate_documents( docs, resolver=resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), show_progress=False, debug=False, ) ) self.assertLen(results, 2) self.assertFalse( any(inspect.isgenerator(item) for item in results), msg="Must use 'yield from' to delegate, not 'yield'", ) meds_doc1 = { e.extraction_text for e in results[0].extractions if e.extraction_class == "medication" } meds_doc2 = { e.extraction_text for e in results[1].extractions if e.extraction_class == "medication" } self.assertIn("Ibuprofen", meds_doc1) self.assertNotIn("Cefazolin", meds_doc1) self.assertIn("Cefazolin", meds_doc2) self.assertNotIn("Ibuprofen", meds_doc2) class CrossChunkContextTest(absltest.TestCase): """Tests for cross-chunk context window feature.""" def setUp(self): super().setUp() self.mock_language_model = self.enter_context( mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) ) self.annotator = annotation.Annotator( language_model=self.mock_language_model, prompt_template=prompting.PromptTemplateStructured(description=""), ) def test_context_window_includes_previous_chunk_text(self): """Verifies that context_window_chars passes previous chunk text.""" # Chunk 1: "Dr. Sarah Johnson is a cardiologist." # Chunk 2: "She specializes in heart surgery." text = ( "Dr. Sarah Johnson is a cardiologist. She specializes in heart surgery." ) self.mock_language_model.infer.side_effect = [ [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - person: "Dr. Sarah Johnson" ```"""), ) ]], [[ types.ScoredOutput( score=1.0, output=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - specialization: "heart surgery" ```"""), ) ]], ] resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML) _ = self.annotator.annotate_text( text, max_char_buffer=40, batch_length=1, resolver=resolver, context_window_chars=30, enable_fuzzy_alignment=False, ) calls = self.mock_language_model.infer.call_args_list self.assertLen(calls, 2) first_prompt = calls[0].kwargs["batch_prompts"][0] context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX self.assertNotIn(context_prefix, first_prompt) second_prompt = calls[1].kwargs["batch_prompts"][0] self.assertIn(context_prefix, second_prompt) self.assertIn("cardiologist", second_prompt) def test_no_context_included_when_disabled(self): """Verifies that no context is included when context_window_chars=None.""" text = ( "Dr. Sarah Johnson is a cardiologist. She specializes in heart surgery." ) self.mock_language_model.infer.side_effect = [ [[ types.ScoredOutput( score=1.0, output=f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```" ) ]], [[ types.ScoredOutput( score=1.0, output=f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```" ) ]], ] resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML) _ = self.annotator.annotate_text( text, max_char_buffer=40, batch_length=1, resolver=resolver, context_window_chars=None, # Disabled enable_fuzzy_alignment=False, ) calls = self.mock_language_model.infer.call_args_list self.assertLen(calls, 2) context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX first_prompt = calls[0].kwargs["batch_prompts"][0] second_prompt = calls[1].kwargs["batch_prompts"][0] self.assertNotIn(context_prefix, first_prompt) self.assertNotIn(context_prefix, second_prompt) def test_context_window_per_document_isolation(self): """Verifies context is tracked per document, not across documents.""" docs = [ data.Document(text="Doc1 chunk1. Doc1 chunk2.", document_id="doc1"), data.Document(text="Doc2 chunk1. Doc2 chunk2.", document_id="doc2"), ] empty_response = [[ types.ScoredOutput( score=1.0, output=f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```" ) ]] self.mock_language_model.infer.side_effect = [ empty_response, # Doc1 chunk1 empty_response, # Doc1 chunk2 empty_response, # Doc2 chunk1 empty_response, # Doc2 chunk2 ] resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML) _ = list( self.annotator.annotate_documents( docs, resolver=resolver, max_char_buffer=15, batch_length=1, context_window_chars=20, # Large enough to capture "Doc1 chunk1." show_progress=False, ) ) calls = self.mock_language_model.infer.call_args_list self.assertLen(calls, 4) context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX # Extract prompts in order: doc1_chunk1, doc1_chunk2, doc2_chunk1, doc2_chunk2 doc1_chunk1_prompt = calls[0].kwargs["batch_prompts"][0] doc1_chunk2_prompt = calls[1].kwargs["batch_prompts"][0] doc2_chunk1_prompt = calls[2].kwargs["batch_prompts"][0] doc2_chunk2_prompt = calls[3].kwargs["batch_prompts"][0] # First chunks of each document should NOT have context prefix self.assertNotIn(context_prefix, doc1_chunk1_prompt) self.assertNotIn(context_prefix, doc2_chunk1_prompt) # Second chunks should have context from their own document only self.assertIn(context_prefix, doc1_chunk2_prompt) self.assertIn("Doc1", doc1_chunk2_prompt) self.assertIn(context_prefix, doc2_chunk2_prompt) self.assertIn("Doc2", doc2_chunk2_prompt) # Doc2's chunks should never contain Doc1 content self.assertNotIn("Doc1", doc2_chunk1_prompt) self.assertNotIn("Doc1", doc2_chunk2_prompt) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/chunking_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import textwrap from unittest import mock from absl.testing import absltest from absl.testing import parameterized from langextract import chunking from langextract.core import data from langextract.core import tokenizer class SentenceIterTest(absltest.TestCase): def test_basic(self): text = "This is a sentence. This is a longer sentence. Mr. Bond\nasks\nwhy?" tokenized_text = tokenizer.tokenize(text) sentence_iter = chunking.SentenceIterator(tokenized_text) sentence_interval = next(sentence_iter) self.assertEqual( tokenizer.TokenInterval(start_index=0, end_index=5), sentence_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, sentence_interval), "This is a sentence.", ) sentence_interval = next(sentence_iter) self.assertEqual( tokenizer.TokenInterval(start_index=5, end_index=11), sentence_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, sentence_interval), "This is a longer sentence.", ) sentence_interval = next(sentence_iter) self.assertEqual( tokenizer.TokenInterval(start_index=11, end_index=17), sentence_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, sentence_interval), "Mr. Bond\nasks\nwhy?", ) with self.assertRaises(StopIteration): next(sentence_iter) def test_empty(self): text = "" tokenized_text = tokenizer.tokenize(text) sentence_iter = chunking.SentenceIterator(tokenized_text) with self.assertRaises(StopIteration): next(sentence_iter) class ChunkIteratorTest(absltest.TestCase): def test_multi_sentence_chunk(self): text = "This is a sentence. This is a longer sentence. Mr. Bond\nasks\nwhy?" tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=50, tokenizer_impl=tokenizer.RegexTokenizer(), ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=0, end_index=11), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "This is a sentence. This is a longer sentence.", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=11, end_index=17), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "Mr. Bond\nasks\nwhy?", ) with self.assertRaises(StopIteration): next(chunk_iter) def test_sentence_with_multiple_newlines_and_right_interval(self): text = ( "This is a sentence\n\n" + "This is a longer sentence\n\n" + "Mr\n\nBond\n\nasks why?" ) tokenized_text = tokenizer.tokenize(text) chunk_interval = tokenizer.TokenInterval( start_index=0, end_index=len(tokenized_text.tokens) ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), text, ) def test_break_sentence(self): text = "This is a sentence. This is a longer sentence. Mr. Bond\nasks\nwhy?" tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=12, tokenizer_impl=tokenizer.RegexTokenizer(), ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=0, end_index=3), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "This is a", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=3, end_index=5), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "sentence.", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=5, end_index=8), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "This is a", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=8, end_index=9), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "longer", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=9, end_index=11), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "sentence.", ) for _ in range(2): next(chunk_iter) with self.assertRaises(StopIteration): next(chunk_iter) def test_long_token_gets_own_chunk(self): text = "This is a sentence. This is a longer sentence. Mr. Bond\nasks\nwhy?" tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=7, tokenizer_impl=tokenizer.RegexTokenizer(), ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=0, end_index=2), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "This is", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=2, end_index=3), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "a" ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=3, end_index=4), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "sentence", ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval(start_index=4, end_index=5), chunk_interval ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), "." ) for _ in range(9): next(chunk_iter) with self.assertRaises(StopIteration): next(chunk_iter) def test_newline_at_chunk_boundary_does_not_create_empty_interval(self): """Test that newlines at chunk boundaries don't create empty token intervals. When a newline occurs exactly at a chunk boundary, the chunking algorithm should not attempt to create an empty interval (where start_index == end_index). This was causing a ValueError in create_token_interval(). """ text = "First sentence.\nSecond sentence that is longer.\nThird sentence." tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=20, tokenizer_impl=tokenizer.RegexTokenizer(), ) chunks = list(chunk_iter) for chunk in chunks: self.assertLess( chunk.token_interval.start_index, chunk.token_interval.end_index, "Chunk should have non-empty interval", ) expected_intervals = [(0, 3), (3, 6), (6, 9), (9, 12)] actual_intervals = [ (chunk.token_interval.start_index, chunk.token_interval.end_index) for chunk in chunks ] self.assertEqual(actual_intervals, expected_intervals) def test_chunk_unicode_text(self): text = textwrap.dedent("""\ Chief Complaint: ‘swelling of tongue and difficulty breathing and swallowing’ History of Present Illness: 77 y o woman in NAD with a h/o CAD, DM2, asthma and HTN on altace.""") tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=200, tokenizer_impl=tokenizer.RegexTokenizer(), ) chunk_interval = next(chunk_iter).token_interval self.assertEqual( tokenizer.TokenInterval( start_index=0, end_index=len(tokenized_text.tokens) ), chunk_interval, ) self.assertEqual( chunking.get_token_interval_text(tokenized_text, chunk_interval), text ) def test_newlines_is_secondary_sentence_break(self): text = textwrap.dedent("""\ Medications: Theophyline (Uniphyl) 600 mg qhs – bronchodilator by increasing cAMP used for treating asthma Diltiazem 300 mg qhs – Ca channel blocker used to control hypertension Simvistatin (Zocor) 20 mg qhs- HMGCo Reductase inhibitor for hypercholesterolemia Ramipril (Altace) 10 mg BID – ACEI for hypertension and diabetes for renal protective effect""") tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=200, tokenizer_impl=tokenizer.RegexTokenizer(), ) first_chunk = next(chunk_iter) expected_first_chunk_text = textwrap.dedent("""\ Medications: Theophyline (Uniphyl) 600 mg qhs – bronchodilator by increasing cAMP used for treating asthma Diltiazem 300 mg qhs – Ca channel blocker used to control hypertension""") self.assertEqual( chunking.get_token_interval_text( tokenized_text, first_chunk.token_interval ), expected_first_chunk_text, ) self.assertGreater( first_chunk.token_interval.end_index, first_chunk.token_interval.start_index, ) second_chunk = next(chunk_iter) expected_second_chunk_text = textwrap.dedent("""\ Simvistatin (Zocor) 20 mg qhs- HMGCo Reductase inhibitor for hypercholesterolemia Ramipril (Altace) 10 mg BID – ACEI for hypertension and diabetes for renal protective effect""") self.assertEqual( chunking.get_token_interval_text( tokenized_text, second_chunk.token_interval ), expected_second_chunk_text, ) with self.assertRaises(StopIteration): next(chunk_iter) def test_tokenizer_propagation(self): """Test that tokenizer is correctly propagated to TextChunk's Document.""" text = "Some text." mock_tokenizer = mock.Mock(spec=tokenizer.Tokenizer) mock_tokens = [ tokenizer.Token( index=0, token_type=tokenizer.TokenType.WORD, char_interval=data.CharInterval(start_pos=0, end_pos=4), ), tokenizer.Token( index=1, token_type=tokenizer.TokenType.WORD, char_interval=data.CharInterval(start_pos=5, end_pos=9), ), tokenizer.Token( index=2, token_type=tokenizer.TokenType.PUNCTUATION, char_interval=data.CharInterval(start_pos=9, end_pos=10), ), ] mock_tokenized_text = tokenizer.TokenizedText(text=text, tokens=mock_tokens) mock_tokenizer.tokenize.return_value = mock_tokenized_text chunk_iter = chunking.ChunkIterator( text=text, max_char_buffer=100, tokenizer_impl=mock_tokenizer ) text_chunk = next(chunk_iter) self.assertEqual(text_chunk.document_text, mock_tokenized_text) self.assertEqual(text_chunk.chunk_text, text) class BatchingTest(parameterized.TestCase): _SAMPLE_DOCUMENT = data.Document( text=( "Sample text with numerical values such as 120/80 mmHg, 98.6°F, and" " 50mg." ), ) @parameterized.named_parameters( ( "test_with_data", _SAMPLE_DOCUMENT.tokenized_text, 15, 10, [[ chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=0, end_index=1 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=1, end_index=3 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=4, end_index=5 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=5, end_index=7 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=7, end_index=10 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=10, end_index=14 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=14, end_index=19 ), document=_SAMPLE_DOCUMENT, ), chunking.TextChunk( token_interval=tokenizer.TokenInterval( start_index=19, end_index=22 ), document=_SAMPLE_DOCUMENT, ), ]], ), ( "test_empty_input", "", 15, 10, [], ), ) def test_make_batches_of_textchunk( self, tokenized_text: tokenizer.TokenizedText, batch_length: int, max_char_buffer: int, expected_batches: list[list[chunking.TextChunk]], ): chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer, tokenizer_impl=tokenizer.RegexTokenizer(), ) batches_iter = chunking.make_batches_of_textchunk(chunk_iter, batch_length) actual_batches = [list(batch) for batch in batches_iter] self.assertListEqual( actual_batches, expected_batches, "Batched chunks should match expected structure", ) class TextChunkTest(absltest.TestCase): def test_string_output(self): text = "Example input text." expected = textwrap.dedent("""\ TextChunk( interval=[start_index: 0, end_index: 1], Document ID: test_doc_123, Chunk Text: 'Example' )""") document = data.Document(text=text, document_id="test_doc_123") tokenized_text = tokenizer.tokenize(text) chunk_iter = chunking.ChunkIterator( tokenized_text, max_char_buffer=7, document=document, tokenizer_impl=tokenizer.RegexTokenizer(), ) text_chunk = next(chunk_iter) self.assertEqual(str(text_chunk), expected) class TextAdditionalContextTest(absltest.TestCase): _ADDITIONAL_CONTEXT = "Some additional context for prompt..." def test_text_chunk_additional_context(self): document = data.Document( text="Sample text.", additional_context=self._ADDITIONAL_CONTEXT ) chunk_iter = chunking.ChunkIterator( text=document.tokenized_text, max_char_buffer=100, document=document, tokenizer_impl=tokenizer.RegexTokenizer(), ) text_chunk = next(chunk_iter) self.assertEqual(text_chunk.additional_context, self._ADDITIONAL_CONTEXT) def test_chunk_iterator_without_additional_context(self): document = data.Document(text="Sample text.") chunk_iter = chunking.ChunkIterator( text=document.tokenized_text, max_char_buffer=100, document=document, tokenizer_impl=tokenizer.RegexTokenizer(), ) text_chunk = next(chunk_iter) self.assertIsNone(text_chunk.additional_context) def test_multiple_chunks_with_additional_context(self): text = "Sentence one. Sentence two. Sentence three." document = data.Document( text=text, additional_context=self._ADDITIONAL_CONTEXT ) chunk_iter = chunking.ChunkIterator( text=document.tokenized_text, max_char_buffer=15, document=document, tokenizer_impl=tokenizer.RegexTokenizer(), ) chunks = list(chunk_iter) self.assertGreater( len(chunks), 1, "Should create multiple chunks with small buffer" ) additional_contexts = [chunk.additional_context for chunk in chunks] expected_additional_contexts = [self._ADDITIONAL_CONTEXT] * len(chunks) self.assertListEqual(additional_contexts, expected_additional_contexts) class TextChunkPropertyTest(parameterized.TestCase): @parameterized.named_parameters( { "testcase_name": "with_document", "document": data.Document( text="Sample text.", document_id="doc123", additional_context="Additional info", ), "expected_id": "doc123", "expected_text": "Sample text.", "expected_context": "Additional info", }, { "testcase_name": "no_document", "document": None, "expected_id": None, "expected_text": None, "expected_context": None, }, { "testcase_name": "no_additional_context", "document": data.Document( text="Sample text.", document_id="doc123", ), "expected_id": "doc123", "expected_text": "Sample text.", "expected_context": None, }, ) def test_text_chunk_properties( self, document, expected_id, expected_text, expected_context ): chunk = chunking.TextChunk( token_interval=tokenizer.TokenInterval(start_index=0, end_index=1), document=document, ) self.assertEqual(chunk.document_id, expected_id) if chunk.document_text: self.assertEqual(chunk.document_text.text, expected_text) else: self.assertIsNone(chunk.document_text) self.assertEqual(chunk.additional_context, expected_context) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/data_lib_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json from absl.testing import absltest from absl.testing import parameterized import numpy as np from langextract import data_lib from langextract import io from langextract.core import data from langextract.core import tokenizer class DataLibToDictParameterizedTest(parameterized.TestCase): """Tests conversion of AnnotatedDocument objects to JSON dicts. Verifies that `annotated_document_to_dict` correctly serializes documents by: - Excluding private fields (e.g., token_interval). - Converting all expected extraction attributes properly. - Handling int64 values for extraction indexes. """ @parameterized.named_parameters( dict( testcase_name="single_extraction_no_token_interval", annotated_doc=data.AnnotatedDocument( document_id="docA", text="Just a short sentence.", extractions=[ data.Extraction( extraction_class="note", extraction_text="short sentence", extraction_index=1, group_index=0, ), ], ), expected_dict={ "document_id": "docA", "extractions": [ { "extraction_class": "note", "extraction_text": "short sentence", "char_interval": None, "alignment_status": None, "extraction_index": 1, "group_index": 0, "description": None, "attributes": None, }, ], "text": "Just a short sentence.", }, ), dict( testcase_name="multiple_extractions_with_token_interval", annotated_doc=data.AnnotatedDocument( document_id="docB", text="Patient Jane reported a headache.", extractions=[ data.Extraction( extraction_class="patient", extraction_text="Jane", extraction_index=1, group_index=0, ), data.Extraction( extraction_class="symptom", extraction_text="headache", extraction_index=2, group_index=0, char_interval=data.CharInterval(start_pos=24, end_pos=32), token_interval=tokenizer.TokenInterval( start_index=4, end_index=5 ), # should be ignored alignment_status=data.AlignmentStatus.MATCH_EXACT, ), ], ), expected_dict={ "document_id": "docB", "extractions": [ { "extraction_class": "patient", "extraction_text": "Jane", "char_interval": None, "alignment_status": None, "extraction_index": 1, "group_index": 0, "description": None, "attributes": None, }, { "extraction_class": "symptom", "extraction_text": "headache", "char_interval": {"start_pos": 24, "end_pos": 32}, "alignment_status": "match_exact", "extraction_index": 2, "group_index": 0, "description": None, "attributes": None, }, ], "text": "Patient Jane reported a headache.", }, ), dict( testcase_name="extraction_with_attributes_and_token_interval", annotated_doc=data.AnnotatedDocument( document_id="docC", text="He has mild chest pain and a cough.", extractions=[ data.Extraction( extraction_class="condition", extraction_text="chest pain", extraction_index=2, group_index=1, attributes={ "severity": "mild", "persistence": "persistent", }, char_interval=data.CharInterval(start_pos=12, end_pos=22), token_interval=tokenizer.TokenInterval( start_index=3, end_index=5 ), # should be ignored alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="symptom", extraction_text="cough", extraction_index=3, group_index=1, ), ], ), expected_dict={ "document_id": "docC", "extractions": [ { "extraction_class": "condition", "extraction_text": "chest pain", "char_interval": {"start_pos": 12, "end_pos": 22}, "alignment_status": "match_exact", "extraction_index": 2, "group_index": 1, "description": None, "attributes": { "severity": "mild", "persistence": "persistent", }, }, { "extraction_class": "symptom", "extraction_text": "cough", "char_interval": None, "alignment_status": None, "extraction_index": 3, "group_index": 1, "description": None, "attributes": None, }, ], "text": "He has mild chest pain and a cough.", }, ), ) def test_annotated_document_to_dict(self, annotated_doc, expected_dict): actual_dict = data_lib.annotated_document_to_dict(annotated_doc) self.assertDictEqual( actual_dict, expected_dict, "annotated_document_to_dict() output differs from expected JSON dict.", ) def test_annotated_document_to_dict_with_int64(self): doc = data.AnnotatedDocument( document_id="doc_int64", text="Sample text with int64 index", extractions=[ data.Extraction( extraction_class="demo_extraction", extraction_text="placeholder", extraction_index=np.int64(42), # pytype: disable=wrong-arg-types ), ], ) doc_dict = data_lib.annotated_document_to_dict(doc) json_str = json.dumps(doc_dict, ensure_ascii=False) self.assertIn('"extraction_index": 42', json_str) class IsUrlTest(absltest.TestCase): """Tests for io.is_url function validation.""" def test_valid_urls(self): """Test that valid URLs are recognized.""" self.assertTrue(io.is_url("http://example.com")) self.assertTrue(io.is_url("https://www.example.com")) self.assertTrue(io.is_url("http://localhost:8080")) self.assertTrue(io.is_url("http://192.168.1.1")) self.assertTrue(io.is_url("http://[2001:db8::1]")) # IPv6 self.assertTrue(io.is_url("http://[::1]:8080")) # IPv6 localhost with port def test_invalid_urls_with_text(self): """Test that URLs with additional text are rejected.""" # Validates fix for issue where text starting with URL was incorrectly fetched self.assertFalse(io.is_url("http://example.com is a website")) self.assertFalse(io.is_url("http://medical-journal.com published a study")) def test_invalid_urls_no_scheme(self): """Test that URLs without proper scheme are rejected.""" self.assertFalse(io.is_url("example.com")) self.assertFalse(io.is_url("www.example.com")) self.assertFalse(io.is_url("ftp://example.com")) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/extract_precedence_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for parameter precedence in extract().""" from unittest import mock from absl.testing import absltest from langextract import factory import langextract as lx from langextract.core import data from langextract.providers import openai class ExtractParameterPrecedenceTest(absltest.TestCase): """Tests ensuring correct precedence among extract() parameters.""" def setUp(self): super().setUp() self.examples = [ data.ExampleData( text="example", extractions=[ data.Extraction( extraction_class="entity", extraction_text="example", ) ], ) ] self.description = "description" @mock.patch("langextract.annotation.Annotator") @mock.patch("langextract.extraction.factory.create_model") def test_model_overrides_all_other_parameters( self, mock_create_model, mock_annotator_cls ): """Test that model parameter overrides all other model-related parameters.""" provided_model = mock.MagicMock() mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" config = factory.ModelConfig(model_id="config-id") result = lx.extract( text_or_documents="text", prompt_description=self.description, examples=self.examples, model=provided_model, config=config, model_id="ignored-model", api_key="ignored-key", language_model_type=openai.OpenAILanguageModel, use_schema_constraints=False, ) mock_create_model.assert_not_called() _, kwargs = mock_annotator_cls.call_args self.assertIs(kwargs["language_model"], provided_model) self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @mock.patch("langextract.extraction.factory.create_model") def test_config_overrides_model_id_and_language_model_type( self, mock_create_model, mock_annotator_cls ): """Test that config parameter overrides model_id and language_model_type.""" config = factory.ModelConfig( model_id="config-model", provider_kwargs={"api_key": "config-key"} ) mock_model = mock.MagicMock() mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" with mock.patch( "langextract.extraction.factory.ModelConfig" ) as mock_model_config: result = lx.extract( text_or_documents="text", prompt_description=self.description, examples=self.examples, config=config, model_id="other-model", api_key="other-key", language_model_type=openai.OpenAILanguageModel, use_schema_constraints=False, ) mock_model_config.assert_not_called() mock_create_model.assert_called_once() called_config = mock_create_model.call_args[1]["config"] self.assertEqual(called_config.model_id, "config-model") self.assertEqual(called_config.provider_kwargs, {"api_key": "config-key"}) _, kwargs = mock_annotator_cls.call_args self.assertIs(kwargs["language_model"], mock_model) self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @mock.patch("langextract.extraction.factory.create_model") def test_model_id_and_base_kwargs_override_language_model_type( self, mock_create_model, mock_annotator_cls ): """Test that model_id and other kwargs are used when no model or config.""" mock_model = mock.MagicMock() mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator_cls.return_value.annotate_text.return_value = "ok" mock_config = mock.MagicMock() with mock.patch( "langextract.extraction.factory.ModelConfig", return_value=mock_config ) as mock_model_config: with self.assertWarns(FutureWarning): result = lx.extract( text_or_documents="text", prompt_description=self.description, examples=self.examples, model_id="model-123", api_key="api-key", temperature=0.9, model_url="http://model", language_model_type=openai.OpenAILanguageModel, use_schema_constraints=False, ) mock_model_config.assert_called_once() _, kwargs = mock_model_config.call_args self.assertEqual(kwargs["model_id"], "model-123") provider_kwargs = kwargs["provider_kwargs"] self.assertEqual(provider_kwargs["api_key"], "api-key") self.assertEqual(provider_kwargs["temperature"], 0.9) self.assertEqual(provider_kwargs["model_url"], "http://model") self.assertEqual(provider_kwargs["base_url"], "http://model") mock_create_model.assert_called_once() self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @mock.patch("langextract.extraction.factory.create_model") def test_language_model_type_only_emits_warning_and_works( self, mock_create_model, mock_annotator_cls ): """Test that language_model_type emits deprecation warning but still works.""" mock_model = mock.MagicMock() mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator_cls.return_value.annotate_text.return_value = "ok" mock_config = mock.MagicMock() with mock.patch( "langextract.extraction.factory.ModelConfig", return_value=mock_config ) as mock_model_config: with self.assertWarns(FutureWarning): result = lx.extract( text_or_documents="text", prompt_description=self.description, examples=self.examples, language_model_type=openai.OpenAILanguageModel, use_schema_constraints=False, ) mock_model_config.assert_called_once() _, kwargs = mock_model_config.call_args self.assertEqual(kwargs["model_id"], "gemini-2.5-flash") mock_create_model.assert_called_once() self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @mock.patch("langextract.extraction.factory.create_model") def test_use_schema_constraints_warns_with_config( self, mock_create_model, mock_annotator_cls ): """Test that use_schema_constraints emits warning when used with config.""" config = factory.ModelConfig( model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test-key"} ) mock_model = mock.MagicMock() mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" with self.assertWarns(UserWarning) as cm: result = lx.extract( text_or_documents="text", prompt_description=self.description, examples=self.examples, config=config, use_schema_constraints=True, ) self.assertIn("schema constraints", str(cm.warning)) self.assertIn("applied", str(cm.warning)) mock_create_model.assert_called_once() called_config = mock_create_model.call_args[1]["config"] self.assertEqual(called_config.model_id, "gemini-2.5-flash") self.assertEqual(result, "ok") @mock.patch("langextract.annotation.Annotator") @mock.patch("langextract.extraction.factory.create_model") def test_use_schema_constraints_warns_with_model( self, mock_create_model, mock_annotator_cls ): """Test that use_schema_constraints emits warning when used with model.""" provided_model = mock.MagicMock() mock_annotator = mock_annotator_cls.return_value mock_annotator.annotate_text.return_value = "ok" with self.assertWarns(UserWarning) as cm: result = lx.extract( text_or_documents="text", prompt_description=self.description, examples=self.examples, model=provided_model, use_schema_constraints=True, ) self.assertIn("use_schema_constraints", str(cm.warning)) self.assertIn("ignored", str(cm.warning)) mock_create_model.assert_not_called() self.assertEqual(result, "ok") if __name__ == "__main__": absltest.main() ================================================ FILE: tests/extract_schema_integration_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Integration tests for extract function with new schema system.""" from unittest import mock import warnings from absl.testing import absltest import langextract as lx from langextract.core import data class ExtractSchemaIntegrationTest(absltest.TestCase): """Tests for extract function with schema system integration.""" def setUp(self): """Set up test fixtures.""" super().setUp() self.examples = [ data.ExampleData( text="Patient has diabetes", extractions=[ data.Extraction( extraction_class="condition", extraction_text="diabetes", attributes={"severity": "moderate"}, ) ], ) ] self.test_text = "Patient has hypertension" @mock.patch.dict("os.environ", {"GEMINI_API_KEY": "test_key"}) def test_extract_with_gemini_uses_schema(self): """Test that extract with Gemini automatically uses schema.""" with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ) as mock_init: with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.infer", return_value=iter([[mock.Mock(output='{"extractions": []}')]]), ): with mock.patch( "langextract.annotation.Annotator.annotate_text", return_value=data.AnnotatedDocument( text=self.test_text, extractions=[] ), ): result = lx.extract( text_or_documents=self.test_text, prompt_description="Extract conditions", examples=self.examples, model_id="gemini-2.5-flash", use_schema_constraints=True, fence_output=None, # Let it compute ) # Should have been called with response_schema call_kwargs = mock_init.call_args[1] self.assertIn("response_schema", call_kwargs) # Result should be an AnnotatedDocument self.assertIsInstance(result, data.AnnotatedDocument) @mock.patch.dict("os.environ", {"OLLAMA_BASE_URL": "http://localhost:11434"}) def test_extract_with_ollama_uses_json_mode(self): """Test that extract with Ollama uses JSON mode.""" with mock.patch( "langextract.providers.ollama.OllamaLanguageModel.__init__", return_value=None, ) as mock_init: with mock.patch( "langextract.providers.ollama.OllamaLanguageModel.infer", return_value=iter([[mock.Mock(output='{"extractions": []}')]]), ): with mock.patch( "langextract.annotation.Annotator.annotate_text", return_value=data.AnnotatedDocument( text=self.test_text, extractions=[] ), ): result = lx.extract( text_or_documents=self.test_text, prompt_description="Extract conditions", examples=self.examples, model_id="gemma2:2b", use_schema_constraints=True, fence_output=None, # Let it compute ) # Should have been called with format="json" call_kwargs = mock_init.call_args[1] self.assertIn("format", call_kwargs) self.assertEqual(call_kwargs["format"], "json") # Result should be an AnnotatedDocument self.assertIsInstance(result, data.AnnotatedDocument) def test_extract_explicit_fence_respected(self): """Test that explicit fence_output is respected in extract.""" with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ): with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.infer", return_value=iter([[mock.Mock(output='{"extractions": []}')]]), ): with mock.patch( "langextract.annotation.Annotator.__init__", return_value=None ) as mock_annotator_init: with mock.patch( "langextract.annotation.Annotator.annotate_text", return_value=data.AnnotatedDocument( text=self.test_text, extractions=[] ), ): _ = lx.extract( text_or_documents=self.test_text, prompt_description="Extract conditions", examples=self.examples, model_id="gemini-2.5-flash", api_key="test_key", use_schema_constraints=True, fence_output=True, # Explicitly set ) # Annotator should be created with format_handler that has use_fences=True call_kwargs = mock_annotator_init.call_args[1] self.assertIn("format_handler", call_kwargs) self.assertTrue(call_kwargs["format_handler"].use_fences) def test_extract_gemini_schema_deprecation_warning(self): """Test that passing gemini_schema triggers deprecation warning.""" with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ): with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.infer", return_value=iter([[mock.Mock(output='{"extractions": []}')]]), ): with mock.patch( "langextract.annotation.Annotator.annotate_text", return_value=data.AnnotatedDocument( text=self.test_text, extractions=[] ), ): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = lx.extract( text_or_documents=self.test_text, prompt_description="Extract conditions", examples=self.examples, model_id="gemini-2.5-flash", api_key="test_key", language_model_params={ "gemini_schema": "some_schema" }, # Deprecated ) # Should have triggered deprecation warning deprecation_warnings = [ warning for warning in w if issubclass(warning.category, FutureWarning) and "gemini_schema" in str(warning.message) ] self.assertGreater(len(deprecation_warnings), 0) def test_extract_no_schema_when_disabled(self): """Test that no schema is used when use_schema_constraints=False.""" # Create a mock instance with required attributes mock_model = mock.MagicMock() mock_model._schema = None mock_model._fence_output_override = None mock_model.gemini_schema = None mock_model.requires_fence_output = True mock_model.infer.return_value = iter( [[mock.Mock(output='{"extractions": []}')]] ) # Track the kwargs passed to the constructor constructor_kwargs = {} def mock_constructor(**kwargs): constructor_kwargs.update(kwargs) return mock_model with mock.patch( "langextract.providers.gemini.GeminiLanguageModel", side_effect=mock_constructor, ): with mock.patch( "langextract.annotation.Annotator.annotate_text", return_value=data.AnnotatedDocument( text=self.test_text, extractions=[] ), ): _ = lx.extract( text_or_documents=self.test_text, prompt_description="Extract conditions", examples=self.examples, model_id="gemini-2.5-flash", api_key="test_key", use_schema_constraints=False, # Disabled ) # Should NOT have response_schema when schema constraints are disabled self.assertNotIn("response_schema", constructor_kwargs) self.assertNotIn("gemini_schema", constructor_kwargs) @mock.patch("langextract.factory.create_model") def test_validation_triggers_warning_for_gemini(self, mock_create_model): """Test that Gemini schema validation triggers warnings.""" # Setup mock model with Gemini schema mock_model = mock.MagicMock() mock_model.requires_fence_output = True mock_model.infer.return_value = [ [mock.MagicMock(output='{"extractions": []}', score=1.0)] ] # Create a mock Gemini schema with validate_format that issues warnings mock_schema = mock.MagicMock() def mock_validate_format(format_handler, level=None): # Simulate the warning that would be issued warnings.warn( "Gemini outputs native JSON via" " response_mime_type='application/json'", UserWarning, stacklevel=3, ) mock_schema.validate_format = mock_validate_format mock_model.schema = mock_schema mock_create_model.return_value = mock_model # Run extraction with warnings captured with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = lx.extract( text_or_documents="Sample text", prompt_description="Extract entities", examples=self.examples, model_id="gemini-pro", api_key="test_key", use_schema_constraints=True, ) # Check that a warning was issued warning_messages = [str(warning.message) for warning in w] self.assertTrue( any("Gemini outputs native JSON" in msg for msg in warning_messages), f"Expected Gemini-specific warning not found in: {warning_messages}", ) # Result should still be returned self.assertIsNotNone(result) @mock.patch("langextract.factory.create_model") def test_no_validation_without_schema(self, mock_create_model): """Test that validation is skipped when no schema is present.""" mock_model = mock.MagicMock() mock_model.requires_fence_output = False mock_model.schema = None # No schema mock_model.infer.return_value = [ [mock.MagicMock(output='{"extractions": []}', score=1.0)] ] mock_create_model.return_value = mock_model with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = lx.extract( text_or_documents="Sample text", prompt_description="Extract", examples=self.examples, model_id="some-model", api_key="key", use_schema_constraints=False, # No schema constraints ) # No format compatibility warnings should be issued warning_messages = [str(warning.message) for warning in w] self.assertFalse( any("Format compatibility" in msg for msg in warning_messages), f"Unexpected format warning found in: {warning_messages}", ) self.assertIsNotNone(result) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/factory_schema_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for factory schema integration and fence defaulting.""" from unittest import mock from absl.testing import absltest from langextract import factory from langextract import schema from langextract.core import base_model from langextract.core import data class FactorySchemaIntegrationTest(absltest.TestCase): """Tests for create_model_with_schema factory function.""" def setUp(self): """Set up test fixtures.""" super().setUp() self.examples = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="test_class", extraction_text="test extraction", ) ], ) ] @mock.patch.dict("os.environ", {"GEMINI_API_KEY": "test_key"}) def test_gemini_with_schema_returns_false_fence(self): """Test that Gemini with schema returns fence_output=False.""" config = factory.ModelConfig( model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} ) with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ) as mock_init: model = factory._create_model_with_schema( config=config, examples=self.examples, use_schema_constraints=True, fence_output=None, ) mock_init.assert_called_once() call_kwargs = mock_init.call_args[1] self.assertIn("response_schema", call_kwargs) self.assertFalse(model.requires_fence_output) @mock.patch.dict("os.environ", {"OLLAMA_BASE_URL": "http://localhost:11434"}) def test_ollama_with_schema_returns_false_fence(self): """Test that Ollama with JSON mode returns fence_output=False.""" config = factory.ModelConfig(model_id="gemma2:2b") with mock.patch( "langextract.providers.ollama.OllamaLanguageModel.__init__", return_value=None, ) as mock_init: model = factory._create_model_with_schema( config=config, examples=self.examples, use_schema_constraints=True, fence_output=None, ) mock_init.assert_called_once() call_kwargs = mock_init.call_args[1] self.assertIn("format", call_kwargs) self.assertEqual(call_kwargs["format"], "json") self.assertFalse(model.requires_fence_output) def test_explicit_fence_output_respected(self): """Test that explicit fence_output is not overridden.""" config = factory.ModelConfig( model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} ) with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ): model = factory._create_model_with_schema( config=config, examples=self.examples, use_schema_constraints=True, fence_output=True, ) self.assertTrue(model.requires_fence_output) def test_no_schema_defaults_to_true_fence(self): """Test that models without schema support default to fence_output=True.""" class NoSchemaModel(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): yield [] config = factory.ModelConfig(model_id="test-model") with mock.patch( "langextract.providers.registry.resolve", return_value=NoSchemaModel ): with mock.patch.object(NoSchemaModel, "__init__", return_value=None): model = factory._create_model_with_schema( config=config, examples=self.examples, use_schema_constraints=True, fence_output=None, ) self.assertTrue(model.requires_fence_output) def test_schema_disabled_returns_true_fence(self): """Test that disabling schema constraints returns fence_output=True.""" config = factory.ModelConfig( model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} ) with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ) as mock_init: model = factory._create_model_with_schema( config=config, examples=self.examples, use_schema_constraints=False, fence_output=None, ) call_kwargs = mock_init.call_args[1] self.assertNotIn("response_schema", call_kwargs) self.assertTrue(model.requires_fence_output) def test_caller_overrides_schema_config(self): """Test that caller's provider_kwargs override schema configuration.""" config = factory.ModelConfig( model_id="gemma2:2b", provider_kwargs={"format": "yaml"}, ) with mock.patch( "langextract.providers.ollama.OllamaLanguageModel.__init__", return_value=None, ) as mock_init: _ = factory._create_model_with_schema( config=config, examples=self.examples, use_schema_constraints=True, fence_output=None, ) mock_init.assert_called_once() call_kwargs = mock_init.call_args[1] self.assertIn("format", call_kwargs) self.assertEqual(call_kwargs["format"], "yaml") def test_no_examples_no_schema(self): """Test that no examples means no schema is created.""" config = factory.ModelConfig( model_id="gemini-2.5-flash", provider_kwargs={"api_key": "test_key"} ) with mock.patch( "langextract.providers.gemini.GeminiLanguageModel.__init__", return_value=None, ) as mock_init: model = factory._create_model_with_schema( config=config, examples=None, use_schema_constraints=True, fence_output=None, ) call_kwargs = mock_init.call_args[1] self.assertNotIn("response_schema", call_kwargs) self.assertTrue(model.requires_fence_output) class SchemaApplicationTest(absltest.TestCase): """Tests for apply_schema being called on models.""" def test_apply_schema_called_when_supported(self): """Test that apply_schema is called on models that support it.""" examples = [ data.ExampleData( text="Test", extractions=[ data.Extraction(extraction_class="test", extraction_text="test") ], ) ] class SchemaAwareModel(base_model.BaseLanguageModel): @classmethod def get_schema_class(cls): return schema.GeminiSchema def infer(self, batch_prompts, **kwargs): yield [] config = factory.ModelConfig(model_id="test-model") with mock.patch( "langextract.providers.registry.resolve", return_value=SchemaAwareModel ): with mock.patch.object(SchemaAwareModel, "__init__", return_value=None): with mock.patch.object(SchemaAwareModel, "apply_schema") as mock_apply: _ = factory._create_model_with_schema( config=config, examples=examples, use_schema_constraints=True, ) mock_apply.assert_called_once() schema_arg = mock_apply.call_args[0][0] self.assertIsInstance(schema_arg, schema.GeminiSchema) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/factory_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the factory module. Note: This file tests the deprecated registry module which is now an alias for router. The no-name-in-module warning for providers.registry is expected. """ # pylint: disable=no-name-in-module import os from unittest import mock from absl.testing import absltest from langextract import exceptions from langextract import factory from langextract.core import base_model from langextract.core import types from langextract.providers import router class FakeGeminiProvider(base_model.BaseLanguageModel): """Fake Gemini provider for testing.""" def __init__(self, model_id, api_key=None, **kwargs): self.model_id = model_id self.api_key = api_key self.kwargs = kwargs super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="gemini")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) class FakeOpenAIProvider(base_model.BaseLanguageModel): """Fake OpenAI provider for testing.""" def __init__(self, model_id, api_key=None, **kwargs): if not api_key: raise ValueError("API key required") self.model_id = model_id self.api_key = api_key self.kwargs = kwargs super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="openai")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) class FactoryTest(absltest.TestCase): # pylint: disable=too-many-public-methods def setUp(self): super().setUp() router.clear() import langextract.providers as providers_module # pylint: disable=import-outside-toplevel providers_module._plugins_loaded = True # Use direct registration for test providers to avoid module path issues router.register(r"^gemini", priority=100)(FakeGeminiProvider) router.register(r"^gpt", r"^o1", priority=100)(FakeOpenAIProvider) def tearDown(self): super().tearDown() router.clear() import langextract.providers as providers_module # pylint: disable=import-outside-toplevel providers_module._plugins_loaded = False def test_create_model_basic(self): """Test basic model creation.""" config = factory.ModelConfig( model_id="gemini-pro", provider_kwargs={"api_key": "test-key"} ) model = factory.create_model(config) self.assertIsInstance(model, FakeGeminiProvider) self.assertEqual(model.model_id, "gemini-pro") self.assertEqual(model.api_key, "test-key") def test_create_model_from_id(self): """Test convenience function for creating model from ID.""" model = factory.create_model_from_id("gemini-flash", api_key="test-key") self.assertIsInstance(model, FakeGeminiProvider) self.assertEqual(model.model_id, "gemini-flash") self.assertEqual(model.api_key, "test-key") @mock.patch.dict(os.environ, {"GEMINI_API_KEY": "env-gemini-key"}) def test_uses_gemini_api_key_from_environment(self): """Factory should use GEMINI_API_KEY from environment for Gemini models.""" config = factory.ModelConfig(model_id="gemini-pro") model = factory.create_model(config) self.assertEqual(model.api_key, "env-gemini-key") @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "env-openai-key"}) def test_uses_openai_api_key_from_environment(self): """Factory should use OPENAI_API_KEY from environment for OpenAI models.""" config = factory.ModelConfig(model_id="gpt-4") model = factory.create_model(config) self.assertEqual(model.api_key, "env-openai-key") @mock.patch.dict( os.environ, {"LANGEXTRACT_API_KEY": "env-langextract-key"}, clear=True ) def test_falls_back_to_langextract_api_key_when_provider_key_missing(self): """Factory uses LANGEXTRACT_API_KEY when provider-specific key is missing.""" config = factory.ModelConfig(model_id="gemini-pro") model = factory.create_model(config) self.assertEqual(model.api_key, "env-langextract-key") @mock.patch.dict( os.environ, { "GEMINI_API_KEY": "gemini-key", "LANGEXTRACT_API_KEY": "langextract-key", }, ) def test_provider_specific_key_takes_priority_over_langextract_key(self): """Factory prefers provider-specific API key over LANGEXTRACT_API_KEY.""" config = factory.ModelConfig(model_id="gemini-pro") model = factory.create_model(config) self.assertEqual(model.api_key, "gemini-key") def test_explicit_kwargs_override_env(self): """Test that explicit kwargs override environment variables.""" with mock.patch.dict(os.environ, {"GEMINI_API_KEY": "env-key"}): config = factory.ModelConfig( model_id="gemini-pro", provider_kwargs={"api_key": "explicit-key"} ) model = factory.create_model(config) self.assertEqual(model.api_key, "explicit-key") @mock.patch.dict(os.environ, {}, clear=True) def test_wraps_provider_initialization_error_in_inference_config_error(self): """Factory should wrap provider errors in InferenceConfigError.""" config = factory.ModelConfig(model_id="gpt-4") with self.assertRaises(exceptions.InferenceConfigError) as cm: factory.create_model(config) self.assertIn("Failed to create provider", str(cm.exception)) self.assertIn("API key required", str(cm.exception)) def test_raises_error_when_no_provider_matches_model_id(self): """Factory should raise InferenceConfigError for unregistered model IDs.""" config = factory.ModelConfig(model_id="unknown-model") with self.assertRaises(exceptions.InferenceConfigError) as cm: factory.create_model(config) self.assertIn("No provider registered", str(cm.exception)) def test_additional_kwargs_passed_through(self): """Test that additional kwargs are passed to provider.""" config = factory.ModelConfig( model_id="gemini-pro", provider_kwargs={ "api_key": "test-key", "temperature": 0.5, "max_tokens": 100, "custom_param": "value", }, ) model = factory.create_model(config) self.assertEqual(model.kwargs["temperature"], 0.5) self.assertEqual(model.kwargs["max_tokens"], 100) self.assertEqual(model.kwargs["custom_param"], "value") @mock.patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://custom:11434"}) def test_ollama_uses_base_url_from_environment(self): """Factory should use OLLAMA_BASE_URL from environment for Ollama models.""" @router.register(r"^ollama") class FakeOllamaProvider(base_model.BaseLanguageModel): # pylint: disable=unused-variable def __init__(self, model_id, base_url=None, **kwargs): self.model_id = model_id self.base_url = base_url super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="ollama")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) config = factory.ModelConfig(model_id="ollama/llama2") model = factory.create_model(config) self.assertEqual(model.base_url, "http://custom:11434") def test_ollama_models_select_without_api_keys(self): """Test that Ollama models resolve without API keys or explicit type.""" @router.register(r"^llama", r"^gemma", r"^mistral", r"^qwen", priority=100) class FakeOllamaProvider(base_model.BaseLanguageModel): def __init__(self, model_id, **kwargs): self.model_id = model_id super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="test")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) test_models = ["llama3", "gemma2:2b", "mistral:7b", "qwen3:0.6b"] for model_id in test_models: with self.subTest(model_id=model_id): with mock.patch.dict(os.environ, {}, clear=True): config = factory.ModelConfig(model_id=model_id) model = factory.create_model(config) self.assertIsInstance(model, FakeOllamaProvider) self.assertEqual(model.model_id, model_id) def test_model_config_fields_are_immutable(self): """ModelConfig fields should not be modifiable after creation.""" config = factory.ModelConfig( model_id="gemini-pro", provider_kwargs={"api_key": "test"} ) with self.assertRaises(AttributeError): config.model_id = "different" def test_model_config_allows_dict_contents_modification(self): """ModelConfig allows modification of dict contents (not deeply frozen).""" config = factory.ModelConfig( model_id="gemini-pro", provider_kwargs={"api_key": "test"} ) config.provider_kwargs["new_key"] = "value" self.assertEqual(config.provider_kwargs["new_key"], "value") def test_uses_highest_priority_provider_when_multiple_match(self): """Factory uses highest priority provider when multiple patterns match.""" @router.register(r"^gemini", priority=90) class AnotherGeminiProvider(base_model.BaseLanguageModel): # pylint: disable=unused-variable def __init__(self, model_id=None, **kwargs): self.model_id = model_id or "default-model" self.kwargs = kwargs super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="another")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) config = factory.ModelConfig(model_id="gemini-pro") model = factory.create_model(config) self.assertIsInstance(model, FakeGeminiProvider) # Priority 100 wins def test_explicit_provider_overrides_pattern_matching(self): """Factory should use explicit provider even when pattern doesn't match.""" @router.register(r"^another", priority=90) class AnotherProvider(base_model.BaseLanguageModel): def __init__(self, model_id=None, **kwargs): self.model_id = model_id or "default-model" self.kwargs = kwargs super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="another")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) config = factory.ModelConfig( model_id="gemini-pro", provider="AnotherProvider" ) model = factory.create_model(config) self.assertIsInstance(model, AnotherProvider) self.assertEqual(model.model_id, "gemini-pro") def test_provider_without_model_id_uses_provider_default(self): """Factory should use provider's default model_id when none specified.""" @router.register(r"^default-provider$", priority=50) class DefaultProvider(base_model.BaseLanguageModel): def __init__(self, model_id="default-model", **kwargs): self.model_id = model_id self.kwargs = kwargs super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="default")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) config = factory.ModelConfig(provider="DefaultProvider") model = factory.create_model(config) self.assertIsInstance(model, DefaultProvider) self.assertEqual(model.model_id, "default-model") def test_raises_error_when_neither_model_id_nor_provider_specified(self): """Factory raises ValueError when config has neither model_id nor provider.""" config = factory.ModelConfig() with self.assertRaises(ValueError) as cm: factory.create_model(config) self.assertIn( "Either model_id or provider must be specified", str(cm.exception) ) def test_gemini_vertexai_parameters_accepted(self): """Test that Vertex AI parameters are properly passed to Gemini provider.""" original_entries = router._entries.copy() # pylint: disable=protected-access original_keys = router._entry_keys.copy() # pylint: disable=protected-access try: @router.register(r"^gemini", priority=200) class MockGeminiWithVertexAI(base_model.BaseLanguageModel): # pylint: disable=unused-variable def __init__( self, model_id="gemini-2.5-flash", api_key=None, vertexai=False, credentials=None, project=None, location=None, **kwargs, ): self.model_id = model_id self.api_key = api_key self.vertexai = vertexai self.credentials = credentials self.project = project self.location = location super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="vertexai-test")]] config = factory.ModelConfig( model_id="gemini-pro", provider_kwargs={ "vertexai": True, "project": "test-project", "location": "us-central1", }, ) model = factory.create_model(config) self.assertTrue(model.vertexai) self.assertEqual(model.project, "test-project") self.assertEqual(model.location, "us-central1") self.assertIsNone(model.api_key) finally: router._entries = original_entries # pylint: disable=protected-access router._entry_keys = original_keys # pylint: disable=protected-access def test_gemini_vertexai_with_credentials(self): """Test that Vertex AI credentials can be passed through.""" original_entries = router._entries.copy() # pylint: disable=protected-access original_keys = router._entry_keys.copy() # pylint: disable=protected-access try: @router.register(r"^gemini", priority=200) class MockGeminiWithCredentials(base_model.BaseLanguageModel): # pylint: disable=unused-variable def __init__( self, model_id="gemini-2.5-flash", credentials=None, **kwargs ): self.model_id = model_id self.credentials = credentials super().__init__() def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="creds-test")]] mock_credentials = {"type": "service_account"} # Simplified mock config = factory.ModelConfig( model_id="gemini-2.5-flash", provider_kwargs={"credentials": mock_credentials}, ) model = factory.create_model(config) self.assertEqual(model.credentials, mock_credentials) finally: router._entries = original_entries # pylint: disable=protected-access router._entry_keys = original_keys # pylint: disable=protected-access if __name__ == "__main__": absltest.main() ================================================ FILE: tests/format_handler_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for centralized format handler.""" import textwrap from absl.testing import absltest from absl.testing import parameterized from langextract import prompting from langextract import resolver from langextract.core import data from langextract.core import format_handler class FormatHandlerTest(parameterized.TestCase): """Tests for FormatHandler.""" @parameterized.named_parameters( dict( testcase_name="json_with_wrapper_and_fences", format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=True, extraction_class="person", extraction_text="Alice", attributes={"role": "engineer"}, expected_fence="```json", expected_wrapper='"extractions":', expected_extraction='"person": "Alice"', model_output=textwrap.dedent(""" Here is the result: ```json { "extractions": [ {"person": "Bob", "person_attributes": {"role": "manager"}} ] } ``` """).strip(), parsed_class="person", parsed_text="Bob", ), dict( testcase_name="json_no_wrapper_no_fences", format_type=data.FormatType.JSON, use_wrapper=False, wrapper_key=None, use_fences=False, extraction_class="item", extraction_text="book", attributes=None, expected_fence=None, expected_wrapper=None, expected_extraction='"item": "book"', model_output='[{"item": "pen", "item_attributes": {}}]', parsed_class="item", parsed_text="pen", ), dict( testcase_name="yaml_with_wrapper_and_fences", format_type=data.FormatType.YAML, use_wrapper=True, wrapper_key="extractions", use_fences=True, extraction_class="city", extraction_text="Paris", attributes=None, expected_fence="```yaml", expected_wrapper="extractions:", expected_extraction="city: Paris", model_output=textwrap.dedent(""" ```yaml extractions: - city: London city_attributes: {} ``` """).strip(), parsed_class="city", parsed_text="London", ), ) def test_format_and_parse( # pylint: disable=too-many-arguments self, format_type, use_wrapper, wrapper_key, use_fences, extraction_class, extraction_text, attributes, expected_fence, expected_wrapper, expected_extraction, model_output, parsed_class, parsed_text, ): """Test formatting and parsing with various configurations.""" handler = format_handler.FormatHandler( format_type=format_type, use_wrapper=use_wrapper, wrapper_key=wrapper_key, use_fences=use_fences, ) extractions = [ data.Extraction( extraction_class=extraction_class, extraction_text=extraction_text, attributes=attributes, ) ] formatted = handler.format_extraction_example(extractions) if expected_fence: self.assertIn(expected_fence, formatted) else: self.assertNotIn("```", formatted) if expected_wrapper: self.assertIn(expected_wrapper, formatted) else: if wrapper_key: self.assertNotIn(wrapper_key, formatted) self.assertIn(expected_extraction, formatted) parsed = handler.parse_output(model_output) self.assertLen(parsed, 1) self.assertEqual(parsed[0][parsed_class], parsed_text) def test_end_to_end_integration_with_prompt_and_resolver(self): """Test that FormatHandler unifies prompt generation and parsing.""" handler = format_handler.FormatHandler( format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=True, ) template = prompting.PromptTemplateStructured( description="Extract entities from text.", examples=[ data.ExampleData( text="Alice is an engineer", extractions=[ data.Extraction( extraction_class="person", extraction_text="Alice", attributes={"role": "engineer"}, ) ], ) ], ) prompt_gen = prompting.QAPromptGenerator( template=template, format_handler=handler, ) prompt = prompt_gen.render("Bob is a manager") self.assertIn("```json", prompt, "Prompt should contain JSON fence") self.assertIn('"extractions":', prompt, "Prompt should contain wrapper key") test_resolver = resolver.Resolver( format_handler=handler, extraction_index_suffix=None, ) model_output = textwrap.dedent(""" ```json { "extractions": [ { "person": "Bob", "person_attributes": {"role": "manager"} } ] } ``` """).strip() extractions = test_resolver.resolve(model_output) self.assertLen(extractions, 1, "Should extract exactly one entity") self.assertEqual( extractions[0].extraction_class, "person", "Extraction class should be 'person'", ) self.assertEqual( extractions[0].extraction_text, "Bob", "Extraction text should be 'Bob'" ) @parameterized.named_parameters( dict( testcase_name="yaml_no_wrapper_no_fences", format_type=data.FormatType.YAML, use_wrapper=False, use_fences=False, ), dict( testcase_name="json_with_wrapper_and_fences", format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=True, ), dict( testcase_name="yaml_with_wrapper_no_fences", format_type=data.FormatType.YAML, use_wrapper=True, wrapper_key="extractions", use_fences=False, ), ) def test_format_parse_roundtrip( self, format_type, use_wrapper, use_fences, wrapper_key=None ): """Test that what we format can be parsed back identically.""" handler = format_handler.FormatHandler( format_type=format_type, use_wrapper=use_wrapper, wrapper_key=wrapper_key, use_fences=use_fences, ) extractions = [ data.Extraction( extraction_class="test", extraction_text="value", attributes={"key": "data"}, ) ] formatted = handler.format_extraction_example(extractions) parsed = handler.parse_output(formatted) self.assertEqual(parsed[0]["test"], "value") self.assertEqual(parsed[0]["test_attributes"]["key"], "data") class NonGeminiModelParsingTest(parameterized.TestCase): """Regression tests for non-Gemini model parsing edge cases.""" def test_think_tags_stripped_before_parsing(self): # Reasoning models output tags before JSON handler = format_handler.FormatHandler( format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=False, ) input_with_think = ( "Let me analyze this text..." '{"extractions": [{"person": "Alice"}]}' ) parsed = handler.parse_output(input_with_think) self.assertLen(parsed, 1) self.assertEqual(parsed[0]["person"], "Alice") def test_top_level_list_accepted_as_fallback(self): # Some models return [...] instead of {"extractions": [...]} handler = format_handler.FormatHandler( format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=False, ) input_list = '[{"person": "Bob"}, {"person": "Carol"}]' parsed = handler.parse_output(input_list) self.assertLen(parsed, 2) self.assertEqual(parsed[0]["person"], "Bob") self.assertEqual(parsed[1]["person"], "Carol") def test_deepseek_r1_real_output(self): # Real output captured from DeepSeek-R1:1.5b model handler = format_handler.FormatHandler( format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=False, ) deepseek_output = textwrap.dedent("""\ Alright, so I need to extract people from the given text. I see John Smith is mentioned as an engineer. {"extractions": [{"person": "John Smith"}]}""") parsed = handler.parse_output(deepseek_output) self.assertLen(parsed, 1) self.assertEqual(parsed[0]["person"], "John Smith") if __name__ == "__main__": absltest.main() ================================================ FILE: tests/inference_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for inference module. Note: This file contains test helper classes that intentionally have few public methods and define attributes outside __init__. These pylint warnings are expected for test fixtures. """ # pylint: disable=attribute-defined-outside-init from unittest import mock from absl.testing import absltest from absl.testing import parameterized from langextract import exceptions from langextract.core import base_model from langextract.core import data from langextract.core import types from langextract.providers import gemini from langextract.providers import ollama from langextract.providers import openai class TestBaseLanguageModel(absltest.TestCase): def test_merge_kwargs_with_none(self): """Test merge_kwargs handles None runtime_kwargs.""" class TestModel(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): return iter([]) model = TestModel() model._extra_kwargs = {"a": 1, "b": 2} result = model.merge_kwargs(None) self.assertEqual( {"a": 1, "b": 2}, result, "merge_kwargs(None) should return stored kwargs unchanged", ) result = model.merge_kwargs({}) self.assertEqual( {"a": 1, "b": 2}, result, "merge_kwargs({}) should return stored kwargs unchanged", ) result = model.merge_kwargs({"b": 3, "c": 4}) self.assertEqual( {"a": 1, "b": 3, "c": 4}, result, "Runtime kwargs should override stored kwargs and add new keys", ) def test_merge_kwargs_without_extra_kwargs(self): """Test merge_kwargs when _extra_kwargs doesn't exist.""" class TestModel(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): return iter([]) model = TestModel() # Intentionally not setting _extra_kwargs to test fallback behavior result = model.merge_kwargs({"a": 1}) self.assertEqual( {"a": 1}, result, "merge_kwargs should work even without _extra_kwargs attribute", ) class TestOllamaLanguageModel(absltest.TestCase): @mock.patch("langextract.providers.ollama.OllamaLanguageModel._ollama_query") def test_ollama_infer(self, mock_ollama_query): # Real gemma2 response structure from Ollama API for validation gemma_response = { "model": "gemma2:latest", "created_at": "2025-01-23T22:37:08.579440841Z", "response": "{'bus' : '**autóbusz**'} \n\n\n \n", "done": True, "done_reason": "stop", "context": [ 106, 1645, 108, 1841, 603, 1986, 575, 59672, 235336, 107, 108, 106, 2516, 108, 9766, 6710, 235281, 865, 664, 688, 7958, 235360, 6710, 235306, 688, 12990, 235248, 110, 139, 108, ], "total_duration": 24038204381, "load_duration": 21551375738, "prompt_eval_count": 15, "prompt_eval_duration": 633000000, "eval_count": 17, "eval_duration": 1848000000, } mock_ollama_query.return_value = gemma_response model = ollama.OllamaLanguageModel( model_id="gemma2:latest", model_url="http://localhost:11434", structured_output_format="json", ) batch_prompts = ["What is bus in Hungarian?"] results = list(model.infer(batch_prompts)) mock_ollama_query.assert_called_once_with( prompt="What is bus in Hungarian?", model="gemma2:latest", structured_output_format="json", model_url="http://localhost:11434", ) expected_results = [[ types.ScoredOutput( score=1.0, output="{'bus' : '**autóbusz**'} \n\n\n \n" ) ]] self.assertEqual(results, expected_results) @mock.patch("requests.post") def test_ollama_extra_kwargs_passed_to_api(self, mock_post): """Verify extra kwargs like timeout and keep_alive are passed to the API.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = { "response": '{"test": "value"}', "done": True, } mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id="test-model", timeout=300, keep_alive=600, num_threads=8, ) prompts = ["Test prompt"] list(model.infer(prompts)) mock_post.assert_called_once() call_args = mock_post.call_args json_payload = call_args.kwargs["json"] self.assertEqual(json_payload["options"]["keep_alive"], 600) self.assertEqual(json_payload["options"]["num_thread"], 8) # timeout is passed to requests.post, not in the JSON payload self.assertEqual(call_args.kwargs["timeout"], 300) @mock.patch("requests.post") def test_ollama_stop_and_top_p_passthrough(self, mock_post): """Verify stop and top_p parameters are passed to Ollama API.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = { "response": '{"test": "value"}', "done": True, } mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id="test-model", top_p=0.9, stop=["\\n\\n", "END"], ) prompts = ["Test prompt"] list(model.infer(prompts)) mock_post.assert_called_once() call_args = mock_post.call_args json_payload = call_args.kwargs["json"] # Ollama expects 'stop' at top level, not in options self.assertEqual(json_payload["stop"], ["\\n\\n", "END"]) self.assertEqual(json_payload["options"]["top_p"], 0.9) @mock.patch("requests.post") def test_ollama_defaults_when_unspecified(self, mock_post): """Verify Ollama uses correct defaults when parameters are not specified.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = { "response": '{"test": "value"}', "done": True, } mock_post.return_value = mock_response model = ollama.OllamaLanguageModel(model_id="test-model") prompts = ["Test prompt"] list(model.infer(prompts)) mock_post.assert_called_once() call_args = mock_post.call_args json_payload = call_args.kwargs["json"] self.assertEqual(json_payload["options"]["temperature"], 0.1) self.assertEqual(json_payload["options"]["keep_alive"], 300) self.assertEqual(json_payload["options"]["num_ctx"], 2048) self.assertEqual(call_args.kwargs["timeout"], 120) @mock.patch("requests.post") def test_ollama_runtime_kwargs_override_stored(self, mock_post): """Verify runtime kwargs override stored kwargs.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = { "response": '{"test": "value"}', "done": True, } mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id="test-model", temperature=0.5, keep_alive=300, ) prompts = ["Test prompt"] list(model.infer(prompts, temperature=0.8, keep_alive=600)) mock_post.assert_called_once() call_args = mock_post.call_args json_payload = call_args.kwargs["json"] self.assertEqual(json_payload["options"]["temperature"], 0.8) self.assertEqual(json_payload["options"]["keep_alive"], 600) @mock.patch("requests.post") def test_ollama_temperature_zero(self, mock_post): """Test that temperature=0.0 is properly passed to Ollama.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = { "response": '{"test": "value"}', "done": True, } mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id="test-model", temperature=0.0, ) list(model.infer(["test prompt"])) mock_post.assert_called_once() call_args = mock_post.call_args json_payload = call_args.kwargs["json"] self.assertEqual(json_payload["options"]["temperature"], 0.0) def test_ollama_default_timeout(self): """Test that default timeout is used when not specified.""" model = ollama.OllamaLanguageModel( model_id="test-model", model_url="http://localhost:11434", ) mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": "test output"} with mock.patch.object( model._requests, "post", return_value=mock_response ) as mock_post: model._ollama_query(prompt="test prompt") mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] self.assertEqual( 120, call_kwargs["timeout"], "Should use default timeout of 120 seconds", ) def test_ollama_timeout_through_infer(self): """Test that timeout flows correctly through the infer() method.""" model = ollama.OllamaLanguageModel( model_id="test-model", model_url="http://localhost:11434", timeout=60, ) mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": "test output"} with mock.patch.object( model._requests, "post", return_value=mock_response ) as mock_post: list(model.infer(["test prompt"])) mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] self.assertEqual( 60, call_kwargs["timeout"], "Timeout from constructor should flow through infer()", ) class TestGeminiLanguageModel(absltest.TestCase): @mock.patch("google.genai.Client") def test_gemini_allowlist_filtering(self, mock_client_class): """Test that only allow-listed keys are passed through.""" mock_client = mock.Mock() mock_client_class.return_value = mock_client mock_response = mock.Mock() mock_response.text = '{"result": "test"}' mock_client.models.generate_content.return_value = mock_response model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="test-key", # Allow-listed parameters tools=["tool1", "tool2"], stop_sequences=["\n\n"], system_instruction="Be helpful", # Unknown parameters to test filtering unknown_param="should_be_ignored", another_unknown="also_ignored", ) expected_extra_kwargs = { "tools": ["tool1", "tool2"], "stop_sequences": ["\n\n"], "system_instruction": "Be helpful", } self.assertEqual( expected_extra_kwargs, model._extra_kwargs, "Only allow-listed kwargs should be stored in _extra_kwargs", ) prompts = ["Test prompt"] list(model.infer(prompts)) mock_client.models.generate_content.assert_called_once() call_args = mock_client.models.generate_content.call_args config = call_args.kwargs["config"] for key in ["tools", "stop_sequences", "system_instruction"]: self.assertIn(key, config, f"Expected {key} to be in API config") self.assertEqual( expected_extra_kwargs[key], config[key], f"Config value for {key} should match what was provided", ) @mock.patch("google.genai.Client") def test_gemini_runtime_kwargs_filtered(self, mock_client_class): """Test that runtime kwargs are also filtered by allow-list.""" mock_client = mock.Mock() mock_client_class.return_value = mock_client mock_response = mock.Mock() mock_response.text = '{"result": "test"}' mock_client.models.generate_content.return_value = mock_response model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="test-key", ) prompts = ["Test prompt"] list( model.infer( prompts, candidate_count=2, safety_settings={"HARM_CATEGORY_DANGEROUS": "BLOCK_NONE"}, unknown_runtime_param="ignored", ) ) call_args = mock_client.models.generate_content.call_args config = call_args.kwargs["config"] self.assertEqual( 2, config.get("candidate_count"), "candidate_count should be passed through to API", ) self.assertEqual( {"HARM_CATEGORY_DANGEROUS": "BLOCK_NONE"}, config.get("safety_settings"), "safety_settings should be passed through to API", ) self.assertNotIn( "unknown_runtime_param", config, "Unknown kwargs should be filtered out" ) def test_gemini_requires_auth_config(self): """Test that Gemini requires either API key or Vertex AI config.""" with self.assertRaises(exceptions.InferenceConfigError) as cm: gemini.GeminiLanguageModel() self.assertIn("Gemini models require either", str(cm.exception)) self.assertIn("API key", str(cm.exception)) self.assertIn("Vertex AI", str(cm.exception)) def test_gemini_vertexai_requires_project_and_location(self): """Test that Vertex AI mode requires both project and location.""" with self.assertRaises(exceptions.InferenceConfigError) as cm: gemini.GeminiLanguageModel(vertexai=True) self.assertIn("requires both project and location", str(cm.exception)) @mock.patch("google.genai.Client") def test_gemini_vertexai_initialization(self, mock_client_class): """Test successful initialization with Vertex AI config.""" mock_client = mock.Mock() mock_client_class.return_value = mock_client model = gemini.GeminiLanguageModel( vertexai=True, project="test-project", location="us-central1" ) self.assertIsNone(model.api_key) self.assertTrue(model.vertexai) self.assertEqual(model.project, "test-project") self.assertEqual(model.location, "us-central1") mock_client_class.assert_called_once_with( api_key=None, vertexai=True, credentials=None, project="test-project", location="us-central1", http_options=None, ) @mock.patch("absl.logging.warning") @mock.patch("google.genai.Client") def test_gemini_warns_when_both_auth_provided( self, mock_client_class, mock_warning ): """Test that warning is logged when both API key and Vertex AI are provided.""" mock_client = mock.Mock() mock_client_class.return_value = mock_client gemini.GeminiLanguageModel( api_key="test-key", vertexai=True, project="test-project", location="us-central1", ) mock_warning.assert_called_once() warning_msg = mock_warning.call_args[0][0] self.assertIn("Both API key and Vertex AI", warning_msg) self.assertIn("API key will take precedence", warning_msg) @mock.patch("google.genai.Client") def test_gemini_vertexai_with_http_options(self, mock_client_class): """Test that http_options are passed to genai.Client for VPC endpoints.""" mock_client = mock.Mock() mock_client_class.return_value = mock_client http_options = {"base_url": "https://custom-vpc.p.googleapis.com"} model = gemini.GeminiLanguageModel( vertexai=True, project="test-project", location="us-central1", http_options=http_options, ) self.assertEqual(model.http_options, http_options) mock_client_class.assert_called_once_with( api_key=None, vertexai=True, credentials=None, project="test-project", location="us-central1", http_options=http_options, ) class TestOpenAILanguageModelInference(parameterized.TestCase): @parameterized.named_parameters( ("without", "test-api-key", None, "gpt-4o-mini", 0.5), ("with", "test-api-key", "http://127.0.0.1:9001/v1", "gpt-4o-mini", 0.5), ) @mock.patch("openai.OpenAI") def test_openai_infer_with_parameters( self, api_key, base_url, model_id, temperature, mock_openai_class ): mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"name": "John", "age": 30}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id=model_id, api_key=api_key, base_url=base_url, temperature=temperature, ) batch_prompts = ["Extract name and age from: John is 30 years old"] results = list(model.infer(batch_prompts)) # JSON format adds a system message; only explicitly set params are passed mock_client.chat.completions.create.assert_called_once() call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs["model"], "gpt-4o-mini") self.assertEqual(call_args.kwargs["temperature"], temperature) self.assertEqual(call_args.kwargs["n"], 1) self.assertEqual(len(call_args.kwargs["messages"]), 2) self.assertEqual(call_args.kwargs["messages"][0]["role"], "system") self.assertEqual(call_args.kwargs["messages"][1]["role"], "user") expected_results = [ [types.ScoredOutput(score=1.0, output='{"name": "John", "age": 30}')] ] self.assertEqual(results, expected_results) class TestOpenAILanguageModel(absltest.TestCase): def test_openai_parse_output_json(self): model = openai.OpenAILanguageModel( api_key="test-key", format_type=data.FormatType.JSON ) output = '{"key": "value", "number": 42}' parsed = model.parse_output(output) self.assertEqual(parsed, {"key": "value", "number": 42}) with self.assertRaises(ValueError) as context: model.parse_output("invalid json") self.assertIn("Failed to parse output as JSON", str(context.exception)) def test_openai_parse_output_yaml(self): model = openai.OpenAILanguageModel( api_key="test-key", format_type=data.FormatType.YAML ) output = "key: value\nnumber: 42" parsed = model.parse_output(output) self.assertEqual(parsed, {"key": "value", "number": 42}) with self.assertRaises(ValueError) as context: model.parse_output("invalid: yaml: bad") self.assertIn("Failed to parse output as YAML", str(context.exception)) def test_openai_no_api_key_raises_error(self): with self.assertRaises(exceptions.InferenceConfigError) as context: openai.OpenAILanguageModel(api_key=None) self.assertEqual(str(context.exception), "API key not provided.") @mock.patch("openai.OpenAI") def test_openai_extra_kwargs_passed(self, mock_openai_class): """Test that extra kwargs are passed to OpenAI API.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( api_key="test-key", frequency_penalty=0.5, presence_penalty=0.7, seed=42, ) list(model.infer(["test prompt"])) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs["frequency_penalty"], 0.5) self.assertEqual(call_args.kwargs["presence_penalty"], 0.7) self.assertEqual(call_args.kwargs["seed"], 42) @mock.patch("openai.OpenAI") def test_openai_runtime_kwargs_override(self, mock_openai_class): """Test that runtime kwargs override stored kwargs.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( api_key="test-key", temperature=0.5, seed=123, ) list(model.infer(["test prompt"], temperature=0.8, seed=456)) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs["temperature"], 0.8) self.assertEqual(call_args.kwargs["seed"], 456) @mock.patch("openai.OpenAI") def test_openai_json_response_format(self, mock_openai_class): """Test that JSON format adds response_format parameter.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( api_key="test-key", format_type=data.FormatType.JSON ) list(model.infer(["test prompt"])) mock_client.chat.completions.create.assert_called_once() call_args = mock_client.chat.completions.create.call_args self.assertEqual( call_args.kwargs["response_format"], {"type": "json_object"} ) @mock.patch("openai.OpenAI") def test_openai_temperature_zero(self, mock_openai_class): """Verify temperature=0.0 is properly passed to the API.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel(api_key="test-key", temperature=0.0) list(model.infer(["test prompt"])) mock_client.chat.completions.create.assert_called_once() call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs["temperature"], 0.0) self.assertEqual(call_args.kwargs["model"], "gpt-4o-mini") self.assertEqual(call_args.kwargs["n"], 1) @mock.patch("openai.OpenAI") def test_openai_temperature_none_not_sent(self, mock_openai_class): """Test that temperature=None is not sent to the API.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response # Test with temperature=None in model init model = openai.OpenAILanguageModel( api_key="test-key", temperature=None, ) list(model.infer(["test prompt"])) call_args = mock_client.chat.completions.create.call_args self.assertNotIn("temperature", call_args.kwargs) @mock.patch("openai.OpenAI") def test_openai_none_values_filtered(self, mock_openai_class): """Test that None values are not passed to the API.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( api_key="test-key", top_p=0.9, ) list(model.infer(["test prompt"], top_p=None, seed=None)) call_args = mock_client.chat.completions.create.call_args self.assertNotIn("top_p", call_args.kwargs) self.assertNotIn("seed", call_args.kwargs) @mock.patch("openai.OpenAI") def test_openai_no_system_message_when_not_json_yaml(self, mock_openai_class): """Test that no system message is sent when format_type is not JSON/YAML.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content="test output")) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( api_key="test-key", format_type=None, ) list(model.infer(["test prompt"])) call_args = mock_client.chat.completions.create.call_args messages = call_args.kwargs["messages"] self.assertEqual(len(messages), 1) self.assertEqual(messages[0]["role"], "user") self.assertEqual(messages[0]["content"], "test prompt") @mock.patch("google.genai.Client") def test_gemini_none_values_filtered(self, mock_client_class): """Test that None values are not passed to Gemini API.""" mock_client = mock.Mock() mock_client_class.return_value = mock_client mock_response = mock.Mock() mock_response.text = '{"result": "test"}' mock_client.models.generate_content.return_value = mock_response model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="test-key", ) list(model.infer(["test prompt"], candidate_count=None)) call_args = mock_client.models.generate_content.call_args config = call_args.kwargs["config"] self.assertNotIn("candidate_count", config) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/init_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the main package functions in __init__.py.""" import textwrap from unittest import mock import warnings from absl.testing import absltest from absl.testing import parameterized from langextract import prompting import langextract as lx from langextract.core import base_model from langextract.core import data from langextract.core import format_handler as fh from langextract.core import schema from langextract.core import types from langextract.providers import schemas class InitTest(parameterized.TestCase): """Test cases for the main package functions.""" @mock.patch.object( schemas.gemini.GeminiSchema, "from_examples", autospec=True ) @mock.patch("langextract.extraction.factory.create_model") def test_lang_extract_as_lx_extract( self, mock_create_model, mock_gemini_schema ): input_text = "Patient takes Aspirin 100mg every morning." mock_model = mock.MagicMock() mock_model.infer.return_value = [[ types.ScoredOutput( output=textwrap.dedent("""\ ```json { "extractions": [ { "entity": "Aspirin", "entity_attributes": { "class": "medication" } }, { "entity": "100mg", "entity_attributes": { "frequency": "every morning", "class": "dosage" } } ] } ```"""), score=0.9, ) ]] mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_gemini_schema.return_value = None expected_result = data.AnnotatedDocument( document_id=None, extractions=[ data.Extraction( extraction_class="entity", extraction_text="Aspirin", char_interval=data.CharInterval(start_pos=14, end_pos=21), alignment_status=data.AlignmentStatus.MATCH_EXACT, extraction_index=1, group_index=0, description=None, attributes={"class": "medication"}, ), data.Extraction( extraction_class="entity", extraction_text="100mg", char_interval=data.CharInterval(start_pos=22, end_pos=27), alignment_status=data.AlignmentStatus.MATCH_EXACT, extraction_index=2, group_index=1, description=None, attributes={"frequency": "every morning", "class": "dosage"}, ), ], text="Patient takes Aspirin 100mg every morning.", ) mock_description = textwrap.dedent("""\ Extract medication and dosage information in order of occurrence. """) mock_examples = [ lx.data.ExampleData( text="Patient takes Tylenol 500mg daily.", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="Tylenol", attributes={ "type": "analgesic", "class": "medication", }, ), ], ) ] mock_prompt_template = prompting.PromptTemplateStructured( description=mock_description, examples=mock_examples ) format_handler = fh.FormatHandler( format_type=data.FormatType.JSON, use_wrapper=True, wrapper_key="extractions", use_fences=True, ) prompt_generator = prompting.QAPromptGenerator( template=mock_prompt_template, format_handler=format_handler ) actual_result = lx.extract( text_or_documents=input_text, prompt_description=mock_description, examples=mock_examples, api_key="some_api_key", fence_output=True, use_schema_constraints=False, ) mock_gemini_schema.assert_not_called() mock_create_model.assert_called_once() mock_model.infer.assert_called_once_with( batch_prompts=[prompt_generator.render(input_text)], max_workers=10, ) self.assertDataclassEqual(expected_result, actual_result) @mock.patch("langextract.extraction.resolver.Resolver.align") @mock.patch("langextract.extraction.factory.create_model") def test_extract_resolver_params_alignment_passthrough( self, mock_create_model, mock_align ): mock_model = mock.MagicMock() mock_model.infer.return_value = [ [types.ScoredOutput(output='{"extractions":[]}')] ] mock_model.requires_fence_output = False mock_create_model.return_value = mock_model mock_align.return_value = [] mock_examples = [ lx.data.ExampleData( text="Patient takes Tylenol 500mg daily.", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="Tylenol", attributes={ "type": "analgesic", "class": "medication", }, ), ], ) ] lx.extract( text_or_documents="test text", prompt_description="desc", examples=mock_examples, api_key="test_key", resolver_params={ "enable_fuzzy_alignment": False, "fuzzy_alignment_threshold": 0.8, "accept_match_lesser": False, }, ) mock_align.assert_called() _, kwargs = mock_align.call_args self.assertFalse(kwargs.get("enable_fuzzy_alignment")) self.assertEqual(kwargs.get("fuzzy_alignment_threshold"), 0.8) self.assertFalse(kwargs.get("accept_match_lesser")) @mock.patch("langextract.annotation.Annotator.annotate_text") @mock.patch("langextract.extraction.factory.create_model") def test_extract_resolver_params_suppress_parse_errors( self, mock_create_model, mock_annotate ): """Test that suppress_parse_errors can be passed through resolver_params.""" mock_model = mock.MagicMock() mock_model.requires_fence_output = False mock_model.schema = None mock_create_model.return_value = mock_model mock_annotate.return_value = lx.data.AnnotatedDocument( text="test", extractions=[] ) mock_examples = [ lx.data.ExampleData( text="Example text", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="example", ), ], ) ] # This should not raise a TypeError about unknown key lx.extract( text_or_documents="test text", prompt_description="desc", examples=mock_examples, api_key="test_key", resolver_params={ "suppress_parse_errors": True, "enable_fuzzy_alignment": False, }, ) mock_annotate.assert_called() _, kwargs = mock_annotate.call_args self.assertIn("suppress_parse_errors", kwargs) self.assertTrue(kwargs.get("suppress_parse_errors")) self.assertFalse(kwargs.get("enable_fuzzy_alignment")) @mock.patch("langextract.extraction.resolver.Resolver") @mock.patch("langextract.extraction.factory.create_model") def test_extract_resolver_params_none_handling( self, mock_create_model, mock_resolver_class ): mock_model = mock.MagicMock() mock_model.infer.return_value = [ [types.ScoredOutput(output='{"extractions":[]}')] ] mock_model.requires_fence_output = False mock_create_model.return_value = mock_model mock_resolver = mock.MagicMock() mock_resolver_class.return_value = mock_resolver mock_examples = [ lx.data.ExampleData( text="Test text", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="test", ), ], ) ] with mock.patch( "langextract.annotation.Annotator.annotate_text" ) as mock_annotate: mock_annotate.return_value = lx.data.AnnotatedDocument( text="test", extractions=[] ) lx.extract( text_or_documents="test text", prompt_description="desc", examples=mock_examples, api_key="test_key", resolver_params={ "enable_fuzzy_alignment": None, "fuzzy_alignment_threshold": 0.8, }, ) _, resolver_kwargs = mock_resolver_class.call_args self.assertNotIn("enable_fuzzy_alignment", resolver_kwargs) self.assertNotIn("fuzzy_alignment_threshold", resolver_kwargs) self.assertIn("format_handler", resolver_kwargs) _, annotate_kwargs = mock_annotate.call_args self.assertNotIn("enable_fuzzy_alignment", annotate_kwargs) self.assertEqual(annotate_kwargs["fuzzy_alignment_threshold"], 0.8) @mock.patch("langextract.extraction.factory.create_model") def test_extract_resolver_params_typo_error(self, mock_create_model): mock_model = mock.MagicMock() mock_model.requires_fence_output = False mock_create_model.return_value = mock_model mock_examples = [ lx.data.ExampleData( text="Test", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="test", ), ], ) ] with self.assertRaisesRegex(TypeError, "Unknown key in resolver_params"): lx.extract( text_or_documents="test", prompt_description="desc", examples=mock_examples, api_key="test_key", resolver_params={ "fuzzy_alignment_treshold": ( # Typo: treshold instead of threshold 0.5 ), }, ) @mock.patch("langextract.annotation.Annotator.annotate_documents") @mock.patch("langextract.extraction.factory.create_model") def test_extract_resolver_params_docs_path_passthrough( self, mock_create_model, mock_annotate_docs ): mock_model = mock.MagicMock() mock_model.infer.return_value = [ [types.ScoredOutput(output='{"extractions":[]}')] ] mock_model.requires_fence_output = False mock_create_model.return_value = mock_model mock_annotate_docs.return_value = [] docs = [lx.data.Document(text="doc1")] examples = [ lx.data.ExampleData( text="Example text", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="example", ), ], ) ] lx.extract( text_or_documents=docs, prompt_description="desc", examples=examples, api_key="k", resolver_params={ "enable_fuzzy_alignment": False, "fuzzy_alignment_threshold": 0.9, "accept_match_lesser": False, }, ) _, kwargs = mock_annotate_docs.call_args self.assertFalse(kwargs.get("enable_fuzzy_alignment")) self.assertEqual(kwargs.get("fuzzy_alignment_threshold"), 0.9) self.assertFalse(kwargs.get("accept_match_lesser")) @mock.patch("langextract.annotation.Annotator.annotate_text") @mock.patch("langextract.extraction.resolver.Resolver") @mock.patch("langextract.extraction.factory.create_model") def test_extract_resolver_params_none_threshold( self, mock_create_model, mock_resolver_cls, mock_annotate ): mock_model = mock.MagicMock() mock_model.infer.return_value = [ [types.ScoredOutput(output='{"extractions":[]}')] ] mock_model.requires_fence_output = False mock_create_model.return_value = mock_model mock_resolver_cls.return_value = mock.MagicMock() mock_annotate.return_value = lx.data.AnnotatedDocument( text="t", extractions=[] ) lx.extract( text_or_documents="t", prompt_description="d", examples=[ lx.data.ExampleData( text="example", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="ex", ), ], ) ], api_key="k", resolver_params={"fuzzy_alignment_threshold": None}, ) _, resolver_kwargs = mock_resolver_cls.call_args self.assertNotIn("fuzzy_alignment_threshold", resolver_kwargs) _, annotate_kwargs = mock_annotate.call_args self.assertNotIn("fuzzy_alignment_threshold", annotate_kwargs) @mock.patch.object( schemas.gemini.GeminiSchema, "from_examples", autospec=True ) @mock.patch("langextract.extraction.factory.create_model") def test_extract_custom_params_reach_inference( self, mock_create_model, mock_gemini_schema ): """Sanity check that custom parameters reach the inference layer.""" input_text = "Test text" mock_model = mock.MagicMock() mock_model.infer.return_value = [[ types.ScoredOutput( output='```json\n{"extractions": []}\n```', score=0.9, ) ]] mock_model.requires_fence_output = True mock_create_model.return_value = mock_model mock_gemini_schema.return_value = None mock_examples = [ lx.data.ExampleData( text="Example", extractions=[ lx.data.Extraction( extraction_class="test", extraction_text="example", ), ], ) ] lx.extract( text_or_documents=input_text, prompt_description="Test extraction", examples=mock_examples, api_key="test_key", max_workers=5, fence_output=True, use_schema_constraints=False, ) mock_model.infer.assert_called_once() _, kwargs = mock_model.infer.call_args self.assertEqual(kwargs.get("max_workers"), 5) @mock.patch("langextract.extraction.factory.create_model") def test_extract_with_custom_tokenizer(self, mock_create_model): """Test that a custom tokenizer can be passed to extract().""" input_text = "Test text" mock_model = mock.MagicMock() mock_model.infer.return_value = [[ types.ScoredOutput( output='```json\n{"extractions": []}\n```', score=0.9, ) ]] mock_model.requires_fence_output = True mock_create_model.return_value = mock_model def mock_tokenize(text): if text == "\u241F": # Delimiter return lx.tokenizer.TokenizedText( text=text, tokens=[ lx.tokenizer.Token( index=0, token_type=lx.tokenizer.TokenType.PUNCTUATION, char_interval=lx.tokenizer.CharInterval(0, 1), ) ], ) # Return dummy tokens for other text to avoid "empty tokens" error in aligner return lx.tokenizer.TokenizedText( text=text, tokens=[ lx.tokenizer.Token( index=0, token_type=lx.tokenizer.TokenType.WORD, char_interval=lx.tokenizer.CharInterval(0, len(text)), ) ], ) mock_tokenizer = mock.MagicMock() mock_tokenizer.tokenize.side_effect = mock_tokenize mock_examples = [ lx.data.ExampleData( text="Example", extractions=[ lx.data.Extraction( extraction_class="test", extraction_text="example", ), ], ) ] lx.extract( text_or_documents=input_text, prompt_description="Test extraction", examples=mock_examples, api_key="test_key", tokenizer=mock_tokenizer, ) mock_tokenizer.tokenize.assert_called_with(input_text) def test_data_module_exports_via_compatibility_shim(self): """Verify data module exports are accessible via lx.data.""" expected_exports = [ "AlignmentStatus", "CharInterval", "Extraction", "Document", "AnnotatedDocument", "ExampleData", "FormatType", ] for name in expected_exports: with self.subTest(export=name): self.assertTrue( hasattr(lx.data, name), f"lx.data.{name} not accessible via compatibility shim", ) def test_tokenizer_module_exports_via_compatibility_shim(self): """Verify tokenizer module exports are accessible via lx.tokenizer.""" expected_exports = [ "BaseTokenizerError", "InvalidTokenIntervalError", "SentenceRangeError", "CharInterval", "TokenInterval", "TokenType", "Token", "TokenizedText", "tokenize", "tokens_text", "find_sentence_range", ] for name in expected_exports: with self.subTest(export=name): self.assertTrue( hasattr(lx.tokenizer, name), f"lx.tokenizer.{name} not accessible via compatibility shim", ) @parameterized.named_parameters( dict( testcase_name="show_progress_true_debug_false", show_progress=True, debug=False, expected_progress_disabled=False, ), dict( testcase_name="show_progress_false_debug_false", show_progress=False, debug=False, expected_progress_disabled=True, ), dict( testcase_name="show_progress_true_debug_true", show_progress=True, debug=True, expected_progress_disabled=False, ), dict( testcase_name="show_progress_false_debug_true", show_progress=False, debug=True, expected_progress_disabled=True, ), ) @mock.patch("langextract.progress.create_extraction_progress_bar") @mock.patch("langextract.extraction.factory.create_model") def test_show_progress_controls_progress_bar( self, mock_create_model, mock_progress, show_progress, debug, expected_progress_disabled, ): """Test that show_progress parameter controls progress bar visibility.""" mock_model = mock.MagicMock() mock_model.infer.return_value = [ [ types.ScoredOutput( output='{"extractions": []}', score=0.9, ) ] ] mock_model.requires_fence_output = False mock_create_model.return_value = mock_model mock_progress.side_effect = lambda iterable, **kwargs: iter(iterable) mock_examples = [ lx.data.ExampleData( text="Example text", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="example", ), ], ) ] lx.extract( text_or_documents="test text", prompt_description="extract entities", examples=mock_examples, api_key="test_key", show_progress=show_progress, debug=debug, ) mock_progress.assert_called() call_args = mock_progress.call_args self.assertEqual( call_args.kwargs.get("disable", False), expected_progress_disabled ) @mock.patch("langextract.factory.create_model") def test_schema_validation_warning_issued(self, mock_create_model): """Test that schema validation warnings are properly issued.""" mock_model = mock.Mock(spec=base_model.BaseLanguageModel) mock_model.requires_fence_output = True mock_model.infer.return_value = [ [types.ScoredOutput(output='{"extractions": []}', score=1.0)] ] mock_schema = mock.Mock(spec=schema.BaseSchema) def validate_format_side_effect(format_handler): warnings.warn("Test validation warning", UserWarning, stacklevel=3) mock_schema.validate_format = mock.Mock( side_effect=validate_format_side_effect ) mock_model.schema = mock_schema mock_create_model.return_value = mock_model test_examples = [ lx.data.ExampleData( text="test", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="test", ), ], ) ] with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = lx.extract( text_or_documents="Sample text", prompt_description="Extract", examples=test_examples, model_id="test-model", api_key="key", use_schema_constraints=True, ) warning_messages = [str(warning.message) for warning in w] self.assertIn( "Test validation warning", " ".join(warning_messages), "Schema validation warning should be issued", ) self.assertIsNotNone(result) def test_gemini_schema_deprecation_warning(self): """Test that passing gemini_schema triggers deprecation warning.""" mock_model = mock.MagicMock(spec=base_model.BaseLanguageModel) mock_model.infer.return_value = iter( [[mock.Mock(output='{"extractions": []}')]] ) mock_model.requires_fence_output = True mock_model.schema = None self.enter_context( mock.patch( "langextract.factory.create_model", return_value=mock_model, ) ) self.enter_context( mock.patch( "langextract.annotation.Annotator.annotate_text", return_value=data.AnnotatedDocument(text="test", extractions=[]), ) ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") _ = lx.extract( text_or_documents="test", prompt_description="Extract conditions", examples=[ lx.data.ExampleData( text="test", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="test", ), ], ) ], model_id="gemini-2.5-flash", api_key="test_key", language_model_params={"gemini_schema": "deprecated"}, ) self.assertTrue( any( issubclass(warning.category, FutureWarning) and "gemini_schema" in str(warning.message) for warning in w ), "Expected deprecation warning for gemini_schema", ) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/progress_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for langextract.progress module.""" import unittest from unittest import mock import tqdm from langextract import progress class ProgressTest(unittest.TestCase): def test_download_progress_bar(self): """Test download progress bar creation.""" pbar = progress.create_download_progress_bar( 1024, "https://example.com/file.txt" ) self.assertIsInstance(pbar, tqdm.tqdm) self.assertEqual(pbar.total, 1024) self.assertIn("Downloading", pbar.desc) def test_extraction_progress_bar(self): """Test extraction progress bar creation.""" pbar = progress.create_extraction_progress_bar( range(10), "gemini-2.0-flash" ) self.assertIsInstance(pbar, tqdm.tqdm) self.assertIn("LangExtract", pbar.desc) self.assertIn("gemini-2.0-flash", pbar.desc) def test_save_load_progress_bars(self): """Test save and load progress bar creation.""" save_pbar = progress.create_save_progress_bar("/path/file.json") load_pbar = progress.create_load_progress_bar("/path/file.json") self.assertIsInstance(save_pbar, tqdm.tqdm) self.assertIsInstance(load_pbar, tqdm.tqdm) self.assertIn("Saving", save_pbar.desc) self.assertIn("Loading", load_pbar.desc) def test_model_info_extraction(self): """Test extracting model info from objects.""" mock_model = mock.MagicMock() mock_model.model_id = "gemini-1.5-pro" self.assertEqual(progress.get_model_info(mock_model), "gemini-1.5-pro") mock_model = mock.MagicMock() del mock_model.model_id del mock_model.model_url self.assertIsNone(progress.get_model_info(mock_model)) def test_formatting_functions(self): """Test message formatting functions.""" stats = progress.format_extraction_stats(1500, 5000) self.assertIn("1,500", stats) self.assertIn("5,000", stats) desc = progress.format_extraction_progress("gemini-2.0-flash") self.assertIn("LangExtract", desc) self.assertIn("gemini-2.0-flash", desc) desc_no_model = progress.format_extraction_progress(None) self.assertIn("Processing", desc_no_model) if __name__ == "__main__": unittest.main() ================================================ FILE: tests/prompt_validation_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for prompt validation module.""" from absl.testing import absltest from absl.testing import parameterized from langextract import extraction from langextract import prompt_validation from langextract.core import data class PromptAlignmentValidationTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name="exact_alignment", text="Patient takes lisinopril.", extraction_class="Medication", extraction_text="lisinopril", expected_issues=0, expected_has_failed=False, expected_has_non_exact=False, expected_alignment_status=None, ), dict( testcase_name="fuzzy_match_lesser", text="Type 2 diabetes.", extraction_class="Diagnosis", extraction_text="type-2 diabetes", expected_issues=1, expected_has_failed=False, expected_has_non_exact=True, expected_alignment_status=data.AlignmentStatus.MATCH_LESSER, ), dict( testcase_name="extraction_not_found", text="No medications mentioned in this text.", extraction_class="Medication", extraction_text="lisinopril", expected_issues=1, expected_has_failed=True, expected_has_non_exact=False, expected_alignment_status=None, ), ) def test_alignment_detection( self, text, extraction_class, extraction_text, expected_issues, expected_has_failed, expected_has_non_exact, expected_alignment_status, ): """Test that different alignment types are correctly detected.""" example = data.ExampleData( text=text, extractions=[ data.Extraction( extraction_class=extraction_class, extraction_text=extraction_text, attributes={}, ) ], ) report = prompt_validation.validate_prompt_alignment([example]) self.assertLen(report.issues, expected_issues) self.assertEqual(report.has_failed, expected_has_failed) self.assertEqual(report.has_non_exact, expected_has_non_exact) if expected_issues > 0: issue = report.issues[0] self.assertEqual(issue.alignment_status, expected_alignment_status) self.assertEqual(issue.extraction_class, extraction_class) if expected_has_failed: self.assertIsNone(issue.alignment_status) elif expected_has_non_exact: self.assertIsNotNone(issue.alignment_status) @parameterized.named_parameters( dict( testcase_name="one_fails", text="Patient takes lisinopril and has diabetes mellitus.", extractions=[ ("Medication", "lisinopril"), # PASSES - found exactly ("Diagnosis", "diabetes"), # PASSES - found exactly ("Medication", "metformin"), # FAILS - not in text ], expected_issues=1, expected_has_failed=True, expected_has_non_exact=False, expected_failed_text="metformin", ), dict( testcase_name="all_pass", text="Patient takes lisinopril and aspirin for diabetes management.", extractions=[ ("Medication", "lisinopril"), ("Medication", "aspirin"), ("Diagnosis", "diabetes"), ], expected_issues=0, expected_has_failed=False, expected_has_non_exact=False, expected_failed_text=None, ), ) def test_multiple_extractions_per_example( self, text, extractions, expected_issues, expected_has_failed, expected_has_non_exact, expected_failed_text, ): """Test validation with multiple extractions in a single example.""" example = data.ExampleData( text=text, extractions=[ data.Extraction( extraction_class=extraction_class, extraction_text=extraction_text, attributes={}, ) for extraction_class, extraction_text in extractions ], ) report = prompt_validation.validate_prompt_alignment([example]) self.assertLen(report.issues, expected_issues) self.assertEqual(report.has_failed, expected_has_failed) self.assertEqual(report.has_non_exact, expected_has_non_exact) if expected_failed_text: issue = report.issues[0] self.assertIsNone(issue.alignment_status) self.assertEqual(issue.extraction_text_preview, expected_failed_text) @parameterized.named_parameters( dict( testcase_name="warning_mode_with_failed", text="Patient has no known allergies.", extraction_text="penicillin", validation_level=prompt_validation.PromptValidationLevel.WARNING, strict_non_exact=False, ), dict( testcase_name="off_mode_with_failed", text="Patient history incomplete.", extraction_text="aspirin", validation_level=prompt_validation.PromptValidationLevel.OFF, strict_non_exact=False, ), ) def test_validation_levels_that_dont_raise( self, text, extraction_text, validation_level, strict_non_exact ): """Test that WARNING and OFF modes don't raise exceptions.""" example = data.ExampleData( text=text, extractions=[ data.Extraction( extraction_class="Medication", extraction_text=extraction_text, attributes={}, ) ], ) report = prompt_validation.validate_prompt_alignment([example]) # This should not raise an exception in WARNING or OFF modes prompt_validation.handle_alignment_report( report, validation_level, strict_non_exact=strict_non_exact ) @parameterized.named_parameters( dict( testcase_name="error_mode_failed_alignment", text="Patient has no known allergies.", extraction_class="Medication", extraction_text="penicillin", strict_non_exact=False, error_pattern=r"1 extraction\(s\).*could not be aligned", ), dict( testcase_name="error_mode_strict_fuzzy_match", text="Type 2 diabetes.", extraction_class="Diagnosis", extraction_text="type-2 diabetes", strict_non_exact=True, error_pattern=r"strict mode.*1 non-exact", ), ) def test_error_mode_raises_appropriately( self, text, extraction_class, extraction_text, strict_non_exact, error_pattern, ): """Test that ERROR mode raises with appropriate messages.""" example = data.ExampleData( text=text, extractions=[ data.Extraction( extraction_class=extraction_class, extraction_text=extraction_text, attributes={}, ) ], ) report = prompt_validation.validate_prompt_alignment([example]) with self.assertRaisesRegex( prompt_validation.PromptAlignmentError, error_pattern ): prompt_validation.handle_alignment_report( report, prompt_validation.PromptValidationLevel.ERROR, strict_non_exact=strict_non_exact, ) def test_empty_examples_produces_empty_report(self): report = prompt_validation.validate_prompt_alignment([]) self.assertEmpty(report.issues) self.assertFalse(report.has_failed) self.assertFalse(report.has_non_exact) def test_multiple_examples_preserve_indices(self): examples = [ data.ExampleData( # Example 0: FAILS - "metformin" not in text text="First patient record.", extractions=[ data.Extraction( extraction_class="Medication", extraction_text="metformin", attributes={}, ) ], ), data.ExampleData( # Example 1: PASSES - "aspirin" found exactly text="Patient takes aspirin daily.", extractions=[ data.Extraction( extraction_class="Medication", extraction_text="aspirin", attributes={}, ) ], ), data.ExampleData( # Example 2: NON-EXACT - "type-2" fuzzy matches "Type 2" text="Type 2 diabetes mellitus.", extractions=[ data.Extraction( extraction_class="Diagnosis", extraction_text="type-2 diabetes", attributes={}, ) ], ), ] report = prompt_validation.validate_prompt_alignment(examples) # Expect 2 issues: example 0 (failed) and example 2 (non-exact) self.assertLen(report.issues, 2) self.assertTrue(report.has_failed) self.assertTrue(report.has_non_exact) issue_by_index = {issue.example_index: issue for issue in report.issues} # Example 0: Failed alignment (metformin not found) self.assertIn(0, issue_by_index) self.assertIsNone(issue_by_index[0].alignment_status) # Example 1: No issue (aspirin found exactly) self.assertNotIn(1, issue_by_index) # Example 2: Non-exact match (type-2 vs Type 2) self.assertIn(2, issue_by_index) self.assertIsNotNone(issue_by_index[2].alignment_status) def test_validation_does_not_mutate_input(self): example = data.ExampleData( text="Patient takes lisinopril 10mg daily.", extractions=[ data.Extraction( extraction_class="Medication", extraction_text="lisinopril", attributes={}, ) ], ) original_extraction = example.extractions[0] self.assertIsNone(getattr(original_extraction, "token_interval", None)) self.assertIsNone(getattr(original_extraction, "char_interval", None)) self.assertIsNone(getattr(original_extraction, "alignment_status", None)) _ = prompt_validation.validate_prompt_alignment([example]) self.assertIsNone(getattr(original_extraction, "token_interval", None)) self.assertIsNone(getattr(original_extraction, "char_interval", None)) self.assertIsNone(getattr(original_extraction, "alignment_status", None)) @parameterized.named_parameters( dict( testcase_name="fuzzy_disabled_rejects_non_exact", text="Patient has type 2 diabetes.", extraction_class="Diagnosis", extraction_text="Type-2 Diabetes", enable_fuzzy=False, accept_lesser=False, fuzzy_threshold=0.75, expected_has_failed=True, expected_has_non_exact=False, ), dict( testcase_name="fuzzy_enabled_accepts_close_match", text="Patient has type 2 diabetes.", extraction_class="Diagnosis", extraction_text="Type-2 Diabetes", enable_fuzzy=True, accept_lesser=False, fuzzy_threshold=0.75, expected_has_failed=False, expected_has_non_exact=True, ), ) def test_alignment_policies( self, text, extraction_class, extraction_text, enable_fuzzy, accept_lesser, fuzzy_threshold, expected_has_failed, expected_has_non_exact, ): """Test different alignment policy configurations.""" example = data.ExampleData( text=text, extractions=[ data.Extraction( extraction_class=extraction_class, extraction_text=extraction_text, attributes={}, ) ], ) if not enable_fuzzy: default_report = prompt_validation.validate_prompt_alignment([example]) self.assertFalse(default_report.has_failed) self.assertTrue(default_report.has_non_exact) policy = prompt_validation.AlignmentPolicy( enable_fuzzy_alignment=enable_fuzzy, accept_match_lesser=accept_lesser, fuzzy_alignment_threshold=fuzzy_threshold, ) report = prompt_validation.validate_prompt_alignment( [example], policy=policy ) self.assertEqual(report.has_failed, expected_has_failed) self.assertEqual(report.has_non_exact, expected_has_non_exact) class ExtractIntegrationTest(absltest.TestCase): """Minimal integration test for extract() entry point validation.""" def test_extract_validates_in_error_mode(self): """Verify extract() runs validation when configured.""" examples = [ data.ExampleData( text="Patient takes aspirin.", extractions=[ data.Extraction( extraction_class="Medication", extraction_text="ibuprofen", attributes={}, ) ], ) ] with self.assertRaisesRegex( prompt_validation.PromptAlignmentError, r"1 extraction\(s\).*could not be aligned", ): extraction.extract( text_or_documents="Test document", prompt_description="Extract medications", examples=examples, prompt_validation_level=prompt_validation.PromptValidationLevel.ERROR, model_id="fake-model", ) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/prompting_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import textwrap from absl.testing import absltest from absl.testing import parameterized from langextract import prompting from langextract.core import data from langextract.core import format_handler as fh class QAPromptGeneratorTest(parameterized.TestCase): def test_generate_prompt(self): prompt_template_structured = prompting.PromptTemplateStructured( description=( "You are an assistant specialized in extracting key extractions" " from text.\nIdentify and extract important extractions such as" " people, places,\norganizations, dates, and medical conditions" " mentioned in the text.\n**Please ensure that the extractions are" " extracted in the same order as they\nappear in the source" " text.**\nProvide the extracted extractions in a structured YAML" " format." ), examples=[ data.ExampleData( text=( "The patient was diagnosed with hypertension and diabetes." ), extractions=[ data.Extraction( extraction_text="hypertension", extraction_class="medical_condition", attributes={ "chronicity": "chronic", "system": "cardiovascular", }, ), data.Extraction( extraction_text="diabetes", extraction_class="medical_condition", attributes={ "chronicity": "chronic", "system": "endocrine", }, ), ], ) ], ) format_handler = fh.FormatHandler( format_type=data.FormatType.YAML, use_wrapper=True, wrapper_key="extractions", use_fences=True, ) prompt_generator = prompting.QAPromptGenerator( template=prompt_template_structured, format_handler=format_handler, examples_heading="", question_prefix="", answer_prefix="", ) actual_prompt_text = prompt_generator.render( "The patient reports chest pain and shortness of breath." ) expected_prompt_text = textwrap.dedent(f"""\ You are an assistant specialized in extracting key extractions from text. Identify and extract important extractions such as people, places, organizations, dates, and medical conditions mentioned in the text. **Please ensure that the extractions are extracted in the same order as they appear in the source text.** Provide the extracted extractions in a structured YAML format. The patient was diagnosed with hypertension and diabetes. ```yaml {data.EXTRACTIONS_KEY}: - medical_condition: hypertension medical_condition_attributes: chronicity: chronic system: cardiovascular - medical_condition: diabetes medical_condition_attributes: chronicity: chronic system: endocrine ``` The patient reports chest pain and shortness of breath. """) self.assertEqual(expected_prompt_text, actual_prompt_text) @parameterized.named_parameters( dict( testcase_name="json_basic_format", format_type=data.FormatType.JSON, example_text="Patient has diabetes and is prescribed insulin.", example_extractions=[ data.Extraction( extraction_text="diabetes", extraction_class="medical_condition", attributes={"chronicity": "chronic"}, ), data.Extraction( extraction_text="insulin", extraction_class="medication", attributes={"prescribed": "prescribed"}, ), ], expected_formatted_example=textwrap.dedent(f"""\ Patient has diabetes and is prescribed insulin. ```json {{ "{data.EXTRACTIONS_KEY}": [ {{ "medical_condition": "diabetes", "medical_condition_attributes": {{ "chronicity": "chronic" }} }}, {{ "medication": "insulin", "medication_attributes": {{ "prescribed": "prescribed" }} }} ] }} ``` """), ), dict( testcase_name="yaml_basic_format", format_type=data.FormatType.YAML, example_text="Patient has diabetes and is prescribed insulin.", example_extractions=[ data.Extraction( extraction_text="diabetes", extraction_class="medical_condition", attributes={"chronicity": "chronic"}, ), data.Extraction( extraction_text="insulin", extraction_class="medication", attributes={"prescribed": "prescribed"}, ), ], expected_formatted_example=textwrap.dedent(f"""\ Patient has diabetes and is prescribed insulin. ```yaml {data.EXTRACTIONS_KEY}: - medical_condition: diabetes medical_condition_attributes: chronicity: chronic - medication: insulin medication_attributes: prescribed: prescribed ``` """), ), dict( testcase_name="custom_attribute_suffix", format_type=data.FormatType.YAML, example_text="Patient has a fever.", example_extractions=[ data.Extraction( extraction_text="fever", extraction_class="symptom", attributes={"severity": "mild"}, ), ], attribute_suffix="_props", expected_formatted_example=textwrap.dedent(f"""\ Patient has a fever. ```yaml {data.EXTRACTIONS_KEY}: - symptom: fever symptom_props: severity: mild ``` """), ), dict( testcase_name="yaml_empty_extractions", format_type=data.FormatType.YAML, example_text="Text with no extractions.", example_extractions=[], expected_formatted_example=textwrap.dedent(f"""\ Text with no extractions. ```yaml {data.EXTRACTIONS_KEY}: [] ``` """), ), dict( testcase_name="json_empty_extractions", format_type=data.FormatType.JSON, example_text="Text with no extractions.", example_extractions=[], expected_formatted_example=textwrap.dedent(f"""\ Text with no extractions. ```json {{ "{data.EXTRACTIONS_KEY}": [] }} ``` """), ), dict( testcase_name="yaml_empty_attributes", format_type=data.FormatType.YAML, example_text="Patient is resting comfortably.", example_extractions=[ data.Extraction( extraction_text="Patient", extraction_class="person", attributes={}, ), ], expected_formatted_example=textwrap.dedent(f"""\ Patient is resting comfortably. ```yaml {data.EXTRACTIONS_KEY}: - person: Patient person_attributes: {{}} ``` """), ), dict( testcase_name="json_empty_attributes", format_type=data.FormatType.JSON, example_text="Patient is resting comfortably.", example_extractions=[ data.Extraction( extraction_text="Patient", extraction_class="person", attributes={}, ), ], expected_formatted_example=textwrap.dedent(f"""\ Patient is resting comfortably. ```json {{ "{data.EXTRACTIONS_KEY}": [ {{ "person": "Patient", "person_attributes": {{}} }} ] }} ``` """), ), dict( testcase_name="yaml_same_extraction_class_multiple_times", format_type=data.FormatType.YAML, example_text=( "Patient has multiple medications: aspirin and lisinopril." ), example_extractions=[ data.Extraction( extraction_text="aspirin", extraction_class="medication", attributes={"dosage": "81mg"}, ), data.Extraction( extraction_text="lisinopril", extraction_class="medication", attributes={"dosage": "10mg"}, ), ], expected_formatted_example=textwrap.dedent(f"""\ Patient has multiple medications: aspirin and lisinopril. ```yaml {data.EXTRACTIONS_KEY}: - medication: aspirin medication_attributes: dosage: 81mg - medication: lisinopril medication_attributes: dosage: 10mg ``` """), ), dict( testcase_name="json_simplified_no_extractions_key", format_type=data.FormatType.JSON, example_text="Patient has diabetes and is prescribed insulin.", example_extractions=[ data.Extraction( extraction_text="diabetes", extraction_class="medical_condition", attributes={"chronicity": "chronic"}, ), data.Extraction( extraction_text="insulin", extraction_class="medication", attributes={"prescribed": "prescribed"}, ), ], require_extractions_key=False, expected_formatted_example=textwrap.dedent("""\ Patient has diabetes and is prescribed insulin. ```json [ { "medical_condition": "diabetes", "medical_condition_attributes": { "chronicity": "chronic" } }, { "medication": "insulin", "medication_attributes": { "prescribed": "prescribed" } } ] ``` """), ), dict( testcase_name="yaml_simplified_no_extractions_key", format_type=data.FormatType.YAML, example_text="Patient has a fever.", example_extractions=[ data.Extraction( extraction_text="fever", extraction_class="symptom", attributes={"severity": "mild"}, ), ], require_extractions_key=False, expected_formatted_example=textwrap.dedent("""\ Patient has a fever. ```yaml - symptom: fever symptom_attributes: severity: mild ``` """), ), ) def test_format_example( self, format_type, example_text, example_extractions, expected_formatted_example, attribute_suffix="_attributes", require_extractions_key=True, ): """Tests formatting of examples in different formats and scenarios.""" example_data = data.ExampleData( text=example_text, extractions=example_extractions, ) structured_template = prompting.PromptTemplateStructured( description="Extract information from the text.", examples=[example_data], ) format_handler = fh.FormatHandler( format_type=format_type, use_wrapper=require_extractions_key, wrapper_key="extractions" if require_extractions_key else None, use_fences=True, attribute_suffix=attribute_suffix, ) prompt_generator = prompting.QAPromptGenerator( template=structured_template, format_handler=format_handler, question_prefix="", answer_prefix="", ) actual_formatted_example = prompt_generator.format_example_as_text( example_data ) self.assertEqual(expected_formatted_example, actual_formatted_example) class PromptBuilderTest(absltest.TestCase): """Tests for PromptBuilder base class.""" def _create_generator(self): """Creates a simple QAPromptGenerator for testing.""" template = prompting.PromptTemplateStructured( description="Extract entities.", examples=[ data.ExampleData( text="Sample text.", extractions=[ data.Extraction( extraction_text="Sample", extraction_class="entity", ) ], ) ], ) format_handler = fh.FormatHandler( format_type=data.FormatType.YAML, use_wrapper=True, wrapper_key="extractions", use_fences=True, ) return prompting.QAPromptGenerator( template=template, format_handler=format_handler, ) def test_build_prompt_renders_chunk_text(self): """Verifies build_prompt includes chunk text in the rendered prompt.""" generator = self._create_generator() builder = prompting.PromptBuilder(generator) prompt = builder.build_prompt( chunk_text="Test input text.", document_id="doc1", ) self.assertIn("Test input text.", prompt) self.assertIn("Extract entities.", prompt) def test_build_prompt_includes_additional_context(self): """Verifies build_prompt passes additional_context to renderer.""" generator = self._create_generator() builder = prompting.PromptBuilder(generator) prompt = builder.build_prompt( chunk_text="Test input.", document_id="doc1", additional_context="Important context here.", ) self.assertIn("Important context here.", prompt) class ContextAwarePromptBuilderTest(absltest.TestCase): """Tests for ContextAwarePromptBuilder.""" def _create_generator(self): """Creates a simple QAPromptGenerator for testing.""" template = prompting.PromptTemplateStructured( description="Extract entities.", examples=[ data.ExampleData( text="Sample text.", extractions=[ data.Extraction( extraction_text="Sample", extraction_class="entity", ) ], ) ], ) format_handler = fh.FormatHandler( format_type=data.FormatType.YAML, use_wrapper=True, wrapper_key="extractions", use_fences=True, ) return prompting.QAPromptGenerator( template=template, format_handler=format_handler, ) def test_context_window_chars_property(self): """Verifies the context_window_chars property returns configured value.""" generator = self._create_generator() builder_none = prompting.ContextAwarePromptBuilder(generator) self.assertIsNone(builder_none.context_window_chars) builder_with_value = prompting.ContextAwarePromptBuilder( generator, context_window_chars=100 ) self.assertEqual(100, builder_with_value.context_window_chars) def test_first_chunk_has_no_previous_context(self): """Verifies the first chunk does not include previous context.""" generator = self._create_generator() builder = prompting.ContextAwarePromptBuilder( generator, context_window_chars=50 ) context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX prompt = builder.build_prompt( chunk_text="First chunk text.", document_id="doc1", ) self.assertNotIn(context_prefix, prompt) self.assertIn("First chunk text.", prompt) def test_second_chunk_includes_previous_context(self): """Verifies the second chunk includes text from the first chunk.""" generator = self._create_generator() builder = prompting.ContextAwarePromptBuilder( generator, context_window_chars=20 ) context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX builder.build_prompt(chunk_text="First chunk ending.", document_id="doc1") second_prompt = builder.build_prompt( chunk_text="Second chunk text.", document_id="doc1", ) self.assertIn(context_prefix, second_prompt) self.assertIn("chunk ending.", second_prompt) def test_context_disabled_when_none(self): """Verifies no context is added when context_window_chars is None.""" generator = self._create_generator() builder = prompting.ContextAwarePromptBuilder( generator, context_window_chars=None ) context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX builder.build_prompt(chunk_text="First chunk.", document_id="doc1") second_prompt = builder.build_prompt( chunk_text="Second chunk.", document_id="doc1", ) self.assertNotIn(context_prefix, second_prompt) def test_context_isolated_per_document(self): """Verifies context tracking is isolated per document_id.""" generator = self._create_generator() builder = prompting.ContextAwarePromptBuilder( generator, context_window_chars=50 ) builder.build_prompt(chunk_text="Doc A chunk one.", document_id="docA") builder.build_prompt(chunk_text="Doc B chunk one.", document_id="docB") prompt_a2 = builder.build_prompt( chunk_text="Doc A chunk two.", document_id="docA", ) prompt_b2 = builder.build_prompt( chunk_text="Doc B chunk two.", document_id="docB", ) self.assertIn("Doc A chunk one", prompt_a2) self.assertNotIn("Doc B", prompt_a2) self.assertIn("Doc B chunk one", prompt_b2) self.assertNotIn("Doc A", prompt_b2) def test_combines_previous_context_with_additional_context(self): """Verifies both previous chunk context and additional_context are included.""" generator = self._create_generator() builder = prompting.ContextAwarePromptBuilder( generator, context_window_chars=30 ) context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX builder.build_prompt(chunk_text="Previous chunk text.", document_id="doc1") prompt = builder.build_prompt( chunk_text="Current chunk.", document_id="doc1", additional_context="Extra info here.", ) self.assertIn(context_prefix, prompt) self.assertIn("Previous chunk text.", prompt) self.assertIn("Extra info here.", prompt) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/provider_plugin_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for provider plugin system. Note: This file contains test helper classes that intentionally have few public methods. The too-few-public-methods warnings are expected. """ from importlib import metadata import os from pathlib import Path import subprocess import sys import tempfile import textwrap import types as builtin_types from unittest import mock import uuid from absl.testing import absltest import pytest import langextract as lx from langextract.core import base_model from langextract.core import types def _create_mock_entry_points(entry_points_list): """Create a mock EntryPoints object for testing. Args: entry_points_list: List of entry points to return for langextract.providers. Returns: A mock object that behaves like importlib.metadata.EntryPoints. """ class MockEntryPoints: # pylint: disable=too-few-public-methods """Mock EntryPoints that implements select() method.""" def select(self, group=None): if group == "langextract.providers": return entry_points_list return [] return MockEntryPoints() class PluginSmokeTest(absltest.TestCase): """Basic smoke tests for plugin loading functionality.""" def setUp(self): super().setUp() lx.providers.registry.clear() # Always reset both flags to ensure clean state lx.providers._reset_for_testing() # Register cleanup self.addCleanup(lx.providers.registry.clear) self.addCleanup(lx.providers._reset_for_testing) def test_plugin_discovery_and_usage(self): """Test plugin discovery via entry points. Entry points can return a class or module. Registration happens via the @register decorator in both cases. """ def _ep_load(): @lx.providers.registry.register(r"^plugin-model") class PluginProvider(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def __init__(self, model_id=None, **kwargs): super().__init__() self.model_id = model_id def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="ok")]] return PluginProvider ep = builtin_types.SimpleNamespace( name="plugin_provider", group="langextract.providers", value="my_pkg:PluginProvider", load=_ep_load, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([ep]) ): lx.providers.load_plugins_once() resolved_cls = lx.providers.registry.resolve("plugin-model-123") self.assertEqual( resolved_cls.__name__, "PluginProvider", "Provider should be resolvable after plugin load", ) cfg = lx.factory.ModelConfig(model_id="plugin-model-123") model = lx.factory.create_model(cfg) out = model.infer(["hi"])[0][0].output self.assertEqual(out, "ok", "Provider should return expected output") def test_plugin_disabled_by_env_var(self): """Test that LANGEXTRACT_DISABLE_PLUGINS=1 prevents plugin loading.""" with mock.patch.dict("os.environ", {"LANGEXTRACT_DISABLE_PLUGINS": "1"}): with mock.patch.object(metadata, "entry_points") as mock_ep: lx.providers.load_plugins_once() mock_ep.assert_not_called() def test_handles_import_errors_gracefully(self): """Test that import errors during plugin loading don't crash.""" def _bad_load(): raise ImportError("Plugin not found") bad_ep = builtin_types.SimpleNamespace( name="bad_plugin", group="langextract.providers", value="bad_pkg:BadProvider", load=_bad_load, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([bad_ep]), ): lx.providers.load_plugins_once() providers = lx.providers.registry.list_providers() self.assertIsInstance( providers, list, "Registry should remain functional after import error", ) # Built-in providers should still be loaded even if plugin fails self.assertGreater( len(providers), 0, "Built-in providers should still be available after plugin failure", ) def test_load_plugins_once_is_idempotent(self): """Test that load_plugins_once only discovers once.""" def _ep_load(): @lx.providers.registry.register(r"^plugin-model") class Plugin(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, *a, **k): return [[types.ScoredOutput(score=1.0, output="ok")]] return Plugin ep = builtin_types.SimpleNamespace( name="plugin_provider", group="langextract.providers", value="pkg:Plugin", load=_ep_load, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([ep]) ) as m: lx.providers.load_plugins_once() lx.providers.load_plugins_once() # should be a no-op self.assertEqual(m.call_count, 1, "Discovery should happen only once") def test_non_subclass_entry_point_does_not_crash(self): """Test that non-BaseLanguageModel classes don't crash the system.""" class NotAProvider: # pylint: disable=too-few-public-methods """Dummy class to test non-provider handling.""" bad_ep = builtin_types.SimpleNamespace( name="bad", group="langextract.providers", value="bad:NotAProvider", load=lambda: NotAProvider, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([bad_ep]), ): lx.providers.load_plugins_once() # The system should remain functional even if a bad provider is loaded # Trying to use it would fail, but discovery shouldn't crash providers = lx.providers.registry.list_providers() self.assertIsInstance( providers, list, "Registry should remain functional with bad provider", ) with self.assertRaisesRegex( lx.exceptions.InferenceConfigError, "No provider registered" ): lx.providers.registry.resolve("bad") def test_plugin_priority_override_core_provider(self): """Plugin with higher priority should override core provider on conflicts.""" lx.providers.registry.clear() lx.providers._plugins_loaded = False def _ep_load(): @lx.providers.registry.register(r"^gemini", priority=50) class OverrideGemini(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="override")]] return OverrideGemini ep = builtin_types.SimpleNamespace( name="override_gemini", group="langextract.providers", value="pkg:OverrideGemini", load=_ep_load, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([ep]) ): lx.providers.load_plugins_once() # Core gemini registers with priority 10 in providers.gemini # Our plugin registered with priority 50; it should win. resolved = lx.providers.registry.resolve("gemini-2.5-flash") self.assertEqual(resolved.__name__, "OverrideGemini") def test_resolve_provider_for_plugin(self): """resolve_provider should find plugin by class name and name-insensitive.""" lx.providers.registry.clear() lx.providers._plugins_loaded = False def _ep_load(): @lx.providers.registry.register(r"^plugin-resolve") class ResolveMePlease(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="ok")]] return ResolveMePlease ep = builtin_types.SimpleNamespace( name="resolver_plugin", group="langextract.providers", value="pkg:ResolveMePlease", load=_ep_load, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([ep]) ): lx.providers.load_plugins_once() cls_by_exact = lx.providers.registry.resolve_provider("ResolveMePlease") self.assertEqual(cls_by_exact.__name__, "ResolveMePlease") cls_by_partial = lx.providers.registry.resolve_provider("resolveme") self.assertEqual(cls_by_partial.__name__, "ResolveMePlease") def test_plugin_with_custom_schema(self): """Test that a plugin can provide its own schema implementation.""" class TestPluginSchema(lx.schema.BaseSchema): """Test schema implementation.""" def __init__(self, config): self._config = config @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): return cls({"generated": True, "count": len(examples_data)}) def to_provider_config(self): return {"custom_schema": self._config} @property def requires_raw_output(self): return True def _ep_load(): @lx.providers.registry.register(r"^custom-schema-test") class SchemaTestProvider(base_model.BaseLanguageModel): def __init__(self, model_id=None, **kwargs): super().__init__() self.model_id = model_id self.schema_config = kwargs.get("custom_schema") @classmethod def get_schema_class(cls): return TestPluginSchema def infer(self, batch_prompts, **kwargs): output = ( f"Schema={self.schema_config}" if self.schema_config else "No schema" ) return [[types.ScoredOutput(score=1.0, output=output)]] return SchemaTestProvider ep = builtin_types.SimpleNamespace( name="schema_test", group="langextract.providers", value="test:SchemaTestProvider", load=_ep_load, ) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([ep]) ): lx.providers.load_plugins_once() provider_cls = lx.providers.registry.resolve("custom-schema-test-v1") self.assertEqual( provider_cls.get_schema_class().__name__, "TestPluginSchema", "Plugin should provide custom schema class", ) examples = [ lx.data.ExampleData( text="Test", extractions=[ lx.data.Extraction( extraction_class="test", extraction_text="test text", ) ], ) ] config = lx.factory.ModelConfig(model_id="custom-schema-test-v1") model = lx.factory._create_model_with_schema( config=config, examples=examples, use_schema_constraints=True, fence_output=None, ) self.assertIsNotNone( model.schema_config, "Model should have schema config applied", ) self.assertTrue( model.schema_config["generated"], "Schema should be generated from examples", ) self.assertFalse( model.requires_fence_output, "Schema outputs raw JSON, no fences needed", ) class PluginE2ETest(absltest.TestCase): """End-to-end test with actual pip installation. This test is expensive and only runs when explicitly requested via tox -e plugin-e2e or in CI when provider files change. """ def test_plugin_with_schema_e2e(self): """Test that a plugin with custom schema works end-to-end with extract().""" class TestPluginSchema(lx.schema.BaseSchema): """Test schema implementation.""" def __init__(self, config): self._config = config @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): return cls({"generated": True, "count": len(examples_data)}) def to_provider_config(self): return {"custom_schema": self._config} @property def requires_raw_output(self): return True def _ep_load(): @lx.providers.registry.register(r"^e2e-schema-test") class SchemaE2EProvider(base_model.BaseLanguageModel): def __init__(self, model_id=None, **kwargs): super().__init__() self.model_id = model_id self.schema_config = kwargs.get("custom_schema") @classmethod def get_schema_class(cls): return TestPluginSchema def infer(self, batch_prompts, **kwargs): # Return a mock extraction that includes schema info if self.schema_config: output = ( '{"extractions": [{"entity": "test", ' '"entity_attributes": {"schema": "applied"}}]}' ) else: output = '{"extractions": []}' return [[types.ScoredOutput(score=1.0, output=output)]] return SchemaE2EProvider ep = builtin_types.SimpleNamespace( name="schema_e2e", group="langextract.providers", value="test:SchemaE2EProvider", load=_ep_load, ) # Clear and set up registry lx.providers.registry.clear() lx.providers._plugins_loaded = False self.addCleanup(lx.providers.registry.clear) self.addCleanup(setattr, lx.providers, "_plugins_loaded", False) with mock.patch.object( metadata, "entry_points", return_value=_create_mock_entry_points([ep]) ): lx.providers.load_plugins_once() # Test with extract() using schema constraints examples = [ lx.data.ExampleData( text="Find entities", extractions=[ lx.data.Extraction( extraction_class="entity", extraction_text="example", attributes={"type": "test"}, ) ], ) ] result = lx.extract( text_or_documents="Test text for extraction", prompt_description="Extract entities", examples=examples, model_id="e2e-schema-test-v1", use_schema_constraints=True, fence_output=False, # Schema supports strict mode ) # Verify we got results self.assertIsInstance(result, lx.data.AnnotatedDocument) self.assertIsNotNone(result.extractions) self.assertGreater(len(result.extractions), 0) # Verify the schema was applied by checking the extraction extraction = result.extractions[0] self.assertEqual(extraction.extraction_class, "entity") self.assertIn("schema", extraction.attributes) self.assertEqual(extraction.attributes["schema"], "applied") @pytest.mark.requires_pip @pytest.mark.integration def test_pip_install_discovery_and_cleanup(self): """Test complete plugin lifecycle: install, discovery, usage, uninstall. This test: 1. Creates a Python package with a provider plugin 2. Installs it via pip 3. Verifies the plugin is discovered and usable 4. Uninstalls and verifies cleanup """ # Skip in Bazel environment where pip operations don't work if os.environ.get("TEST_TMPDIR") or os.environ.get( "BUILD_WORKING_DIRECTORY" ): self.skipTest("pip install tests don't work in Bazel sandbox") # Also skip if pip is not available try: subprocess.run( [sys.executable, "-m", "pip", "--version"], capture_output=True, check=True, ) except (subprocess.CalledProcessError, FileNotFoundError): self.skipTest("pip not available in test environment") with tempfile.TemporaryDirectory() as tmpdir: pkg_name = f"test_langextract_plugin_{uuid.uuid4().hex[:8]}" pkg_dir = Path(tmpdir) / pkg_name pkg_dir.mkdir() (pkg_dir / pkg_name).mkdir() (pkg_dir / pkg_name / "__init__.py").write_text("") (pkg_dir / pkg_name / "provider.py").write_text(textwrap.dedent(""" import langextract as lx from langextract.core import base_model from langextract.core import types USED_BY_EXTRACT = False class TestPipSchema(lx.schema.BaseSchema): '''Test schema for pip provider.''' def __init__(self, config): self._config = config @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): return cls({"pip_schema": True, "examples": len(examples_data)}) def to_provider_config(self): return {"schema_config": self._config} @property def requires_raw_output(self): return True @lx.providers.registry.register(r'^test-pip-model', priority=50) class TestPipProvider(base_model.BaseLanguageModel): def __init__(self, model_id, **kwargs): super().__init__() self.model_id = model_id self.schema_config = kwargs.get("schema_config", {}) @classmethod def get_schema_class(cls): return TestPipSchema def infer(self, batch_prompts, **kwargs): global USED_BY_EXTRACT USED_BY_EXTRACT = True schema_info = "with_schema" if self.schema_config else "no_schema" return [[types.ScoredOutput(score=1.0, output=f"pip test response: {schema_info}")]] """)) (pkg_dir / "pyproject.toml").write_text(textwrap.dedent(f""" [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "{pkg_name}" version = "0.0.1" description = "Test plugin for langextract" [project.entry-points."langextract.providers"] test_provider = "{pkg_name}.provider:TestPipProvider" """)) pip_env = { **os.environ, "PIP_NO_INPUT": "1", "PIP_DISABLE_PIP_VERSION_CHECK": "1", } result = subprocess.run( [ sys.executable, "-m", "pip", "install", "-e", str(pkg_dir), "--no-deps", "-q", ], check=True, capture_output=True, text=True, env=pip_env, ) self.assertEqual(result.returncode, 0, "pip install failed") self.assertNotIn( "ERROR", result.stderr.upper(), f"pip install had errors: {result.stderr}", ) try: test_script = Path(tmpdir) / "test_plugin.py" test_script.write_text(textwrap.dedent(f""" import langextract as lx import sys lx.providers.load_plugins_once() # Test 1: Basic usage without schema cfg = lx.factory.ModelConfig(model_id="test-pip-model-123") model = lx.factory.create_model(cfg) result = model.infer(["test prompt"]) assert "no_schema" in result[0][0].output, f"Got: {{result[0][0].output}}" # Test 2: With schema constraints examples = [ lx.data.ExampleData( text="test", extractions=[ lx.data.Extraction( extraction_class="test", extraction_text="test", ) ], ) ] cfg2 = lx.factory.ModelConfig(model_id="test-pip-model-456") model2 = lx.factory._create_model_with_schema( config=cfg2, examples=examples, use_schema_constraints=True, fence_output=None, ) result2 = model2.infer(["test prompt"]) assert "with_schema" in result2[0][0].output, f"Got: {{result2[0][0].output}}" assert model2.requires_fence_output == False, "Schema outputs raw JSON, should not need fences" # Test 3: Verify schema class is available provider_cls = lx.providers.registry.resolve("test-pip-model-xyz") assert provider_cls.__name__ == "TestPipProvider", "Plugin should be resolvable" schema_cls = provider_cls.get_schema_class() assert schema_cls.__name__ == "TestPipSchema", f"Schema class should be TestPipSchema, got {{schema_cls.__name__}}" from {pkg_name}.provider import USED_BY_EXTRACT assert USED_BY_EXTRACT, "Provider infer() was not called" print("SUCCESS: Plugin test with schema passed") """)) result = subprocess.run( [sys.executable, str(test_script)], capture_output=True, text=True, check=False, ) self.assertIn( "SUCCESS", result.stdout, f"Test failed. stdout: {result.stdout}, stderr: {result.stderr}", ) finally: subprocess.run( [sys.executable, "-m", "pip", "uninstall", "-y", pkg_name], check=False, capture_output=True, env=pip_env, ) lx.providers.registry.clear() lx.providers._plugins_loaded = False lx.providers.load_plugins_once() with self.assertRaisesRegex( lx.exceptions.InferenceConfigError, "No provider registered for model_id='test-pip-model", ): lx.providers.registry.resolve("test-pip-model-789") if __name__ == "__main__": absltest.main() ================================================ FILE: tests/provider_schema_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for provider schema discovery and implementations.""" from unittest import mock from absl.testing import absltest from langextract import exceptions from langextract import factory from langextract import schema import langextract as lx from langextract.core import data from langextract.providers import gemini from langextract.providers import ollama from langextract.providers import openai from langextract.providers import schemas class ProviderSchemaDiscoveryTest(absltest.TestCase): """Tests for provider schema discovery via get_schema_class().""" def test_gemini_returns_gemini_schema(self): """Test that GeminiLanguageModel returns GeminiSchema.""" schema_class = gemini.GeminiLanguageModel.get_schema_class() self.assertEqual( schema_class, schemas.gemini.GeminiSchema, msg="GeminiLanguageModel should return GeminiSchema class", ) def test_ollama_returns_format_mode_schema(self): """Test that OllamaLanguageModel returns FormatModeSchema.""" schema_class = ollama.OllamaLanguageModel.get_schema_class() self.assertEqual( schema_class, schema.FormatModeSchema, msg="OllamaLanguageModel should return FormatModeSchema class", ) def test_openai_returns_none(self): """Test that OpenAILanguageModel returns None (no schema support yet).""" # OpenAI imports dependencies in __init__, not at module level schema_class = openai.OpenAILanguageModel.get_schema_class() self.assertIsNone( schema_class, msg="OpenAILanguageModel should return None (no schema support)", ) class FormatModeSchemaTest(absltest.TestCase): """Tests for FormatModeSchema implementation.""" def test_from_examples_ignores_examples(self): """Test that FormatModeSchema ignores examples and returns JSON mode.""" examples_data = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="test_class", extraction_text="test extraction", attributes={"key": "value"}, ) ], ) ] test_schema = schema.FormatModeSchema.from_examples(examples_data) self.assertEqual( test_schema._format, "json", msg="FormatModeSchema should default to JSON format", ) def test_to_provider_config_returns_format(self): """Test that to_provider_config returns format parameter.""" examples_data = [] test_schema = schema.FormatModeSchema.from_examples(examples_data) provider_config = test_schema.to_provider_config() self.assertEqual( provider_config, {"format": "json"}, msg="Provider config should contain format: json", ) def test_requires_raw_output_returns_true(self): """Test that FormatModeSchema requires raw output for JSON.""" examples_data = [] test_schema = schema.FormatModeSchema.from_examples(examples_data) self.assertTrue( test_schema.requires_raw_output, msg="FormatModeSchema with JSON should require raw output", ) def test_different_examples_same_output(self): """Test that different examples produce the same schema for Ollama.""" examples1 = [ data.ExampleData( text="Text 1", extractions=[ data.Extraction( extraction_class="class1", extraction_text="text1" ) ], ) ] examples2 = [ data.ExampleData( text="Text 2", extractions=[ data.Extraction( extraction_class="class2", extraction_text="text2", attributes={"attr": "value"}, ) ], ) ] schema1 = schema.FormatModeSchema.from_examples(examples1) schema2 = schema.FormatModeSchema.from_examples(examples2) # Examples are ignored by FormatModeSchema self.assertEqual( schema1.to_provider_config(), schema2.to_provider_config(), msg="Different examples should produce same config for Ollama", ) class OllamaFormatParameterTest(absltest.TestCase): """Tests for Ollama format parameter handling.""" def test_ollama_json_format_in_request_payload(self): """Test that JSON format is passed to Ollama API by default.""" with mock.patch("requests.post", autospec=True) as mock_post: mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id="test-model", format_type=data.FormatType.JSON, ) list(model.infer(["Test prompt"])) mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] payload = call_kwargs["json"] self.assertEqual(payload["format"], "json", msg="Format should be json") self.assertEqual( payload["model"], "test-model", msg="Model ID should match" ) self.assertEqual( payload["prompt"], "Test prompt", msg="Prompt should match" ) self.assertFalse(payload["stream"], msg="Stream should be False") def test_ollama_default_format_is_json(self): """Test that JSON is the default format when not specified.""" with mock.patch("requests.post", autospec=True) as mock_post: mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel(model_id="test-model") list(model.infer(["Test prompt"])) mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] payload = call_kwargs["json"] self.assertEqual( payload["format"], "json", msg="Default format should be json" ) def test_extract_with_ollama_passes_json_format(self): """Test that lx.extract() correctly passes JSON format to Ollama API.""" with mock.patch("requests.post", autospec=True) as mock_post: mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = { "response": ( '{"extractions": [{"extraction_class": "test", "extraction_text":' ' "example"}]}' ) } mock_post.return_value = mock_response # Mock the registry to return OllamaLanguageModel with mock.patch("langextract.providers.registry.resolve") as mock_resolve: mock_resolve.return_value = ollama.OllamaLanguageModel examples = [ data.ExampleData( text="Sample text", extractions=[ data.Extraction( extraction_class="test", extraction_text="sample", ) ], ) ] result = lx.extract( text_or_documents="Test document", prompt_description="Extract test information", examples=examples, model_id="gemma2:2b", model_url="http://localhost:11434", format_type=data.FormatType.JSON, use_schema_constraints=True, ) mock_post.assert_called() last_call = mock_post.call_args_list[-1] payload = last_call[1]["json"] self.assertEqual( payload["format"], "json", msg="Format should be json in extract() call", ) self.assertEqual( payload["model"], "gemma2:2b", msg="Model ID should match" ) self.assertIsNotNone(result) self.assertIsInstance(result, data.AnnotatedDocument) class OllamaYAMLOverrideTest(absltest.TestCase): """Tests for Ollama YAML format override behavior.""" def test_ollama_yaml_format_in_request_payload(self): """Test that YAML format override appears in Ollama request payload.""" with mock.patch("requests.post", autospec=True) as mock_post: mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": '{"extractions": []}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel(model_id="gemma2:2b", format="yaml") list(model.infer(["Test prompt"])) mock_post.assert_called_once() call_kwargs = mock_post.call_args[1] self.assertIn( "json", call_kwargs, msg="Request should use json parameter" ) payload = call_kwargs["json"] self.assertIn("format", payload, msg="Payload should contain format key") self.assertEqual(payload["format"], "yaml", msg="Format should be yaml") def test_yaml_override_sets_fence_output_true(self): """Test that overriding to YAML format sets fence_output to True.""" examples_data = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="test_class", extraction_text="test extraction", ) ], ) ] with mock.patch("requests.post", autospec=True) as mock_post: mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": '{"extractions": []}'} mock_post.return_value = mock_response with mock.patch("langextract.providers.registry.resolve") as mock_resolve: mock_resolve.return_value = ollama.OllamaLanguageModel config = factory.ModelConfig( model_id="gemma2:2b", provider_kwargs={"format": "yaml"}, ) model = factory.create_model( config=config, examples=examples_data, use_schema_constraints=True, fence_output=None, # Let it be computed ) self.assertTrue( model.requires_fence_output, msg="YAML format should require fences" ) def test_json_format_keeps_fence_output_false(self): """Test that JSON format keeps fence_output False.""" examples_data = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="test_class", extraction_text="test extraction", ) ], ) ] with mock.patch("requests.post", autospec=True) as mock_post: mock_response = mock.Mock(spec=["status_code", "json"]) mock_response.status_code = 200 mock_response.json.return_value = {"response": '{"extractions": []}'} mock_post.return_value = mock_response with mock.patch("langextract.providers.registry.resolve") as mock_resolve: mock_resolve.return_value = ollama.OllamaLanguageModel config = factory.ModelConfig( model_id="gemma2:2b", provider_kwargs={"format": "json"}, ) model = factory.create_model( config=config, examples=examples_data, use_schema_constraints=True, fence_output=None, # Let it be computed ) self.assertFalse( model.requires_fence_output, msg="JSON format should not require fences", ) class GeminiSchemaProviderIntegrationTest(absltest.TestCase): """Tests for GeminiSchema provider integration.""" def test_gemini_schema_to_provider_config(self): """Test that GeminiSchema.to_provider_config includes response_schema.""" examples_data = [ data.ExampleData( text="Patient has diabetes", extractions=[ data.Extraction( extraction_class="condition", extraction_text="diabetes", attributes={"severity": "moderate"}, ) ], ) ] gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) provider_config = gemini_schema.to_provider_config() self.assertIn( "response_schema", provider_config, msg="GeminiSchema config should contain response_schema", ) self.assertIsInstance( provider_config["response_schema"], dict, msg="response_schema should be a dictionary", ) self.assertIn( "properties", provider_config["response_schema"], msg="response_schema should contain properties field", ) self.assertIn( "response_mime_type", provider_config, msg="GeminiSchema config should contain response_mime_type", ) self.assertEqual( provider_config["response_mime_type"], "application/json", msg="response_mime_type should be application/json", ) def test_gemini_requires_raw_output(self): """Test that GeminiSchema requires raw output.""" examples_data = [] gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) self.assertTrue( gemini_schema.requires_raw_output, msg="GeminiSchema should require raw output", ) def test_gemini_rejects_yaml_with_schema(self): """Test that Gemini raises error when YAML format is used with schema.""" examples_data = [ data.ExampleData( text="Test", extractions=[ data.Extraction( extraction_class="test", extraction_text="test text", ) ], ) ] test_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) with mock.patch("google.genai.Client", autospec=True): model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="test_key", format_type=data.FormatType.YAML, ) model.apply_schema(test_schema) prompt = "Test prompt" config = {"temperature": 0.5} with self.assertRaises(exceptions.InferenceRuntimeError) as cm: _ = model._process_single_prompt(prompt, config) self.assertIn( "only supports JSON format", str(cm.exception), msg="Error should mention JSON-only constraint", ) def test_gemini_forwards_schema_to_genai_client(self): """Test that GeminiLanguageModel forwards schema config to genai client.""" examples_data = [ data.ExampleData( text="Test", extractions=[ data.Extraction( extraction_class="test", extraction_text="test text", ) ], ) ] test_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) with mock.patch("google.genai.Client", autospec=True) as mock_client: mock_model_instance = mock.Mock(spec=["return_value"]) mock_client.return_value.models.generate_content = mock_model_instance mock_model_instance.return_value.text = '{"extractions": []}' model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="test_key", response_schema=test_schema.schema_dict, response_mime_type="application/json", ) prompt = "Test prompt" config = {"temperature": 0.5} _ = model._process_single_prompt(prompt, config) mock_model_instance.assert_called_once() call_kwargs = mock_model_instance.call_args[1] self.assertIn( "config", call_kwargs, msg="genai.generate_content should receive config parameter", ) self.assertIn( "response_schema", call_kwargs["config"], msg="Config should contain response_schema from GeminiSchema", ) self.assertIn( "response_mime_type", call_kwargs["config"], msg="Config should contain response_mime_type", ) self.assertEqual( call_kwargs["config"]["response_mime_type"], "application/json", msg="response_mime_type should be application/json", ) def test_gemini_doesnt_forward_non_api_kwargs(self): """Test that GeminiLanguageModel doesn't forward non-API kwargs to genai.""" with mock.patch("google.genai.Client", autospec=True) as mock_client: mock_model_instance = mock.Mock(spec=["return_value"]) mock_client.return_value.models.generate_content = mock_model_instance mock_model_instance.return_value.text = '{"extractions": []}' model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", api_key="test_key", max_workers=5, response_schema={"test": "schema"}, # API parameter ) prompt = "Test prompt" config = {"temperature": 0.5} _ = model._process_single_prompt(prompt, config) mock_model_instance.assert_called_once() call_kwargs = mock_model_instance.call_args[1] self.assertNotIn( "max_workers", call_kwargs["config"], msg="max_workers should not be forwarded to genai API config", ) self.assertIn( "response_schema", call_kwargs["config"], msg="response_schema should be forwarded to genai API config", ) class SchemaShimTest(absltest.TestCase): """Tests for backward compatibility shims in schema module.""" def test_constraint_types_import(self): """Test that Constraint and ConstraintType can be imported.""" from langextract import schema as lx_schema # pylint: disable=reimported,import-outside-toplevel constraint = lx_schema.Constraint() self.assertEqual( constraint.constraint_type, lx_schema.ConstraintType.NONE, msg="Default Constraint should have type NONE", ) self.assertEqual( lx_schema.ConstraintType.NONE.value, "none", msg="ConstraintType.NONE should have value 'none'", ) def test_provider_schema_imports(self): """Test that provider schemas can be imported from schema module.""" from langextract import schema as lx_schema # pylint: disable=reimported,import-outside-toplevel # Backward compatibility: re-exported from providers.schemas.gemini self.assertTrue( hasattr(lx_schema, "GeminiSchema"), msg=( "GeminiSchema should be importable from schema module for backward" " compatibility" ), ) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/registry_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the provider registry module. Note: This file tests the deprecated registry module which is now an alias for router. The no-name-in-module warning for providers.registry is expected. Test helper classes also intentionally have few public methods. """ # pylint: disable=no-name-in-module import re from absl.testing import absltest from langextract import exceptions from langextract.core import base_model from langextract.core import types from langextract.providers import router class FakeProvider(base_model.BaseLanguageModel): """Fake provider for testing.""" def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="test")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) class AnotherFakeProvider(base_model.BaseLanguageModel): """Another fake provider for testing.""" def infer(self, batch_prompts, **kwargs): return [[types.ScoredOutput(score=1.0, output="another")]] def infer_batch(self, prompts, batch_size=32): return self.infer(prompts) class RegistryTest(absltest.TestCase): def setUp(self): super().setUp() router.clear() def tearDown(self): super().tearDown() router.clear() def test_register_decorator(self): """Test registering a provider using the decorator.""" @router.register(r"^test-model") class TestProvider(FakeProvider): pass resolved = router.resolve("test-model-v1") self.assertEqual(resolved, TestProvider) def test_register_lazy(self): """Test lazy registration with string target.""" # Use direct registration for test provider to avoid module path issues router.register(r"^fake-model")(FakeProvider) resolved = router.resolve("fake-model-v2") self.assertEqual(resolved, FakeProvider) def test_multiple_patterns(self): """Test registering multiple patterns for one provider.""" # Use direct registration to avoid module path issues in Bazel router.register(r"^gemini", r"^palm")(FakeProvider) self.assertEqual(router.resolve("gemini-pro"), FakeProvider) self.assertEqual(router.resolve("palm-2"), FakeProvider) def test_priority_resolution(self): """Test that higher priority wins on conflicts.""" # Use direct registration to avoid module path issues in Bazel router.register(r"^model", priority=0)(FakeProvider) router.register(r"^model", priority=10)(AnotherFakeProvider) resolved = router.resolve("model-v1") self.assertEqual(resolved, AnotherFakeProvider) def test_no_provider_registered(self): """Test error when no provider matches.""" with self.assertRaisesRegex( exceptions.InferenceConfigError, "No provider registered for model_id='unknown-model'", ): router.resolve("unknown-model") def test_caching(self): """Test that resolve results are cached.""" # Use direct registration for test provider to avoid module path issues router.register(r"^cached")(FakeProvider) # First call result1 = router.resolve("cached-model") # Second call should return cached result result2 = router.resolve("cached-model") self.assertIs(result1, result2) def test_clear_registry(self): """Test clearing the router.""" # Use direct registration for test provider to avoid module path issues router.register(r"^temp")(FakeProvider) # Should resolve before clear resolved = router.resolve("temp-model") self.assertEqual(resolved, FakeProvider) # Clear registry router.clear() # Should fail after clear with self.assertRaises(exceptions.InferenceConfigError): router.resolve("temp-model") def test_list_entries(self): """Test listing registered entries.""" router.register_lazy(r"^test1", target="fake:Target1", priority=5) router.register_lazy( r"^test2", r"^test3", target="fake:Target2", priority=10 ) entries = router.list_entries() self.assertEqual(len(entries), 2) patterns1, priority1 = entries[0] self.assertEqual(patterns1, ["^test1"]) self.assertEqual(priority1, 5) patterns2, priority2 = entries[1] self.assertEqual(set(patterns2), {"^test2", "^test3"}) self.assertEqual(priority2, 10) def test_lazy_loading_defers_import(self): """Test that lazy registration doesn't import until resolve.""" # Register with a module that would fail if imported router.register_lazy(r"^lazy", target="non.existent.module:Provider") # Registration should succeed without importing entries = router.list_entries() self.assertTrue(any("^lazy" in patterns for patterns, _ in entries)) # Only on resolve should it try to import and fail with self.assertRaises(ModuleNotFoundError): router.resolve("lazy-model") def test_regex_pattern_objects(self): """Test using pre-compiled regex patterns.""" pattern = re.compile(r"^custom-\d+") @router.register(pattern) class CustomProvider(FakeProvider): pass self.assertEqual(router.resolve("custom-123"), CustomProvider) # Should not match without digits with self.assertRaises(exceptions.InferenceConfigError): router.resolve("custom-abc") def test_resolve_provider_by_name(self): """Test resolving provider by exact name.""" @router.register(r"^test-model", r"^TestProvider$") class TestProvider(FakeProvider): pass # Resolve by exact class name pattern provider = router.resolve_provider("TestProvider") self.assertEqual(provider, TestProvider) # Resolve by partial name match provider = router.resolve_provider("test") self.assertEqual(provider, TestProvider) def test_resolve_provider_not_found(self): """Test resolve_provider raises for unknown provider.""" with self.assertRaises(exceptions.InferenceConfigError) as cm: router.resolve_provider("UnknownProvider") self.assertIn("No provider found matching", str(cm.exception)) def test_hf_style_model_id_patterns(self): """Test that Hugging Face style model ID patterns work. This addresses issue #129 where HF-style model IDs like 'meta-llama/Llama-3.2-1B-Instruct' weren't being recognized. """ @router.register( r"^meta-llama/[Ll]lama", r"^google/gemma", r"^mistralai/[Mm]istral", r"^microsoft/phi", r"^Qwen/", r"^TinyLlama/", priority=100, ) class TestHFProvider(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): return [] hf_model_ids = [ "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/llama-2-7b", "google/gemma-2b", "mistralai/Mistral-7B-v0.1", "microsoft/phi-3-mini", "Qwen/Qwen2.5-7B", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ] for model_id in hf_model_ids: with self.subTest(model_id=model_id): provider_class = router.resolve(model_id) self.assertEqual(provider_class, TestHFProvider) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/resolver_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import textwrap from typing import Sequence from absl.testing import absltest from absl.testing import parameterized from langextract import chunking from langextract import resolver as resolver_lib from langextract.core import data from langextract.core import tokenizer def assert_char_interval_match_source( test_case: absltest.TestCase, source_text: str, extractions: Sequence[data.Extraction], ): """Asserts that the char_interval of matched extractions matches the source text. Args: test_case: The TestCase instance. source_text: The original source text. extractions: A sequence of extractions to check. """ for extraction in extractions: if extraction.alignment_status == data.AlignmentStatus.MATCH_EXACT: assert ( extraction.char_interval is not None ), "char_interval should not be None for AlignmentStatus.MATCH_EXACT" char_int = extraction.char_interval start = char_int.start_pos end = char_int.end_pos test_case.assertIsNotNone(start, "start_pos should not be None") test_case.assertIsNotNone(end, "end_pos should not be None") extracted = source_text[start:end] test_case.assertEqual( extracted.lower(), extraction.extraction_text.lower(), f"Extraction '{extraction.extraction_text}' does not match extracted" f" '{extracted}' using char_interval {char_int}", ) class ParserTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name="json_invalid_input", resolver=resolver_lib.Resolver( format_type=data.FormatType.JSON, fence_output=True, strict_fences=True, ), input_text="invalid input", expected_exception=resolver_lib.ResolverParsingError, expected_regex=".*fence markers.*", ), dict( testcase_name="json_missing_markers", resolver=resolver_lib.Resolver( format_type=data.FormatType.JSON, fence_output=True, strict_fences=True, ), input_text='[{"key": "value"}]', expected_exception=resolver_lib.ResolverParsingError, expected_regex=".*fence markers.*", ), dict( testcase_name="json_empty_string", resolver=resolver_lib.Resolver( format_type=data.FormatType.JSON, fence_output=True, ), input_text="", expected_exception=ValueError, expected_regex=".*must be a non-empty string.*", ), dict( testcase_name="json_partial_markers", resolver=resolver_lib.Resolver( format_type=data.FormatType.JSON, fence_output=True, strict_fences=True, ), input_text='```json\n{"key": "value"', expected_exception=resolver_lib.ResolverParsingError, expected_regex=".*fence markers.*", ), dict( testcase_name="yaml_invalid_input", resolver=resolver_lib.Resolver( format_type=data.FormatType.YAML, fence_output=True, strict_fences=True, ), input_text="invalid input", expected_exception=resolver_lib.ResolverParsingError, expected_regex=".*fence markers.*", ), dict( testcase_name="yaml_missing_markers", resolver=resolver_lib.Resolver( format_type=data.FormatType.YAML, fence_output=True, strict_fences=True, ), input_text='[{"key": "value"}]', expected_exception=resolver_lib.ResolverParsingError, expected_regex=".*fence markers.*", ), dict( testcase_name="yaml_empty_content", resolver=resolver_lib.Resolver( format_type=data.FormatType.YAML, fence_output=True, ), input_text="```yaml\n```", expected_exception=resolver_lib.ResolverParsingError, expected_regex=( ".*Content must be a mapping with an" f" '{data.EXTRACTIONS_KEY}' key.*" ), ), ) def test_parser_error_cases( self, resolver, input_text, expected_exception, expected_regex ): with self.assertRaisesRegex(expected_exception, expected_regex): resolver.string_to_extraction_data(input_text) class ExtractOrderedEntitiesTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name="valid_input", test_input=[ { "medication": "Naprosyn", "medication_index": 4, "frequency": "as needed", "frequency_index": 5, "reason": "pain", "reason_index": 8, }, { "medication": "prednisone", "medication_index": 5, "frequency": "daily", "frequency_index": 1, }, ], expected_output=[ data.Extraction( extraction_class="frequency", extraction_text="daily", extraction_index=1, group_index=1, ), data.Extraction( extraction_class="medication", extraction_text="Naprosyn", extraction_index=4, group_index=0, ), data.Extraction( extraction_class="frequency", extraction_text="as needed", extraction_index=5, group_index=0, ), data.Extraction( extraction_class="medication", extraction_text="prednisone", extraction_index=5, group_index=1, ), data.Extraction( extraction_class="reason", extraction_text="pain", extraction_index=8, group_index=0, ), ], ), dict( testcase_name="empty_input", test_input=[], expected_output=[], ), dict( testcase_name="mixed_index_order", test_input=[ { "medication": "Ibuprofen", "medication_index": 2, "dosage": "400mg", "dosage_index": 1, }, { "medication": "Acetaminophen", "medication_index": 1, "duration": "7 days", "duration_index": 2, }, ], expected_output=[ data.Extraction( extraction_class="dosage", extraction_text="400mg", extraction_index=1, group_index=0, ), data.Extraction( extraction_class="medication", extraction_text="Acetaminophen", extraction_index=1, group_index=1, ), data.Extraction( extraction_class="medication", extraction_text="Ibuprofen", extraction_index=2, group_index=0, ), data.Extraction( extraction_class="duration", extraction_text="7 days", extraction_index=2, group_index=1, ), ], ), dict( testcase_name="missing_index_key", test_input=[{ "medication": "Aspirin", "dosage": "325mg", "dosage_index": 1, }], expected_output=[ data.Extraction( extraction_class="dosage", extraction_text="325mg", extraction_index=1, group_index=0, ), ], ), dict( testcase_name="all_indices_missing", test_input=[ {"medication": "Aspirin", "dosage": "325mg"}, {"medication": "Ibuprofen", "dosage": "400mg"}, ], expected_output=[], ), dict( testcase_name="single_element_dictionaries", test_input=[ {"medication": "Aspirin", "medication_index": 1}, {"medication": "Ibuprofen", "medication_index": 2}, ], expected_output=[ data.Extraction( extraction_class="medication", extraction_text="Aspirin", extraction_index=1, group_index=0, ), data.Extraction( extraction_class="medication", extraction_text="Ibuprofen", extraction_index=2, group_index=1, ), ], ), dict( testcase_name="duplicate_indices_unchanged", test_input=[{ "medication": "Aspirin", "medication_index": 1, "dosage": "325mg", "dosage_index": 1, "form": "tablet", "form_index": 1, }], expected_output=[ data.Extraction( extraction_class="medication", extraction_text="Aspirin", extraction_index=1, group_index=0, ), data.Extraction( extraction_class="dosage", extraction_text="325mg", extraction_index=1, group_index=0, ), data.Extraction( extraction_class="form", extraction_text="tablet", extraction_index=1, group_index=0, ), ], ), dict( testcase_name="negative_indices", test_input=[{ "medication": "Aspirin", "medication_index": -1, "dosage": "325mg", "dosage_index": -2, }], expected_output=[ data.Extraction( extraction_class="dosage", extraction_text="325mg", extraction_index=-2, group_index=0, ), data.Extraction( extraction_class="medication", extraction_text="Aspirin", extraction_index=-1, group_index=0, ), ], ), dict( testcase_name="index_without_data_key_ignored", test_input=[{ "medication_index": 1, "dosage": "325mg", "dosage_index": 2, }], expected_output=[ data.Extraction( extraction_class="dosage", extraction_text="325mg", extraction_index=2, group_index=0, ), ], ), dict( testcase_name="no_index_suffix", resolver=resolver_lib.Resolver( extraction_index_suffix=None, format_type=data.FormatType.JSON, ), test_input=[ {"medication": "Aspirin"}, {"medication": "Ibuprofen"}, {"dosage": "325mg"}, {"dosage": "400mg"}, ], expected_output=[ data.Extraction( extraction_class="medication", extraction_text="Aspirin", extraction_index=1, group_index=0, ), data.Extraction( extraction_class="medication", extraction_text="Ibuprofen", extraction_index=2, group_index=1, ), data.Extraction( extraction_class="dosage", extraction_text="325mg", extraction_index=3, group_index=2, ), data.Extraction( extraction_class="dosage", extraction_text="400mg", extraction_index=4, group_index=3, ), ], ), dict( testcase_name="attributes_suffix", resolver=resolver_lib.Resolver( extraction_index_suffix=None, format_type=data.FormatType.JSON, ), test_input=[ { "patient": "Jane Doe", "patient_attributes": { "PERSON": "True", "IDENTIFIABLE": "True", }, }, { "medication": "Lisinopril", "medication_attributes": { "THERAPEUTIC": "True", "CLINICAL": "True", }, }, ], expected_output=[ data.Extraction( extraction_class="patient", extraction_text="Jane Doe", extraction_index=1, group_index=0, attributes={ "PERSON": "True", "IDENTIFIABLE": "True", }, ), data.Extraction( extraction_class="medication", extraction_text="Lisinopril", extraction_index=2, group_index=1, attributes={ "THERAPEUTIC": "True", "CLINICAL": "True", }, ), ], ), dict( testcase_name="indices_and_attributes", test_input=[ { "patient": "John Doe", "patient_index": 2, "patient_attributes": { "IDENTIFIABLE": "True", }, "condition": "hypertension", "condition_index": 1, "condition_attributes": { "CHRONIC_CONDITION": "True", "REQUIRES_MANAGEMENT": "True", }, }, { "medication": "Lisinopril", "medication_index": 3, "medication_attributes": { "ANTIHYPERTENSIVE_MEDICATION": "True", "DAILY_USE": "True", }, "dosage": "10mg", "dosage_index": 4, "dosage_attributes": { "STANDARD_DAILY_DOSE": "True", }, }, ], expected_output=[ data.Extraction( extraction_class="condition", extraction_text="hypertension", extraction_index=1, group_index=0, attributes={ "CHRONIC_CONDITION": "True", "REQUIRES_MANAGEMENT": "True", }, ), data.Extraction( extraction_class="patient", extraction_text="John Doe", extraction_index=2, group_index=0, attributes={ "IDENTIFIABLE": "True", }, ), data.Extraction( extraction_class="medication", extraction_text="Lisinopril", extraction_index=3, group_index=1, attributes={ "ANTIHYPERTENSIVE_MEDICATION": "True", "DAILY_USE": "True", }, ), data.Extraction( extraction_class="dosage", extraction_text="10mg", extraction_index=4, group_index=1, attributes={ "STANDARD_DAILY_DOSE": "True", }, ), ], ), ) def test_extract_ordered_extractions_success( self, test_input, resolver=None, expected_output=None, ): if resolver is None: resolver = resolver_lib.Resolver( extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX ) actual_output = resolver.extract_ordered_extractions(test_input) self.assertEqual(actual_output, expected_output) @parameterized.named_parameters( dict( testcase_name="non_integer_indices", resolver=resolver_lib.Resolver( format_type=data.FormatType.JSON, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), test_input=[{ "medication": "Aspirin", "medication_index": "first", "dosage": "325mg", "dosage_index": "second", }], expected_exception=ValueError, expected_regex=".*must be an integer.*", ), dict( testcase_name="float_indices", resolver=resolver_lib.Resolver( format_type=data.FormatType.JSON, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), test_input=[{"medication": "Aspirin", "medication_index": 1.0}], expected_exception=ValueError, expected_regex=".*must be an integer.*", ), ) def test_extract_ordered_extractions_exceptions( self, resolver, test_input, expected_exception, expected_regex ): with self.assertRaisesRegex(expected_exception, expected_regex): resolver.extract_ordered_extractions(test_input) class AlignEntitiesTest(parameterized.TestCase): _SOURCE_TEXT_TWO_MEDS = ( "Patient is prescribed Naprosyn and prednisone for treatment." ) _SOURCE_TEXT_THREE_CONDITIONS_AND_MEDS = ( "Patient with arthritis, fever, and inflammation is prescribed" " Naprosyn, prednisone, and ibuprofen." ) _SOURCE_TEXT_MULTI_WORD_EXTRACTIONS = ( "Pt was prescribed Naprosyn as needed for pain and prednisone for" " one month." ) def setUp(self): super().setUp() self.aligner = resolver_lib.WordAligner() self.maxDiff = 10000 @parameterized.named_parameters( ( "basic_alignment", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", ) ], ], _SOURCE_TEXT_TWO_MEDS, [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=22, end_pos=30), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", token_interval=tokenizer.TokenInterval( start_index=5, end_index=6 ), char_interval=data.CharInterval(start_pos=35, end_pos=45), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "shuffled_order_of_last_two_extractions", [ [ data.Extraction( extraction_class="condition", extraction_text="arthritis" ) ], [ data.Extraction( extraction_class="condition", extraction_text="fever" ) ], [ data.Extraction( extraction_class="condition", extraction_text="inflammation", ) ], [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="medication", extraction_text="ibuprofen" ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", ) ], ], _SOURCE_TEXT_THREE_CONDITIONS_AND_MEDS, # Indexes Aligned with Tokens # -------------------------------------------------------------------- # Index | 0 1 2 3 4 5 6 # Token | Patient with arthritis , fever , and # -------------------------------------------------------------------- # Index | 7 8 9 # Token | inflammation is prescribed # -------------------------------------------------------------------- # Index | 10 11 12 13 14 15 # Token | Naprosyn , prednisone , and ibuprofen # -------------------------------------------------------------------- # Index | 16 # Token | . [ [ data.Extraction( extraction_class="condition", extraction_text="arthritis", token_interval=tokenizer.TokenInterval( start_index=2, end_index=3 ), char_interval=data.CharInterval(start_pos=13, end_pos=22), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="condition", extraction_text="fever", token_interval=tokenizer.TokenInterval( start_index=4, end_index=5 ), char_interval=data.CharInterval(start_pos=24, end_pos=29), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="condition", extraction_text="inflammation", token_interval=tokenizer.TokenInterval( start_index=7, end_index=8 ), char_interval=data.CharInterval(start_pos=35, end_pos=47), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=10, end_index=11 ), char_interval=data.CharInterval(start_pos=62, end_pos=70), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="ibuprofen", token_interval=None, char_interval=None, alignment_status=None, ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", token_interval=tokenizer.TokenInterval( start_index=12, end_index=13 ), char_interval=data.CharInterval(start_pos=72, end_pos=82), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "extraction_not_found", [[ data.Extraction( extraction_class="medication", extraction_text="aspirin" ) ]], _SOURCE_TEXT_TWO_MEDS, [[ data.Extraction( extraction_class="medication", extraction_text="aspirin", char_interval=None, ) ]], ), ( "multiple_word_extraction_partially_matched", [[ data.Extraction( extraction_class="condition", extraction_text="high blood pressure", ) ]], "Patient is prescribed high glucose.", [[ data.Extraction( extraction_class="condition", extraction_text="high blood pressure", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), alignment_status=data.AlignmentStatus.MATCH_LESSER, char_interval=data.CharInterval(start_pos=22, end_pos=26), ) ]], ), ( "optimize_multiword_extractions_at_back", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn and prednisone", ) ], ], _SOURCE_TEXT_TWO_MEDS, [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=None, char_interval=None, alignment_status=None, ) ], [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn and prednisone", token_interval=tokenizer.TokenInterval( start_index=3, end_index=6 ), char_interval=data.CharInterval(start_pos=22, end_pos=45), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "optimize_multiword_extractions_at_front", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn and prednisone", ) ], [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], ], _SOURCE_TEXT_TWO_MEDS, [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn and prednisone", token_interval=tokenizer.TokenInterval( start_index=3, end_index=6 ), char_interval=data.CharInterval(start_pos=22, end_pos=45), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", char_interval=None, ) ], ], ), ( "test_en_dash_unicode_handling", [ [ data.Extraction( extraction_class="word", extraction_text="Separated" ) ], [data.Extraction(extraction_class="word", extraction_text="by")], [ data.Extraction( extraction_class="word", extraction_text="en–dashes" ) ], ], "Separated–by–en–dashes.", [ [ data.Extraction( extraction_class="word", extraction_text="Separated", token_interval=tokenizer.TokenInterval( start_index=0, end_index=1 ), char_interval=data.CharInterval(start_pos=0, end_pos=9), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="word", extraction_text="by", token_interval=tokenizer.TokenInterval( start_index=2, end_index=3 ), char_interval=data.CharInterval(start_pos=10, end_pos=12), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="word", extraction_text="en–dashes", token_interval=tokenizer.TokenInterval( start_index=4, end_index=7 ), char_interval=data.CharInterval(start_pos=13, end_pos=22), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "empty_source_text", [[ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ]], "", ValueError, ), ( "special_characters_in_extractions", [[ data.Extraction( extraction_class="medication", extraction_text="Napro-syn" ) ]], "Patient is prescribed Napro-syn.", [ [ data.Extraction( extraction_class="medication", extraction_text="Napro-syn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=6 ), char_interval=data.CharInterval(start_pos=22, end_pos=31), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_extraction_with_substring_of_another_not_matched", [[ data.Extraction( extraction_class="medication", extraction_text="Napro" ) ]], _SOURCE_TEXT_TWO_MEDS, [[ data.Extraction( extraction_class="medication", extraction_text="Napro", char_interval=None, ) ]], ), ( "test_empty_extractions_list", [], _SOURCE_TEXT_TWO_MEDS, [], ), ( "test_extractions_with_similar_words", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="medication", extraction_text="Napro" ) ], ], _SOURCE_TEXT_TWO_MEDS, [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=22, end_pos=30), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="Napro", char_interval=None, ) ], ], ), ( "test_source_text_with_repeated_extractions", [[ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ]], "Patient is prescribed Naprosyn and Naprosyn.", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=22, end_pos=30), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_interleaved_extractions", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="condition", extraction_text="arthritis" ) ], ], "Patient with arthritis is prescribed Naprosyn.", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", char_interval=None, ) ], [ data.Extraction( extraction_class="condition", extraction_text="arthritis", token_interval=tokenizer.TokenInterval( start_index=2, end_index=3 ), char_interval=data.CharInterval(start_pos=13, end_pos=22), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "overlapping_extractions_different_types", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="condition", extraction_text="Naprosyn allergy", ) ], ], _SOURCE_TEXT_TWO_MEDS, [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=22, end_pos=30), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="condition", extraction_text="Naprosyn allergy", char_interval=None, ) ], ], ), ( "test_overlapping_text_extractions_with_overlapping_source", [ [ data.Extraction( extraction_class="condition", extraction_text="high blood" ) ], [ data.Extraction( extraction_class="condition", extraction_text="blood pressure", ) ], ], "Patient has high blood pressure.", [ [ data.Extraction( extraction_class="condition", extraction_text="high blood", token_interval=tokenizer.TokenInterval( start_index=2, end_index=4 ), char_interval=data.CharInterval(start_pos=12, end_pos=22), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="condition", extraction_text="blood pressure", char_interval=None, ) ], ], ), ( "test_multiple_instances_same_extraction", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", ) ], ], "Naprosyn, prednisone, and again Naprosyn are prescribed.", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=0, end_index=1 ), char_interval=data.CharInterval(start_pos=0, end_pos=8), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", token_interval=tokenizer.TokenInterval( start_index=2, end_index=3 ), char_interval=data.CharInterval(start_pos=10, end_pos=20), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_longer_extraction_spanning_multiple_words", [[ data.Extraction( extraction_class="condition", extraction_text="rheumatoid arthritis", ) ]], "Patient diagnosed with rheumatoid arthritis.", [ [ data.Extraction( extraction_class="condition", extraction_text="rheumatoid arthritis", token_interval=tokenizer.TokenInterval( start_index=3, end_index=5 ), char_interval=data.CharInterval(start_pos=23, end_pos=43), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_case_insensitivity", [ [ data.Extraction( extraction_class="medication", extraction_text="naprosyn" ) ], [ data.Extraction( extraction_class="medication", extraction_text="PREDNISONE", ) ], ], _SOURCE_TEXT_TWO_MEDS.lower(), [ [ data.Extraction( extraction_class="medication", extraction_text="naprosyn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=22, end_pos=30), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="PREDNISONE", token_interval=tokenizer.TokenInterval( start_index=5, end_index=6 ), char_interval=data.CharInterval(start_pos=35, end_pos=45), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "numerical_extractions", [[ data.Extraction( extraction_class="medication", extraction_text="Ibuprofen 600mg", ) ]], "Patient was given Ibuprofen 600mg twice daily.", [ [ data.Extraction( extraction_class="medication", extraction_text="Ibuprofen 600mg", token_interval=tokenizer.TokenInterval( start_index=3, end_index=6 ), char_interval=data.CharInterval(start_pos=18, end_pos=33), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_extractions_spanning_across_sentence_boundaries", [ [ data.Extraction( extraction_class="medication", extraction_text="Ibuprofen" ) ], [ data.Extraction( extraction_class="instruction", extraction_text="take with food", ) ], ], "Take Ibuprofen. Always take with food.", [ [ data.Extraction( extraction_class="medication", extraction_text="Ibuprofen", token_interval=tokenizer.TokenInterval( start_index=1, end_index=2 ), char_interval=data.CharInterval(start_pos=5, end_pos=14), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="instruction", extraction_text="take with food", token_interval=tokenizer.TokenInterval( start_index=4, end_index=7 ), char_interval=data.CharInterval(start_pos=23, end_pos=37), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_multiple_multiword_extractions_multi_group", [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn" ) ], [ data.Extraction( extraction_class="frequency", extraction_text="as needed" ) ], [ data.Extraction( extraction_class="reason", extraction_text="pain" ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", ) ], [ data.Extraction( extraction_class="duration", extraction_text="for one month", ) ], ], _SOURCE_TEXT_MULTI_WORD_EXTRACTIONS, [ [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=4 ), char_interval=data.CharInterval(start_pos=18, end_pos=26), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="frequency", extraction_text="as needed", token_interval=tokenizer.TokenInterval( start_index=4, end_index=6 ), char_interval=data.CharInterval(start_pos=27, end_pos=36), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="reason", extraction_text="pain", token_interval=tokenizer.TokenInterval( start_index=7, end_index=8 ), char_interval=data.CharInterval(start_pos=41, end_pos=45), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", token_interval=tokenizer.TokenInterval( start_index=9, end_index=10 ), char_interval=data.CharInterval(start_pos=50, end_pos=60), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="duration", extraction_text="for one month", token_interval=tokenizer.TokenInterval( start_index=10, end_index=13 ), char_interval=data.CharInterval(start_pos=61, end_pos=74), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "extraction_with_tokenizing_pipe_delimiter", [ [ data.Extraction( extraction_class="medication", extraction_text="Napro | syn", ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", ) ], ], "Patient is prescribed Napro | syn and prednisone.", [ [ data.Extraction( extraction_class="medication", extraction_text="Napro | syn", token_interval=tokenizer.TokenInterval( start_index=3, end_index=6 ), char_interval=data.CharInterval(start_pos=22, end_pos=33), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], [ data.Extraction( extraction_class="medication", extraction_text="prednisone", token_interval=tokenizer.TokenInterval( start_index=7, end_index=8 ), char_interval=data.CharInterval(start_pos=38, end_pos=48), alignment_status=data.AlignmentStatus.MATCH_EXACT, ) ], ], ), ( "test_only_matching_end_does_not_align", [ [ data.Extraction( extraction_class="some_class", extraction_text="only matched end", ) ], ], "end", [[ data.Extraction( extraction_class="some_class", extraction_text="only matched end", char_interval=None, alignment_status=None, ) ]], ), dict( testcase_name="fuzzy_alignment_success", # Tests fuzzy alignment alongside exact matching. Shows different alignment statuses: # "heart problems" gets fuzzy match, "severe heart problems complications" gets lesser match. # Demonstrates both fuzzy and lesser matching working with 75% threshold. extractions=[ [ data.Extraction( extraction_class="condition", extraction_text="heart problems", ) ], [ data.Extraction( extraction_class="condition", extraction_text="severe heart problems complications", ) ], ], source_text="Patient has severe heart problems today.", expected_output=[ [ data.Extraction( extraction_class="condition", extraction_text="heart problems", token_interval=tokenizer.TokenInterval( start_index=3, end_index=5 ), char_interval=data.CharInterval(start_pos=19, end_pos=33), alignment_status=data.AlignmentStatus.MATCH_FUZZY, ) ], [ data.Extraction( extraction_class="condition", extraction_text="severe heart problems complications", token_interval=tokenizer.TokenInterval( start_index=2, end_index=5 ), char_interval=data.CharInterval(start_pos=12, end_pos=33), alignment_status=data.AlignmentStatus.MATCH_LESSER, ) ], ], enable_fuzzy_alignment=True, ), dict( testcase_name="fuzzy_alignment_below_threshold", # Tests fuzzy alignment failure when overlap ratio < _FUZZY_ALIGNMENT_MIN_THRESHOLD (75%). # No tokens overlap between "completely different medicine" and "Patient takes aspirin daily." extractions=[ [ data.Extraction( extraction_class="medication", extraction_text="completely different medicine", ) ], ], source_text="Patient takes aspirin daily.", expected_output=[[ data.Extraction( extraction_class="medication", extraction_text="completely different medicine", char_interval=None, alignment_status=None, ) ]], enable_fuzzy_alignment=True, ), dict( testcase_name="accept_match_lesser_disabled", # Tests accept_match_lesser=False with fuzzy fallback. extractions=[ [ data.Extraction( extraction_class="condition", extraction_text="patient heart problems today", ) ], ], source_text="Patient has heart problems today.", expected_output=[[ data.Extraction( extraction_class="condition", extraction_text="patient heart problems today", token_interval=tokenizer.TokenInterval( start_index=0, end_index=5 ), char_interval=data.CharInterval(start_pos=0, end_pos=32), alignment_status=data.AlignmentStatus.MATCH_FUZZY, ) ]], enable_fuzzy_alignment=True, accept_match_lesser=False, ), dict( testcase_name="fuzzy_alignment_subset_window", # Extraction is a subset of a longer source clause; ensures extra tokens do not penalise score. extractions=[[ data.Extraction( extraction_class="tendon", extraction_text="The iliopsoas tendon is intact", ) ]], source_text=( "The iliopsoas and proximal hamstring tendons are intact." ), expected_output=[[ data.Extraction( extraction_class="tendon", extraction_text="The iliopsoas tendon is intact", token_interval=tokenizer.TokenInterval( start_index=0, end_index=8 ), char_interval=data.CharInterval(start_pos=0, end_pos=55), alignment_status=data.AlignmentStatus.MATCH_FUZZY, ) ]], enable_fuzzy_alignment=True, accept_match_lesser=False, ), dict( testcase_name="fuzzy_alignment_with_reordered_words", # Tests fuzzy alignment's ability to handle reordered words in the extraction. extractions=[[ data.Extraction( extraction_class="condition", extraction_text="problems heart", # Reordered words char_interval=data.CharInterval(start_pos=12, end_pos=33), alignment_status=data.AlignmentStatus.MATCH_FUZZY, ) ]], source_text="Patient has severe heart problems today.", expected_output=[[ data.Extraction( extraction_class="condition", extraction_text="problems heart", # The best matching window in the source is "severe heart problems" token_interval=tokenizer.TokenInterval( start_index=2, end_index=5 ), char_interval=data.CharInterval(start_pos=12, end_pos=33), alignment_status=data.AlignmentStatus.MATCH_FUZZY, ) ]], enable_fuzzy_alignment=True, ), dict( testcase_name="fuzzy_alignment_fails_low_ratio", # An extraction that partially overlaps but is below the fuzzy threshold should not be aligned. extractions=[[ data.Extraction( extraction_class="symptom", extraction_text="headache and fever", ) ]], source_text="Patient reports back pain and a fever.", expected_output=[[ data.Extraction( extraction_class="symptom", extraction_text="headache and fever", char_interval=None, alignment_status=None, ) ]], enable_fuzzy_alignment=True, ), dict( testcase_name="fuzzy_alignment_partial_overlap_success", # An extraction where the number of matched tokens divided by total extraction tokens # is >= the threshold (3/4 = 0.75). extractions=[[ data.Extraction( extraction_class="finding", extraction_text="mild degenerative disc disease", ) ]], source_text=( "Findings consistent with degenerative disc disease at L5-S1." ), expected_output=[[ data.Extraction( extraction_class="finding", extraction_text="mild degenerative disc disease", # The best window found is "degenerative disc disease" token_interval=tokenizer.TokenInterval( start_index=3, end_index=6 ), char_interval=data.CharInterval(start_pos=20, end_pos=50), alignment_status=data.AlignmentStatus.MATCH_FUZZY, ) ]], enable_fuzzy_alignment=True, ), ) def test_extraction_alignment( self, extractions: Sequence[Sequence[data.Extraction]], source_text: str, expected_output: Sequence[Sequence[data.Extraction]] | ValueError, enable_fuzzy_alignment: bool = False, accept_match_lesser: bool = True, ): if expected_output is ValueError: with self.assertRaises(ValueError): self.aligner.align_extractions( extractions, source_text, enable_fuzzy_alignment=False ) else: aligned_extraction_groups = self.aligner.align_extractions( extractions, source_text, enable_fuzzy_alignment=enable_fuzzy_alignment, accept_match_lesser=accept_match_lesser, ) flattened_extractions = [] for group in aligned_extraction_groups: flattened_extractions.extend(group) assert_char_interval_match_source( self, source_text, flattened_extractions ) self.assertEqual(aligned_extraction_groups, expected_output) class ResolverTest(parameterized.TestCase): _TWO_MEDICATIONS_JSON_UNDELIMITED = textwrap.dedent(f"""\ {{ "{data.EXTRACTIONS_KEY}": [ {{ "medication": "Naprosyn", "medication_index": 4, "frequency": "as needed", "frequency_index": 5, "reason": "pain", "reason_index": 8 }}, {{ "medication": "prednisone", "medication_index": 9, "duration": "for one month", "duration_index": 10 }} ] }}""") _TWO_MEDICATIONS_YAML_UNDELIMITED = textwrap.dedent(f"""\ {data.EXTRACTIONS_KEY}: - medication: "Naprosyn" medication_index: 4 frequency: "as needed" frequency_index: 5 reason: "pain" reason_index: 8 - medication: "prednisone" medication_index: 9 duration: "for one month" duration_index: 10 """) _EXPECTED_TWO_MEDICATIONS_ANNOTATED = [ data.Extraction( extraction_class="medication", extraction_text="Naprosyn", extraction_index=4, group_index=0, ), data.Extraction( extraction_class="frequency", extraction_text="as needed", extraction_index=5, group_index=0, ), data.Extraction( extraction_class="reason", extraction_text="pain", extraction_index=8, group_index=0, ), data.Extraction( extraction_class="medication", extraction_text="prednisone", extraction_index=9, group_index=1, ), data.Extraction( extraction_class="duration", extraction_text="for one month", extraction_index=10, group_index=1, ), ] def setUp(self): super().setUp() self.default_resolver = resolver_lib.Resolver( format_type=data.FormatType.JSON, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ) @parameterized.named_parameters( dict( testcase_name="json_with_fence", resolver=resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.JSON, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), input_text=textwrap.dedent(f"""\ ```json {{ "{data.EXTRACTIONS_KEY}": [ {{ "medication": "Naprosyn", "medication_index": 4, "frequency": "as needed", "frequency_index": 5, "reason": "pain", "reason_index": 8 }}, {{ "medication": "prednisone", "medication_index": 9, "duration": "for one month", "duration_index": 10 }} ] }} ```"""), expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED, ), dict( testcase_name="yaml_with_fence", resolver=resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), input_text=textwrap.dedent(f"""\ ```yaml {data.EXTRACTIONS_KEY}: - medication: "Naprosyn" medication_index: 4 frequency: "as needed" frequency_index: 5 reason: "pain" reason_index: 8 - medication: "prednisone" medication_index: 9 duration: "for one month" duration_index: 10 ```"""), expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED, ), dict( testcase_name="json_no_fence", resolver=resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.JSON, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), input_text=_TWO_MEDICATIONS_JSON_UNDELIMITED, expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED, ), dict( testcase_name="yaml_no_fence", resolver=resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.YAML, extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, ), input_text=_TWO_MEDICATIONS_YAML_UNDELIMITED, expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED, ), ) def test_resolve_valid_inputs(self, resolver, input_text, expected_output): actual_extractions = resolver.resolve(input_text) self.assertCountEqual(expected_output, actual_extractions) assert_char_interval_match_source(self, input_text, actual_extractions) def test_handle_integer_extraction(self): test_input = textwrap.dedent(f"""\ ```json {{ "{data.EXTRACTIONS_KEY}": [ {{ "year": 2006, "year_index": 6 }} ] }} ```""") expected_extractions = [ data.Extraction( extraction_class="year", extraction_text="2006", extraction_index=6, group_index=0, ) ] actual_extractions = self.default_resolver.resolve(test_input) self.assertEqual(expected_extractions, list(actual_extractions)) def test_resolve_empty_yaml(self): test_input = "```json\n```" actual = self.default_resolver.resolve( test_input, suppress_parse_errors=True ) self.assertEmpty(actual) def test_resolve_empty_yaml_without_suppress_parse_errors(self): test_input = "```json\n```" with self.assertRaises(resolver_lib.ResolverParsingError): self.default_resolver.resolve(test_input, suppress_parse_errors=False) def test_align_with_valid_chunk(self): text = "This is a sample text with some extractions." tokenized_text = tokenizer.tokenize(text) chunk = tokenizer.TokenInterval(start_index=0, end_index=8) annotated_extractions = [ data.Extraction( extraction_class="medication", extraction_text="sample" ), data.Extraction( extraction_class="condition", extraction_text="extractions" ), ] expected_extractions = [ data.Extraction( extraction_class="medication", extraction_text="sample", token_interval=tokenizer.TokenInterval(start_index=3, end_index=4), char_interval=data.CharInterval(start_pos=10, end_pos=16), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="condition", extraction_text="extractions", token_interval=tokenizer.TokenInterval(start_index=7, end_index=8), char_interval=data.CharInterval(start_pos=32, end_pos=43), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), ] chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index aligned_extractions = list( self.default_resolver.align( extractions=annotated_extractions, source_text=chunk_text, token_offset=token_offset, char_offset=0, enable_fuzzy_alignment=False, ) ) self.assertEqual(len(aligned_extractions), len(expected_extractions)) for expected, actual in zip(expected_extractions, aligned_extractions): self.assertDataclassEqual(expected, actual) assert_char_interval_match_source(self, text, aligned_extractions) def test_align_with_chunk_starting_in_middle(self): text = "This is a sample text with some extractions." tokenized_text = tokenizer.tokenize(text) chunk = tokenizer.TokenInterval(start_index=3, end_index=8) annotated_extractions = [ data.Extraction( extraction_class="medication", extraction_text="sample" ), data.Extraction( extraction_class="condition", extraction_text="extractions" ), ] expected_extractions = [ data.Extraction( extraction_class="medication", extraction_text="sample", token_interval=tokenizer.TokenInterval(start_index=3, end_index=4), char_interval=data.CharInterval(start_pos=10, end_pos=16), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), data.Extraction( extraction_class="condition", extraction_text="extractions", token_interval=tokenizer.TokenInterval(start_index=7, end_index=8), char_interval=data.CharInterval(start_pos=32, end_pos=43), alignment_status=data.AlignmentStatus.MATCH_EXACT, ), ] chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index # Compute global char offset from the token at chunk.start_index. char_offset = tokenized_text.tokens[ chunk.start_index ].char_interval.start_pos aligned_extractions = list( self.default_resolver.align( extractions=annotated_extractions, source_text=chunk_text, token_offset=token_offset, char_offset=char_offset, enable_fuzzy_alignment=False, ) ) self.assertEqual(len(aligned_extractions), len(expected_extractions)) for expected, actual in zip(expected_extractions, aligned_extractions): self.assertDataclassEqual(expected, actual) assert_char_interval_match_source(self, text, aligned_extractions) def test_align_with_no_extractions_in_chunk(self): tokenized_text = tokenizer.tokenize("No extractions here.") # Define a chunk that includes the entire text. chunk = tokenizer.TokenInterval() chunk.start_index = 0 chunk.end_index = 3 annotated_extractions = [] chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index aligned_extractions = list( self.default_resolver.align( extractions=annotated_extractions, source_text=chunk_text, token_offset=token_offset, char_offset=0, enable_fuzzy_alignment=False, ) ) self.assertEmpty(aligned_extractions) def test_align_successful(self): tokenized_text = tokenizer.TokenizedText( text="zero one two", tokens=[ tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=0, end_pos=4), index=0, ), tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=5, end_pos=8), index=1, ), tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=9, end_pos=12), index=2, ), ], ) # Define a chunk that includes the entire text. chunk = tokenizer.TokenInterval(start_index=0, end_index=3) annotated_extractions = [ data.Extraction(extraction_class="foo", extraction_text="zero"), data.Extraction(extraction_class="foo", extraction_text="one"), ] chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index aligned_extractions = list( self.default_resolver.align( extractions=annotated_extractions, source_text=chunk_text, token_offset=token_offset, char_offset=0, enable_fuzzy_alignment=False, ) ) self.assertLen(aligned_extractions, 2) assert_char_interval_match_source( self, tokenized_text.text, aligned_extractions ) def test_align_with_discontinuous_tokenized_text(self): tokenized_text = tokenizer.TokenizedText( text="zero one five", tokens=[ tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=0, end_pos=4), index=0, ), tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=5, end_pos=8), index=1, ), tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=9, end_pos=14), index=5, ), ], ) # Define a chunk that includes too many tokens. chunk = tokenizer.TokenInterval(start_index=0, end_index=6) annotated_extractions = [ data.Extraction(extraction_class="foo", extraction_text="zero"), data.Extraction(extraction_class="foo", extraction_text="one"), ] with self.assertRaises(tokenizer.InvalidTokenIntervalError): chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index list( self.default_resolver.align( annotated_extractions, chunk_text, token_offset, enable_fuzzy_alignment=False, ) ) def test_align_with_discontinuous_tokenized_text_but_right_chunk(self): tokenized_text = tokenizer.TokenizedText( text="zero one five", tokens=[ tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=0, end_pos=4), index=0, ), tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=5, end_pos=8), index=1, ), tokenizer.Token( token_type=tokenizer.TokenType.WORD, char_interval=tokenizer.CharInterval(start_pos=9, end_pos=14), index=5, ), ], ) # Define a correct chunk. chunk = tokenizer.TokenInterval(start_index=0, end_index=3) annotated_extractions = [ data.Extraction(extraction_class="foo", extraction_text="zero"), data.Extraction(extraction_class="foo", extraction_text="one"), ] chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index aligned_extractions = list( self.default_resolver.align( extractions=annotated_extractions, source_text=chunk_text, token_offset=token_offset, char_offset=0, enable_fuzzy_alignment=False, ) ) self.assertLen(aligned_extractions, 2) assert_char_interval_match_source( self, tokenized_text.text, aligned_extractions ) def test_align_with_empty_annotated_extractions(self): """Test align method with empty annotated_extractions sequence.""" tokenized_text = tokenizer.tokenize("No extractions here.") # Define a chunk that includes the entire text. chunk = tokenizer.TokenInterval() chunk.start_index = 0 chunk.end_index = 3 annotated_extractions = [] # Empty sequence representing no extractions chunk_text = chunking.get_token_interval_text(tokenized_text, chunk) token_offset = chunk.start_index aligned_extractions = list( self.default_resolver.align( extractions=annotated_extractions, source_text=chunk_text, token_offset=token_offset, char_offset=0, enable_fuzzy_alignment=False, ) ) self.assertEmpty(aligned_extractions) class FenceFallbackTest(parameterized.TestCase): """Tests for fence marker fallback behavior.""" @parameterized.named_parameters( dict( testcase_name="with_valid_fences", test_input=textwrap.dedent("""\ ```json { "extractions": [ {"person": "Marie Curie", "person_attributes": {"field": "physics"}} ] } ```"""), fence_output=True, strict_fences=False, expected_key="person", expected_value="Marie Curie", ), dict( testcase_name="fallback_no_fences", test_input=textwrap.dedent("""\ { "extractions": [ {"person": "Albert Einstein", "person_attributes": {"field": "physics"}} ] }"""), fence_output=True, strict_fences=False, expected_key="person", expected_value="Albert Einstein", ), dict( testcase_name="no_fence_expectation", test_input=textwrap.dedent("""\ { "extractions": [ {"drug": "Aspirin", "drug_attributes": {"dosage": "100mg"}} ] }"""), fence_output=False, strict_fences=False, expected_key="drug", expected_value="Aspirin", ), ) def test_parsing_scenarios( self, test_input, fence_output, strict_fences, expected_key, expected_value, ): resolver = resolver_lib.Resolver( fence_output=fence_output, format_type=data.FormatType.JSON, strict_fences=strict_fences, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 1) self.assertIn(expected_key, result[0]) self.assertEqual(result[0][expected_key], expected_value) def test_fallback_preserves_content_integrity(self): test_input = textwrap.dedent("""\ { "extractions": [ { "medication": "Ibuprofen", "medication_attributes": { "dosage": "200mg", "frequency": "twice daily" } }, { "condition": "headache", "condition_attributes": { "severity": "mild" } } ] }""") resolver = resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.JSON, strict_fences=False, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 2, "Should preserve all extractions during fallback") self.assertEqual( result[0]["medication"], "Ibuprofen", "First extraction should have correct medication", ) self.assertEqual( result[0]["medication_attributes"]["dosage"], "200mg", "Should preserve nested attributes in fallback", ) self.assertEqual( result[1]["condition"], "headache", "Second extraction should have correct condition", ) self.assertEqual( result[1]["condition_attributes"]["severity"], "mild", "Should preserve all nested attributes", ) def test_malformed_json_still_raises_error(self): test_input = textwrap.dedent("""\ { "extractions": [ {"person": "Missing closing brace" ]""") resolver = resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.JSON, strict_fences=False, ) with self.assertRaises(resolver_lib.ResolverParsingError): resolver.string_to_extraction_data(test_input) def test_strict_fences_raises_on_missing_markers(self): strict_resolver = resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.JSON, strict_fences=True, ) test_input = textwrap.dedent("""\ {"extractions": [{"person": "Test"}]}""") with self.assertRaisesRegex( resolver_lib.ResolverParsingError, ".*fence markers.*" ): strict_resolver.string_to_extraction_data(test_input) def test_default_allows_fallback(self): default_resolver = resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.JSON, ) test_input = textwrap.dedent("""\ {"extractions": [{"person": "Default Test"}]}""") result = default_resolver.string_to_extraction_data(test_input) self.assertLen(result, 1) self.assertEqual(result[0]["person"], "Default Test") def test_rejects_multiple_fenced_blocks(self): test_input = textwrap.dedent("""\ preamble ```json {"extractions": [{"item": "first"}]} ``` Some explanation text ```json {"extractions": [{"item": "second"}]} ```""") resolver = resolver_lib.Resolver( fence_output=True, format_type=data.FormatType.JSON, strict_fences=False, ) with self.assertRaisesRegex( resolver_lib.ResolverParsingError, "Multiple fenced blocks found" ): resolver.string_to_extraction_data(test_input) class FlexibleSchemaTest(parameterized.TestCase): """Tests for flexible schema formats without extractions key.""" def test_direct_list_format(self): test_input = textwrap.dedent("""\ [ {"person": "Marie Curie", "field": "physics"}, {"person": "Albert Einstein", "field": "relativity"} ]""") resolver = resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.JSON, require_extractions_key=False, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 2) self.assertEqual(result[0]["person"], "Marie Curie") self.assertEqual(result[1]["person"], "Albert Einstein") def test_single_dict_as_extraction(self): test_input = '{"person": "Isaac Newton", "field": "gravity"}' resolver = resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.JSON, require_extractions_key=False, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 1) self.assertEqual(result[0]["person"], "Isaac Newton") self.assertEqual(result[0]["field"], "gravity") def test_traditional_format_still_works(self): test_input = textwrap.dedent("""\ { "extractions": [ {"person": "Charles Darwin", "field": "evolution"} ] }""") resolver = resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.JSON, require_extractions_key=False, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 1) self.assertEqual(result[0]["person"], "Charles Darwin") def test_lenient_mode_accepts_list(self): # Some models return [...] instead of {"extractions": [...]} test_input = '[{"person": "Test"}]' resolver = resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.JSON, require_extractions_key=True, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 1) self.assertEqual(result[0]["person"], "Test") def test_flexible_with_attributes(self): test_input = textwrap.dedent("""\ [ { "medication": "Aspirin", "medication_attributes": {"dosage": "100mg", "frequency": "daily"} }, { "medication": "Ibuprofen", "medication_attributes": {"dosage": "200mg"} } ]""") resolver = resolver_lib.Resolver( fence_output=False, format_type=data.FormatType.JSON, require_extractions_key=False, ) result = resolver.string_to_extraction_data(test_input) self.assertLen(result, 2) self.assertEqual(result[0]["medication"], "Aspirin") self.assertEqual(result[0]["medication_attributes"]["dosage"], "100mg") self.assertEqual(result[1]["medication"], "Ibuprofen") if __name__ == "__main__": absltest.main() ================================================ FILE: tests/schema_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the schema module. Note: This file contains test helper classes that intentionally have few public methods. The too-few-public-methods warnings are expected. """ from unittest import mock import warnings from absl.testing import absltest from absl.testing import parameterized from langextract.core import base_model from langextract.core import data from langextract.core import format_handler as fh from langextract.core import schema from langextract.providers import schemas class BaseSchemaTest(absltest.TestCase): """Tests for BaseSchema abstract class.""" def test_abstract_methods_required(self): """Test that BaseSchema cannot be instantiated directly.""" with self.assertRaises(TypeError): schema.BaseSchema() # pylint: disable=abstract-class-instantiated def test_subclass_must_implement_all_methods(self): """Test that subclasses must implement all abstract methods.""" class IncompleteSchema(schema.BaseSchema): # pylint: disable=too-few-public-methods @classmethod def from_examples(cls, examples_data, attribute_suffix="_attributes"): return cls() with self.assertRaises(TypeError): IncompleteSchema() # pylint: disable=abstract-class-instantiated class BaseLanguageModelSchemaTest(absltest.TestCase): """Tests for BaseLanguageModel schema methods.""" def test_get_schema_class_returns_none_by_default(self): """Test that get_schema_class returns None by default.""" class TestModel(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): yield [] self.assertIsNone(TestModel.get_schema_class()) def test_apply_schema_stores_instance(self): """Test that apply_schema stores the schema instance.""" class TestModel(base_model.BaseLanguageModel): # pylint: disable=too-few-public-methods def infer(self, batch_prompts, **kwargs): yield [] model = TestModel() mock_schema = mock.Mock(spec=schema.BaseSchema) model.apply_schema(mock_schema) self.assertEqual(model._schema, mock_schema) model.apply_schema(None) self.assertIsNone(model._schema) class GeminiSchemaTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name="empty_extractions", examples_data=[], expected_schema={ "type": "object", "properties": { data.EXTRACTIONS_KEY: { "type": "array", "items": { "type": "object", "properties": {}, }, }, }, "required": [data.EXTRACTIONS_KEY], }, ), dict( testcase_name="single_extraction_no_attributes", examples_data=[ data.ExampleData( text="Patient has diabetes.", extractions=[ data.Extraction( extraction_text="diabetes", extraction_class="condition", ) ], ) ], expected_schema={ "type": "object", "properties": { data.EXTRACTIONS_KEY: { "type": "array", "items": { "type": "object", "properties": { "condition": {"type": "string"}, "condition_attributes": { "type": "object", "properties": { "_unused": {"type": "string"}, }, "nullable": True, }, }, }, }, }, "required": [data.EXTRACTIONS_KEY], }, ), dict( testcase_name="single_extraction", examples_data=[ data.ExampleData( text="Patient has diabetes.", extractions=[ data.Extraction( extraction_text="diabetes", extraction_class="condition", attributes={"chronicity": "chronic"}, ) ], ) ], expected_schema={ "type": "object", "properties": { data.EXTRACTIONS_KEY: { "type": "array", "items": { "type": "object", "properties": { "condition": {"type": "string"}, "condition_attributes": { "type": "object", "properties": { "chronicity": {"type": "string"}, }, "nullable": True, }, }, }, }, }, "required": [data.EXTRACTIONS_KEY], }, ), dict( testcase_name="multiple_extraction_classes", examples_data=[ data.ExampleData( text="Patient has diabetes.", extractions=[ data.Extraction( extraction_text="diabetes", extraction_class="condition", attributes={"chronicity": "chronic"}, ) ], ), data.ExampleData( text="Patient is John Doe", extractions=[ data.Extraction( extraction_text="John Doe", extraction_class="patient", attributes={"id": "12345"}, ) ], ), ], expected_schema={ "type": "object", "properties": { data.EXTRACTIONS_KEY: { "type": "array", "items": { "type": "object", "properties": { "condition": {"type": "string"}, "condition_attributes": { "type": "object", "properties": { "chronicity": {"type": "string"} }, "nullable": True, }, "patient": {"type": "string"}, "patient_attributes": { "type": "object", "properties": { "id": {"type": "string"}, }, "nullable": True, }, }, }, }, }, "required": [data.EXTRACTIONS_KEY], }, ), ) def test_from_examples_constructs_expected_schema( self, examples_data, expected_schema ): gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) actual_schema = gemini_schema.schema_dict self.assertEqual(actual_schema, expected_schema) def test_to_provider_config_returns_response_schema(self): """Test that to_provider_config returns the correct provider kwargs.""" examples_data = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="test_class", extraction_text="test extraction", ) ], ) ] gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) provider_config = gemini_schema.to_provider_config() self.assertIn("response_schema", provider_config) self.assertEqual( provider_config["response_schema"], gemini_schema.schema_dict ) def test_requires_raw_output_returns_true(self): """Test that GeminiSchema requires raw output.""" examples_data = [ data.ExampleData( text="Test text", extractions=[ data.Extraction( extraction_class="test_class", extraction_text="test extraction", ) ], ) ] gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data) self.assertTrue(gemini_schema.requires_raw_output) class SchemaValidationTest(parameterized.TestCase): """Tests for schema format validation.""" def _create_test_schema(self): """Helper to create a test schema.""" examples = [ data.ExampleData( text="Test", extractions=[ data.Extraction( extraction_class="entity", extraction_text="test", ) ], ) ] return schemas.gemini.GeminiSchema.from_examples(examples) @parameterized.named_parameters( dict( testcase_name="warns_about_fences", use_fences=True, use_wrapper=True, wrapper_key=data.EXTRACTIONS_KEY, expected_warning="fence_output=True may cause parsing issues", ), dict( testcase_name="warns_about_wrong_wrapper_key", use_fences=False, use_wrapper=True, wrapper_key="wrong_key", expected_warning="response_schema expects wrapper_key='extractions'", ), dict( testcase_name="no_warning_with_correct_settings", use_fences=False, use_wrapper=True, wrapper_key=data.EXTRACTIONS_KEY, expected_warning=None, ), ) def test_gemini_validation( self, use_fences, use_wrapper, wrapper_key, expected_warning ): """Test GeminiSchema validation with various settings.""" schema_obj = self._create_test_schema() format_handler = fh.FormatHandler( format_type=data.FormatType.JSON, use_fences=use_fences, use_wrapper=use_wrapper, wrapper_key=wrapper_key, ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") schema_obj.validate_format(format_handler) if expected_warning: self.assertLen( w, 1, f"Expected exactly one warning containing '{expected_warning}'", ) self.assertIn( expected_warning, str(w[0].message), f"Warning message should contain '{expected_warning}'", ) else: self.assertEmpty(w, "No warnings should be issued for correct settings") def test_base_schema_no_validation(self): """Test that base schema has no validation by default.""" schema_obj = schema.FormatModeSchema() format_handler = fh.FormatHandler( format_type=data.FormatType.JSON, use_fences=True, ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") schema_obj.validate_format(format_handler) self.assertEmpty( w, "FormatModeSchema should not issue validation warnings" ) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/test_gemini_batch_api.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for Gemini Batch API functionality.""" import io import json from unittest import mock from absl.testing import absltest from absl.testing import parameterized from google import genai from google.api_core import exceptions from langextract.providers import gemini from langextract.providers import gemini_batch as gb from langextract.providers import schemas def create_mock_batch_job( state=genai.types.JobState.JOB_STATE_SUCCEEDED, gcs_uri=f"gs://bucket/output/file{gb._EXT_JSONL}", ): """Create a mock BatchJob for testing.""" job = mock.create_autospec(genai.types.BatchJob, instance=True) job.name = "batches/123" job.state = state job.dest = mock.create_autospec( genai.types.BatchJobDestination, instance=True ) job.dest.gcs_uri = gcs_uri return job def _create_batch_response(idx, text_content): """Helper to create a batch output line with response.""" if not isinstance(text_content, str): text_content = json.dumps(text_content, separators=(",", ":")) return json.dumps({ "key": f"{gb._KEY_IDX}{idx}", "response": { "candidates": [{"content": {"parts": [{"text": text_content}]}}] }, }) def _create_batch_error(idx, code, message): """Helper to create a batch output line with error.""" return json.dumps({ "key": f"{gb._KEY_IDX}{idx}", "error": {"code": code, "message": message}, }) class TestGeminiBatchAPI(absltest.TestCase): """Test Gemini Batch API routing and functionality.""" def setUp(self): super().setUp() self.mock_storage_cls = self.enter_context( mock.patch.object(gb.storage, "Client", autospec=True) ) self.mock_storage_client = self.mock_storage_cls.return_value self.mock_bucket = self.mock_storage_client.bucket.return_value self.mock_blob = self.mock_bucket.blob.return_value @mock.patch.object(genai, "Client", autospec=True) def test_batch_routing_vertex(self, mock_client_cls): """Test that batch API is used when enabled and threshold is met (Vertex).""" mock_client = mock_client_cls.return_value mock_client.vertexai = True self.mock_storage_client.create_bucket.return_value = self.mock_bucket output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = "output.jsonl" # Mock blob.open context manager output_blob.open.return_value.__enter__.return_value = io.StringIO( "\n".join([ _create_batch_response(0, {"ok": 1}), _create_batch_response(1, {"ok": 2}), ]) ) self.mock_bucket.list_blobs.return_value = [output_blob] mock_client.batches.create.return_value = create_mock_batch_job() mock_client.batches.get.return_value = create_mock_batch_job() model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="test-project", location=gb._DEFAULT_LOCATION, batch={ "enabled": True, "threshold": 2, "poll_interval": 1, "enable_caching": False, "retention_days": None, }, ) prompts = ["p1", "p2"] outs = list(model.infer(prompts)) self.assertLen(outs, 2) self.assertEqual(outs[0][0].output, '{"ok":1}') self.assertEqual(outs[1][0].output, '{"ok":2}') self.mock_blob.upload_from_filename.assert_called() mock_client.batches.create.assert_called() @mock.patch.object(genai, "Client", autospec=True) def test_realtime_when_disabled(self, mock_client_cls): """Test that real-time API is used when batch is disabled.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True mock_response = mock.create_autospec( genai.types.GenerateContentResponse, instance=True ) mock_response.text = '{"ok":1}' mock_client.models.generate_content.return_value = mock_response model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", batch={"enabled": False}, ) outs = list(model.infer(["hello"])) self.assertLen(outs, 1) self.assertEqual(outs[0][0].output, '{"ok":1}') mock_client.models.generate_content.assert_called() mock_client.batches.create.assert_not_called() @mock.patch.object(genai, "Client", autospec=True) def test_realtime_when_below_threshold(self, mock_client_cls): """Test that real-time API is used when prompt count is below threshold.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True mock_response = mock.create_autospec( genai.types.GenerateContentResponse, instance=True ) mock_response.text = '{"ok":1}' mock_client.models.generate_content.return_value = mock_response model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", batch={ "enabled": True, "threshold": 10, "enable_caching": False, "retention_days": None, }, ) outs = list(model.infer(["hello"])) self.assertLen(outs, 1) self.assertEqual(outs[0][0].output, '{"ok":1}') mock_client.models.generate_content.assert_called() mock_client.batches.create.assert_not_called() @mock.patch.object(genai, "Client", autospec=True) def test_batch_with_schema(self, mock_client_cls): """Test that batch API properly includes schema when configured.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = f"output{gb._EXT_JSONL}" output_blob.open.return_value.__enter__.return_value = io.StringIO( _create_batch_response(0, {"name": "test"}) ) self.mock_bucket.list_blobs.return_value = [output_blob] mock_client.batches.create.return_value = create_mock_batch_job() mock_client.batches.get.return_value = create_mock_batch_job() mock_schema = mock.create_autospec( schemas.gemini.GeminiSchema, instance=True ) mock_schema.schema_dict = { "type": "object", "properties": {"name": {"type": "string"}}, } model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", gemini_schema=mock_schema, batch={ "enabled": True, "threshold": 1, "enable_caching": False, "retention_days": None, }, ) # Mock _submit_file to verify the request payload contains the schema. with mock.patch.object(gb, "_submit_file", autospec=True) as mock_submit: mock_submit.return_value = create_mock_batch_job() outs = list(model.infer(["test prompt"])) self.assertLen(outs, 1) self.assertEqual(outs[0][0].output, '{"name":"test"}') # Verify _submit_file was called with project and location parameters. mock_submit.assert_called_with( mock_client, "gemini-2.5-flash", [{ "contents": [ {"role": "user", "parts": [{"text": "test prompt"}]} ], "generationConfig": { "responseMimeType": "application/json", "responseSchema": mock_schema.schema_dict, "temperature": 0.0, }, }], mock.ANY, # Display name contains timestamp/random. None, # retention_days "p", # project "l", # location ) self.assertEqual(model.gemini_schema.schema_dict, mock_schema.schema_dict) @mock.patch.object(genai, "Client", autospec=True) def test_batch_error_handling(self, mock_client_cls): """Test that batch errors are properly handled and raised.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True mock_client.batches.create.side_effect = Exception("Batch API error") model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", batch={ "enabled": True, "threshold": 1, "enable_caching": False, "retention_days": None, }, ) with self.assertRaisesRegex(Exception, "Gemini Batch API error"): list(model.infer(["test prompt"])) @mock.patch.object(genai, "Client", autospec=True) def test_file_based_ordering(self, mock_client_cls): """Test that file-based results are returned in correct order.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True # Define inputs and expected outputs prompts = ["prompt 0", "prompt 1", "prompt 2"] # Simulate shuffled response in the file output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = f"output{gb._EXT_JSONL}" output_blob.open.return_value.__enter__.return_value = io.StringIO( "\n".join([ _create_batch_response(2, "response 2"), _create_batch_response(0, "response 0"), _create_batch_response(1, "response 1"), ]) ) self.mock_bucket.list_blobs.return_value = [output_blob] job = create_mock_batch_job() mock_client.batches.create.return_value = job mock_client.batches.get.return_value = job model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", batch={ "enabled": True, "threshold": 1, "enable_caching": False, "retention_days": None, }, ) results = list(model.infer(prompts)) # Verify results are in original order despite shuffled response self.assertListEqual( [r[0].output for r in results], ["response 0", "response 1", "response 2"], ) @mock.patch.object(genai, "Client", autospec=True) def test_max_prompts_per_job(self, mock_client_cls): """Test that requests are split into multiple batch jobs when they exceed max_prompts_per_job. This verifies that: 1. Large requests are chunked correctly based on the limit. 2. Multiple batch jobs are submitted. 3. Results are aggregated and returned in the correct order. """ mock_client = mock_client_cls.return_value mock_client.vertexai = True # Define inputs and expected behavior prompts = ["p1", "p2", "p3", "p4", "p5"] max_prompts_per_job = 2 # Expected chunks: ["p1", "p2"], ["p3", "p4"], ["p5"] # Setup mock storage and blobs for 3 separate jobs blob0 = mock.create_autospec(gb.storage.Blob, instance=True) blob0.name = f"out0{gb._EXT_JSONL}" blob0.open.return_value.__enter__.return_value = io.StringIO( "\n".join([ _create_batch_response(0, "r1"), _create_batch_response(1, "r2"), ]) ) blob1 = mock.create_autospec(gb.storage.Blob, instance=True) blob1.name = f"out1{gb._EXT_JSONL}" blob1.open.return_value.__enter__.return_value = io.StringIO( "\n".join([ _create_batch_response(0, "r3"), _create_batch_response(1, "r4"), ]) ) blob2 = mock.create_autospec(gb.storage.Blob, instance=True) blob2.name = f"out2{gb._EXT_JSONL}" blob2.open.return_value.__enter__.return_value = io.StringIO( _create_batch_response(0, "r5") ) def list_blobs_side_effect(prefix=None): if "part-0" in prefix: return [blob0] if "part-1" in prefix: return [blob1] if "part-2" in prefix: return [blob2] return [] self.mock_bucket.list_blobs.side_effect = list_blobs_side_effect # Setup mock jobs job0 = create_mock_batch_job(gcs_uri="gs://b/batch-input/part-0/out") job1 = create_mock_batch_job(gcs_uri="gs://b/batch-input/part-1/out") job2 = create_mock_batch_job(gcs_uri="gs://b/batch-input/part-2/out") mock_client.batches.create.side_effect = [job0, job1, job2] mock_client.batches.get.side_effect = [job0, job1, job2] model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", batch={ "enabled": True, "threshold": 1, "max_prompts_per_job": max_prompts_per_job, "enable_caching": False, "retention_days": None, }, ) results = list(model.infer(prompts)) self.assertEqual(mock_client.batches.create.call_count, 3) self.assertListEqual( [r[0].output for r in results], ["r1", "r2", "r3", "r4", "r5"] ) @mock.patch.object(genai, "Client", autospec=True) def test_batch_item_error(self, mock_client_cls): """Test that batch item errors raise exception.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = f"output{gb._EXT_JSONL}" output_blob.open.return_value.__enter__.return_value = io.StringIO( _create_batch_error(0, 13, "Internal error") ) self.mock_bucket.list_blobs.return_value = [output_blob] job = create_mock_batch_job() mock_client.batches.create.return_value = job mock_client.batches.get.return_value = job model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project="p", location="l", batch={ "enabled": True, "threshold": 1, "enable_caching": False, "retention_days": None, }, ) with self.assertRaisesRegex(Exception, "Batch item error"): list(model.infer(["test"])) class BatchConfigValidationTest(parameterized.TestCase): """Test BatchConfig validation logic.""" @parameterized.named_parameters( dict(testcase_name="threshold_lt_1", threshold=0), dict(testcase_name="poll_interval_le_0", poll_interval=0), dict(testcase_name="timeout_le_0", timeout=0), dict(testcase_name="max_prompts_per_job_le_0", max_prompts_per_job=0), ) def test_validation_errors(self, **overrides): """Verify validation errors for invalid config values.""" with self.assertRaises(ValueError): gb.BatchConfig(**overrides) class EmptyAndPaddingTest(absltest.TestCase): """Test empty prompt handling and result padding/trimming.""" @mock.patch.object(genai, "Client", autospec=True) def test_empty_prompts_fast_path(self, mock_client_cls): """Verify empty prompts return immediately without API calls.""" outs = gb.infer_batch( client=mock_client_cls.return_value, model_id="m", prompts=[], schema_dict=None, gen_config={}, cfg=gb.BatchConfig( enabled=True, poll_interval=1, enable_caching=False, retention_days=None, ), ) self.assertEqual(outs, []) @mock.patch.object(genai, "Client", autospec=True) def test_file_pad_to_expected_count(self, mock_client_cls): """Verify padding to maintain 1:1 alignment with input prompts.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True with mock.patch.object(gb.storage, "Client", autospec=True) as mock_storage: mock_bucket = mock_storage.return_value.bucket.return_value output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = f"output{gb._EXT_JSONL}" output_blob.open.return_value.__enter__.return_value = io.StringIO( _create_batch_response(0, "only_one") ) mock_bucket.list_blobs.return_value = [output_blob] job = create_mock_batch_job() mock_client.batches.create.return_value = job mock_client.batches.get.return_value = job cfg = gb.BatchConfig( enabled=True, threshold=1, poll_interval=1, enable_caching=False, retention_days=None, ) outs = gb.infer_batch( client=mock_client, model_id="m", prompts=["p1", "p2"], schema_dict=None, gen_config={}, cfg=cfg, ) self.assertEqual(outs, ["only_one", ""]) # padded class GCSBatchCachingTest(absltest.TestCase): """Test GCS batch caching functionality.""" def setUp(self): super().setUp() self.mock_storage_cls = self.enter_context( mock.patch.object(gb.storage, "Client", autospec=True) ) self.mock_storage_client = self.mock_storage_cls.return_value self.mock_bucket = self.mock_storage_client.bucket.return_value self.mock_blob = self.mock_bucket.blob.return_value @mock.patch.object(genai, "Client", autospec=True) def test_cache_hit_skips_inference(self, mock_client_cls): """Test that fully cached prompts skip inference.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True mock_client.project = "p" mock_client.location = "l" self.mock_blob.download_as_text.return_value = '{"text": "cached_response"}' cfg = gb.BatchConfig( enabled=True, threshold=1, enable_caching=True, retention_days=None, ) outs = gb.infer_batch( client=mock_client, model_id="m", prompts=["p1"], schema_dict=None, gen_config={}, cfg=cfg, ) self.assertListEqual(outs, ["cached_response"]) mock_client.batches.create.assert_not_called() self.mock_bucket.blob.assert_called() @mock.patch.object(genai, "Client", autospec=True) def test_partial_cache_hit(self, mock_client_cls): """Test that partial cache hits only submit missing prompts.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True mock_client.project = "p" mock_client.location = "l" # Mock GCS cache: hit for "cached_prompt", miss for "new_prompt" # We mock _compute_hash to avoid dealing with complex hashing in test with mock.patch.object(gb.GCSBatchCache, "_compute_hash") as mock_hash: mock_hash.side_effect = lambda k: f"hash_{k['prompt']}" # Pre-configure blobs blob_hit = mock.create_autospec(gb.storage.Blob, instance=True) blob_hit.download_as_text.return_value = '{"text": "cached_response"}' blob_miss = mock.create_autospec(gb.storage.Blob, instance=True) blob_miss.download_as_text.side_effect = exceptions.NotFound("Not found") def get_blob(name): if "hash_cached_prompt" in name: return blob_hit return blob_miss self.mock_bucket.blob.side_effect = get_blob # Mock list_blobs to return the batch output file for the new prompt output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = f"output{gb._EXT_JSONL}" output_blob.open.return_value.__enter__.return_value = io.StringIO( _create_batch_response(0, "new_response") ) self.mock_bucket.list_blobs.return_value = [output_blob] job = create_mock_batch_job() mock_client.batches.create.return_value = job mock_client.batches.get.return_value = job cfg = gb.BatchConfig( enabled=True, threshold=1, enable_caching=True, retention_days=None, ) outs = gb.infer_batch( client=mock_client, model_id="m", prompts=["cached_prompt", "new_prompt"], schema_dict=None, gen_config={}, cfg=cfg, ) self.assertListEqual(outs, ["cached_response", "new_response"]) mock_client.batches.create.assert_called_once() # Verify "new_response" was uploaded to cache (using the miss blob) # The blob used for upload is blob_miss because it was returned for the miss key upload_calls = [ call for call in blob_miss.upload_from_string.mock_calls if "new_response" in str(call) ] self.assertTrue( upload_calls, "Should have uploaded new_response to cache" ) @mock.patch.object(genai, "Client", autospec=True) @mock.patch.dict("os.environ", {}, clear=True) def test_project_passed_to_storage_client(self, mock_client_cls): """Test that project parameter is passed to storage.Client constructor.""" mock_client = mock_client_cls.return_value mock_client.vertexai = True if hasattr(mock_client, "project"): del mock_client.project self.mock_storage_client.create_bucket.return_value = self.mock_bucket output_blob = mock.create_autospec(gb.storage.Blob, instance=True) output_blob.name = f"output{gb._EXT_JSONL}" output_blob.open.return_value.__enter__.return_value = io.StringIO( _create_batch_response(0, {"result": "ok"}) ) self.mock_bucket.list_blobs.return_value = [output_blob] mock_client.batches.create.return_value = create_mock_batch_job() mock_client.batches.get.return_value = create_mock_batch_job() # Create model with specific project and location test_project = "test-project-123" test_location = "us-central1" model = gemini.GeminiLanguageModel( model_id="gemini-2.5-flash", vertexai=True, project=test_project, location=test_location, batch={ "enabled": True, "threshold": 1, "poll_interval": 0.1, "enable_caching": False, "retention_days": None, }, ) list(model.infer(["test prompt"])) # Verify storage.Client was called with the correct project parameter. storage_calls = self.mock_storage_cls.call_args_list project_calls = [ call for call in storage_calls if call.kwargs.get("project") == test_project ] self.assertGreaterEqual( len(project_calls), 1, f"storage.Client should be called with project={test_project}, " f"but was called with: {[call.kwargs for call in storage_calls]}", ) def test_cache_hashing_stability(self): """Test that hash is stable for same inputs.""" cache = gb.GCSBatchCache("b") data1 = {"a": 1, "b": 2} data2 = {"b": 2, "a": 1} self.assertEqual(cache._compute_hash(data1), cache._compute_hash(data2)) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/test_kwargs_passthrough.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for enhanced kwargs pass-through in providers.""" import unittest from unittest import mock import warnings from absl.testing import parameterized from langextract.providers import ollama from langextract.providers import openai class TestOpenAIKwargsPassthrough(unittest.TestCase): """Test OpenAI provider's enhanced kwargs handling.""" @mock.patch('openai.OpenAI') def test_reasoning_effort_alias_normalization(self, mock_openai_class): """Reasoning_effort parameter should be normalized to {reasoning: {effort: ...}}.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-4o-mini', api_key='test-key', reasoning_effort='minimal', ) list(model.infer(['test prompt'])) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs.get('reasoning'), {'effort': 'minimal'}) @mock.patch('openai.OpenAI') def test_reasoning_parameter_normalized(self, mock_openai_class): """Runtime reasoning_effort should normalize even without constructor param.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-5-nano', api_key='test-key', ) list(model.infer(['test prompt'], reasoning_effort='maximal')) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs.get('reasoning'), {'effort': 'maximal'}) @mock.patch('openai.OpenAI') def test_runtime_kwargs_override_stored(self, mock_openai_class): """Runtime parameters should override constructor parameters.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-4o-mini', api_key='test-key', temperature=0.7, top_p=0.9, ) list(model.infer(['test prompt'], temperature=0.3, seed=42)) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs.get('temperature'), 0.3) self.assertEqual(call_args.kwargs.get('top_p'), 0.9) self.assertEqual(call_args.kwargs.get('seed'), 42) @mock.patch('openai.OpenAI') def test_falsy_values_preserved(self, mock_openai_class): """Falsy values like 0 should be preserved, not filtered as None.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-4o', api_key='test-key', temperature=0, top_logprobs=0, ) list(model.infer(['test prompt'])) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs.get('temperature'), 0) self.assertEqual(call_args.kwargs.get('top_logprobs'), 0) @mock.patch('openai.OpenAI') def test_both_reasoning_forms_merge(self, mock_openai_class): """Both reasoning and reasoning_effort should merge without clobbering.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-5', api_key='test-key', reasoning={'other_field': 'value'}, reasoning_effort='maximal', ) list(model.infer(['test prompt'])) call_args = mock_client.chat.completions.create.call_args self.assertEqual( call_args.kwargs.get('reasoning'), {'other_field': 'value', 'effort': 'maximal'}, ) @mock.patch('openai.OpenAI') def test_custom_response_format(self, mock_openai_class): """Custom response_format should override default JSON format.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-4o', api_key='test-key', format_type=openai.data.FormatType.JSON, ) list( model.infer( ['test prompt'], response_format={'type': 'text', 'schema': 'custom'}, ) ) call_args = mock_client.chat.completions.create.call_args self.assertEqual( call_args.kwargs.get('response_format'), {'type': 'text', 'schema': 'custom'}, ) @mock.patch('openai.OpenAI') def test_direct_reasoning_parameter(self, mock_openai_class): """Direct reasoning parameter should pass through without modification.""" mock_client = mock.Mock() mock_openai_class.return_value = mock_client mock_response = mock.Mock() mock_response.choices = [ mock.Mock(message=mock.Mock(content='{"result": "test"}')) ] mock_client.chat.completions.create.return_value = mock_response model = openai.OpenAILanguageModel( model_id='gpt-5', api_key='test-key', ) list(model.infer(['test prompt'], reasoning={'effort': 'minimal'})) call_args = mock_client.chat.completions.create.call_args self.assertEqual(call_args.kwargs.get('reasoning'), {'effort': 'minimal'}) class TestOllamaAuthSupport(parameterized.TestCase): """Test Ollama provider's authentication support for proxied instances.""" @mock.patch('requests.post') def test_api_key_in_authorization_header(self, mock_post): """API key should be sent in Authorization header with Bearer scheme.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id='gemma2:2b', model_url='https://proxy.example.com', api_key='sk-test-key-123', ) list(model.infer(['test prompt'])) mock_post.assert_called_once() call_args = mock_post.call_args headers = call_args.kwargs.get('headers', {}) self.assertEqual(headers.get('Authorization'), 'Bearer sk-test-key-123') self.assertEqual(headers.get('Content-Type'), 'application/json') @mock.patch('requests.post') def test_custom_auth_header_name(self, mock_post): """Custom auth header name (e.g. X-API-Key) should be supported.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id='gemma2:2b', model_url='https://api.example.com', api_key='abc123', auth_header='X-API-Key', auth_scheme='', ) list(model.infer(['test prompt'])) headers = mock_post.call_args.kwargs.get('headers', {}) self.assertEqual(headers.get('X-API-Key'), 'abc123') self.assertNotIn('Authorization', headers) @mock.patch('requests.post') def test_pass_through_kwargs(self, mock_post): """Future Ollama parameters should pass through without code changes.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id='mistral:7b', temperature=0.5, top_k=40, repeat_penalty=1.1, mirostat=2, ) list(model.infer(['test prompt'])) call_args = mock_post.call_args payload = call_args.kwargs['json'] options = payload['options'] self.assertEqual(options.get('temperature'), 0.5) self.assertEqual(options.get('top_k'), 40) self.assertEqual(options.get('repeat_penalty'), 1.1) self.assertEqual(options.get('mirostat'), 2) def test_api_key_redacted_in_repr(self): """API key should be redacted in string representation for security.""" model = ollama.OllamaLanguageModel( model_id='gemma2:2b', api_key='super-secret-key', ) repr_str = repr(model) self.assertIn('[REDACTED]', repr_str, 'API key should be redacted') self.assertNotIn( 'super-secret-key', repr_str, 'Actual API key should not appear' ) @mock.patch('requests.post') def test_localhost_auth_warning_but_still_works(self, mock_post): """Should warn about localhost auth but still send the auth header.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') model = ollama.OllamaLanguageModel( model_id='gemma2:2b', model_url='http://localhost:11434', api_key='unnecessary-key', ) self.assertTrue( any('localhost' in str(warning.message) for warning in w), 'Expected warning about localhost auth', ) # Verify auth header is still sent despite warning list(model.infer(['test prompt'])) headers = mock_post.call_args.kwargs.get('headers', {}) self.assertEqual(headers.get('Authorization'), 'Bearer unnecessary-key') @mock.patch('requests.post') def test_runtime_kwargs_override(self, mock_post): """Runtime parameters should override constructor parameters.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id='gemma2:2b', temperature=0.7, timeout=60, ) list(model.infer(['test prompt'], temperature=0.3, timeout=120)) call_args = mock_post.call_args payload = call_args.kwargs['json'] options = payload['options'] self.assertEqual(options.get('temperature'), 0.3) self.assertEqual(call_args.kwargs.get('timeout'), 120) @parameterized.named_parameters( ('https_localhost', 'https://localhost:11434', True), ('ipv6_localhost', 'http://[::1]:11434', True), ('ipv4_localhost', 'http://127.0.0.1:8080/', True), ('remote_proxy', 'https://proxy.example.com', False), ) @mock.patch('requests.post') def test_localhost_detection(self, url, should_warn, mock_post): """Should detect localhost in various URL formats (IPv6, https, etc).""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') _ = ollama.OllamaLanguageModel( model_id='gemma2:2b', model_url=url, api_key='test-key', ) if should_warn: self.assertTrue( any('localhost' in str(warning.message) for warning in w), f'Expected warning for {url}', ) else: self.assertFalse( any('localhost' in str(warning.message) for warning in w), f'Unexpected warning for {url}', ) @mock.patch('requests.post') def test_format_none_not_in_payload(self, mock_post): """Format key should be omitted from payload when None (not sent as null).""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': 'plain text'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id='gemma2:2b', ) model.format_type = None _ = model._ollama_query( prompt='test prompt', model='gemma2:2b', structured_output_format=None, ) call_args = mock_post.call_args payload = call_args.kwargs['json'] self.assertNotIn('format', payload, 'format=None should not be in payload') @mock.patch('requests.post') def test_reserved_kwargs_not_in_options(self, mock_post): """Reserved top-level keys (stop, format) should not go into options dict.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response model = ollama.OllamaLanguageModel( model_id='gemma2:2b', stop=['END'], temperature=0.5, custom_param='value', ) list(model.infer(['test prompt'])) call_args = mock_post.call_args payload = call_args.kwargs['json'] options = payload['options'] self.assertEqual(payload.get('stop'), ['END']) self.assertNotIn( 'stop', options, 'stop should be at top level, not in options' ) self.assertEqual(options.get('temperature'), 0.5) self.assertEqual(options.get('custom_param'), 'value') @mock.patch('requests.post') def test_api_key_without_localhost_warning(self, mock_post): """Should not warn when using auth with remote/proxied Ollama instances.""" mock_response = mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {'response': '{"test": "value"}'} mock_post.return_value = mock_response with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') model = ollama.OllamaLanguageModel( model_id='gemma2:2b', model_url='https://proxy.example.com', api_key='necessary-key', ) self.assertFalse( any('localhost' in str(warning.message) for warning in w) ) list(model.infer(['test prompt'])) headers = mock_post.call_args.kwargs.get('headers', {}) self.assertEqual(headers.get('Authorization'), 'Bearer necessary-key') if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_live_api.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Live API integration tests that require real API keys. These tests are skipped if API keys are not available in the environment. They should run in CI after all other tests pass. """ import functools import json import os import re import textwrap import time from typing import Any import unittest from unittest import mock import uuid import dotenv import google.auth import google.auth.exceptions import google.genai.errors import pytest from langextract import data import langextract as lx from langextract.core import tokenizer as tokenizer_lib from langextract.providers import gemini_batch as gb dotenv.load_dotenv(override=True) DEFAULT_GEMINI_MODEL = "gemini-2.5-flash" DEFAULT_OPENAI_MODEL = "gpt-4o" GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") or os.environ.get( "LANGEXTRACT_API_KEY" ) OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") VERTEX_PROJECT = os.environ.get("VERTEX_PROJECT") or os.environ.get( "GOOGLE_CLOUD_PROJECT" ) VERTEX_LOCATION = os.environ.get("VERTEX_LOCATION", "us-central1") def has_vertex_ai_credentials(): """Check if Vertex AI credentials are available.""" if not VERTEX_PROJECT: return False try: credentials, _ = google.auth.default() return credentials is not None except (ImportError, google.auth.exceptions.DefaultCredentialsError): return False skip_if_no_gemini = pytest.mark.skipif( not GEMINI_API_KEY, reason=( "Gemini API key not available (set GEMINI_API_KEY or" " LANGEXTRACT_API_KEY)" ), ) skip_if_no_openai = pytest.mark.skipif( not OPENAI_API_KEY, reason="OpenAI API key not available (set OPENAI_API_KEY)", ) skip_if_no_vertex = pytest.mark.skipif( not has_vertex_ai_credentials(), reason=( "Vertex AI credentials not available (set GOOGLE_CLOUD_PROJECT and" " configure gcloud auth)" ), ) live_api = pytest.mark.live_api GEMINI_MODEL_PARAMS = { "temperature": 0.0, "top_p": 0.0, "max_output_tokens": 256, } OPENAI_MODEL_PARAMS = { "temperature": 0.0, } # Extraction Classes _CLASS_MEDICATION = "medication" _CLASS_DOSAGE = "dosage" _CLASS_ROUTE = "route" _CLASS_FREQUENCY = "frequency" _CLASS_DURATION = "duration" _CLASS_CONDITION = "condition" INITIAL_RETRY_DELAY = 1.0 MAX_RETRY_DELAY = 8.0 def retry_on_transient_errors(max_retries=3, backoff_factor=2.0): """Decorator to retry tests on transient API errors with exponential backoff. Args: max_retries (int): Maximum number of retry attempts backoff_factor (float): Multiplier for exponential backoff (e.g., 2.0 = 1s, 2s, 4s) """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): last_exception = None delay = INITIAL_RETRY_DELAY for attempt in range(max_retries + 1): try: return func(*args, **kwargs) except ( lx.exceptions.LangExtractError, google.genai.errors.ClientError, ConnectionError, TimeoutError, OSError, RuntimeError, ) as e: last_exception = e if attempt < max_retries: print( f"\nRetryable error ({type(e).__name__}) on attempt" f" {attempt + 1}/{max_retries + 1}: {e}" ) time.sleep(delay) delay = min(delay * backoff_factor, MAX_RETRY_DELAY) continue raise raise last_exception return wrapper return decorator @pytest.fixture(autouse=True) def add_delay_between_tests(): """Add a small delay between tests to avoid rate limiting.""" yield time.sleep(0.5) def get_basic_medication_examples(): """Get example data for basic medication extraction.""" return [ lx.data.ExampleData( text="Patient was given 250 mg IV Cefazolin TID for one week.", extractions=[ lx.data.Extraction( extraction_class=_CLASS_DOSAGE, extraction_text="250 mg" ), lx.data.Extraction( extraction_class=_CLASS_ROUTE, extraction_text="IV" ), lx.data.Extraction( extraction_class=_CLASS_MEDICATION, extraction_text="Cefazolin", ), lx.data.Extraction( extraction_class=_CLASS_FREQUENCY, extraction_text="TID", # TID = three times a day ), lx.data.Extraction( extraction_class=_CLASS_DURATION, extraction_text="for one week", ), ], ) ] def get_relationship_examples(): """Get example data for medication relationship extraction.""" return [ lx.data.ExampleData( text=( "Patient takes Aspirin 100mg daily for heart health and" " Simvastatin 20mg at bedtime." ), extractions=[ # First medication group lx.data.Extraction( extraction_class=_CLASS_MEDICATION, extraction_text="Aspirin", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class=_CLASS_DOSAGE, extraction_text="100mg", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class=_CLASS_FREQUENCY, extraction_text="daily", attributes={"medication_group": "Aspirin"}, ), lx.data.Extraction( extraction_class=_CLASS_CONDITION, extraction_text="heart health", attributes={"medication_group": "Aspirin"}, ), # Second medication group lx.data.Extraction( extraction_class=_CLASS_MEDICATION, extraction_text="Simvastatin", attributes={"medication_group": "Simvastatin"}, ), lx.data.Extraction( extraction_class=_CLASS_DOSAGE, extraction_text="20mg", attributes={"medication_group": "Simvastatin"}, ), lx.data.Extraction( extraction_class=_CLASS_FREQUENCY, extraction_text="at bedtime", attributes={"medication_group": "Simvastatin"}, ), ], ) ] def extract_by_class(result, extraction_class): """Helper to extract entities by class. Returns a set of extraction texts for the given class. """ return { e.extraction_text for e in result.extractions if e.extraction_class == extraction_class } def assert_extractions_contain(test_case, result, expected_classes): """Assert that result contains all expected extraction classes. Uses unittest assertions for richer error messages. """ actual_classes = {e.extraction_class for e in result.extractions} missing_classes = expected_classes - actual_classes test_case.assertFalse( missing_classes, f"Missing expected classes: {missing_classes}. Found extractions:" f" {[f'{e.extraction_class}:{e.extraction_text}' for e in result.extractions]}", ) def assert_valid_char_intervals(test_case, result): """Assert that all extractions have valid char intervals and alignment status.""" for extraction in result.extractions: test_case.assertIsNotNone( extraction.char_interval, f"Missing char_interval for extraction: {extraction.extraction_text}", ) test_case.assertIsNotNone( extraction.alignment_status, "Missing alignment_status for extraction:" f" {extraction.extraction_text}", ) if isinstance(result, lx.data.AnnotatedDocument) and result.text: text_length = len(result.text) test_case.assertGreaterEqual( extraction.char_interval.start_pos, 0, f"Invalid start_pos for extraction: {extraction.extraction_text}", ) test_case.assertLessEqual( extraction.char_interval.end_pos, text_length, f"Invalid end_pos for extraction: {extraction.extraction_text}", ) class TestLiveAPIGemini(unittest.TestCase): """Tests using real Gemini API.""" def _check_cached_result(self, result_json: dict[str, Any]) -> bool: """Check if cached result contains expected medication data. Args: result_json: The raw JSON dict from the cache file. Expected format: {"text": "JSON_STRING_OF_RESULT"} Returns: True if the result contains valid medication extractions, False otherwise. """ try: text_content = result_json.get("text") if not isinstance(text_content, str): return False inner_json = json.loads(text_content) if not isinstance(inner_json, dict): return False extractions_data = inner_json.get(data.EXTRACTIONS_KEY) if not isinstance(extractions_data, list): return False extractions = [] for item in extractions_data: if isinstance(item, dict): clean_item = {k: v for k, v in item.items() if not k.startswith("_")} extractions.append(data.Extraction(**clean_item)) doc = data.AnnotatedDocument( text=inner_json.get("text"), extractions=extractions ) if not doc.extractions: return False # Check for specific content medication_texts = extract_by_class(doc, _CLASS_MEDICATION) dosage_texts = extract_by_class(doc, _CLASS_DOSAGE) has_lisinopril = any("Lisinopril" in t for t in medication_texts) has_10mg = any("10mg" in t for t in dosage_texts) return has_lisinopril and has_10mg except (json.JSONDecodeError, TypeError, ValueError): return False def _verify_gcs_cache_content(self, bucket_name): """Verify that GCS cache contains expected structured results.""" cache = gb.GCSBatchCache(bucket_name, project=VERTEX_PROJECT) found_content = False # Use iter_items() to check cache content items = list(cache.iter_items()) self.assertTrue(len(items) > 0, "No cache files found in GCS bucket") for _, text in items: try: result_json = json.loads(text) if self._check_cached_result(result_json): found_content = True break except (json.JSONDecodeError, TypeError, ValueError): continue self.assertTrue( found_content, "Could not find expected structured result in GCS cache files", ) @skip_if_no_gemini @live_api @retry_on_transient_errors(max_retries=2) def test_medication_extraction(self): """Test medication extraction with entities in order.""" prompt = textwrap.dedent("""\ Extract medication information including medication name, dosage, route, frequency, and duration in the order they appear in the text.""") examples = get_basic_medication_examples() input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, api_key=GEMINI_API_KEY, language_model_params=GEMINI_MODEL_PARAMS, ) assert result is not None self.assertIsInstance(result, lx.data.AnnotatedDocument) assert len(result.extractions) > 0 expected_classes = { _CLASS_DOSAGE, _CLASS_ROUTE, _CLASS_MEDICATION, _CLASS_FREQUENCY, _CLASS_DURATION, } assert_extractions_contain(self, result, expected_classes) assert_valid_char_intervals(self, result) # Using regex for precise matching to avoid false positives medication_texts = extract_by_class(result, _CLASS_MEDICATION) self.assertTrue( any( re.search(r"\bIbuprofen\b", text, re.IGNORECASE) for text in medication_texts ), f"No Ibuprofen found in: {medication_texts}", ) dosage_texts = extract_by_class(result, _CLASS_DOSAGE) self.assertTrue( any( re.search(r"\b400\s*mg\b", text, re.IGNORECASE) for text in dosage_texts ), f"No 400mg dosage found in: {dosage_texts}", ) route_texts = extract_by_class(result, _CLASS_ROUTE) self.assertTrue( any( re.search(r"\b(PO|oral)\b", text, re.IGNORECASE) for text in route_texts ), f"No PO/oral route found in: {route_texts}", ) @skip_if_no_gemini @live_api @retry_on_transient_errors(max_retries=2) def test_multilingual_medication_extraction(self): """Test medication extraction with Japanese text.""" text = ( # "The patient takes 10 mg of medication daily." "患者は毎日10mgの薬を服用します。" ) prompt = "Extract medication information including dosage and frequency." examples = [ lx.data.ExampleData( text="The patient takes 20mg of aspirin twice daily.", extractions=[ lx.data.Extraction( extraction_class=_CLASS_MEDICATION, extraction_text="aspirin", attributes={ _CLASS_DOSAGE: "20mg", _CLASS_FREQUENCY: "twice daily", }, ), ], ) ] unicode_tokenizer = tokenizer_lib.UnicodeTokenizer() result = lx.extract( text_or_documents=text, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, api_key=GEMINI_API_KEY, language_model_params=GEMINI_MODEL_PARAMS, tokenizer=unicode_tokenizer, ) assert result is not None self.assertIsInstance(result, lx.data.AnnotatedDocument) assert len(result.extractions) > 0 medication_extractions = [ e for e in result.extractions if e.extraction_class == _CLASS_MEDICATION ] assert ( len(medication_extractions) > 0 ), "No medication entities found in Japanese text" assert_valid_char_intervals(self, result) @skip_if_no_gemini @live_api @retry_on_transient_errors(max_retries=2) def test_explicit_provider_gemini(self): """Test using explicit provider with Gemini.""" config = lx.factory.ModelConfig( model_id=DEFAULT_GEMINI_MODEL, provider="GeminiLanguageModel", provider_kwargs={ "api_key": GEMINI_API_KEY, "temperature": 0.0, }, ) model = lx.factory.create_model(config) self.assertEqual(model.__class__.__name__, "GeminiLanguageModel") self.assertEqual(model.model_id, DEFAULT_GEMINI_MODEL) config2 = lx.factory.ModelConfig( model_id=DEFAULT_GEMINI_MODEL, provider="gemini", provider_kwargs={ "api_key": GEMINI_API_KEY, }, ) model2 = lx.factory.create_model(config2) self.assertEqual(model2.__class__.__name__, "GeminiLanguageModel") @skip_if_no_gemini @live_api @retry_on_transient_errors(max_retries=2) def test_medication_relationship_extraction(self): """Test relationship extraction for medications with Gemini.""" input_text = """ The patient was prescribed Lisinopril and Metformin last month. He takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes. """ prompt = textwrap.dedent(""" Extract medications with their details, using attributes to group related information: 1. Extract entities in the order they appear in the text 2. Each entity must have a 'medication_group' attribute linking it to its medication 3. All details about a medication should share the same medication_group value """) examples = get_relationship_examples() result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, api_key=GEMINI_API_KEY, language_model_params=GEMINI_MODEL_PARAMS, ) assert result is not None assert len(result.extractions) > 0 assert_valid_char_intervals(self, result) medication_groups = {} for extraction in result.extractions: assert ( extraction.attributes is not None ), f"Missing attributes for {extraction.extraction_text}" assert ( "medication_group" in extraction.attributes ), f"Missing medication_group for {extraction.extraction_text}" group_name = extraction.attributes["medication_group"] medication_groups.setdefault(group_name, []).append(extraction) assert ( len(medication_groups) >= 2 ), f"Expected at least 2 medications, found {len(medication_groups)}" # Allow flexible matching for dosage field (could be "dosage" or "dose") for med_name, extractions in medication_groups.items(): extraction_classes = {e.extraction_class for e in extractions} # At minimum, each group should have the medication itself assert ( _CLASS_MEDICATION in extraction_classes ), f"{med_name} group missing medication entity" # Dosage is expected but might be formatted differently assert any( c in extraction_classes for c in [_CLASS_DOSAGE, "dose"] ), f"{med_name} group missing dosage" @skip_if_no_vertex @live_api @pytest.mark.vertex_ai @mock.patch.object(gb, "infer_batch", wraps=gb.infer_batch, autospec=True) def test_batch_extraction_vertex_gcs(self, mock_infer_batch): """Test extraction using Vertex AI Batch API with GCS. This test runs a real Vertex AI Batch job and will take time to complete. It is skipped unless VERTEX_PROJECT is set. We wrap `infer_batch` to verify that: - Batch API is actually called (not falling back to real-time API) - Schema dict is passed (non-None) to the batch function """ prompt = textwrap.dedent("""\ Extract medication information including medication name, dosage, route, frequency, and duration in the order they appear in the text.""") examples = get_basic_medication_examples() documents = [ lx.data.Document( document_id="vx_doc1", text="Patient took 400 mg PO Ibuprofen q4h for two days.", ), lx.data.Document( document_id="vx_doc2", text="Patient was given 250 mg IV Cefazolin TID for one week.", ), lx.data.Document( document_id="vx_doc3", text="Administered 2 mg IV Morphine once for acute pain.", ), lx.data.Document( document_id="vx_doc4", text="Prescribed 500 mg PO Amoxicillin BID for infection.", ), lx.data.Document( document_id="vx_doc5", text="Given 10 mg IM Haloperidol PRN for agitation.", ), ] expected_meds = [ "Ibuprofen", "Cefazolin", "Morphine", "Amoxicillin", "Haloperidol", ] language_model_params = dict(GEMINI_MODEL_PARAMS) language_model_params["vertexai"] = True language_model_params["project"] = VERTEX_PROJECT language_model_params["location"] = VERTEX_LOCATION language_model_params["batch"] = { "enabled": True, "threshold": 2, "poll_interval": 1, # Fast polling for test "timeout": 900, # 15 minutes for actual batch job completion } batch_result = lx.extract( text_or_documents=documents, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, language_model_params=language_model_params, ) mock_infer_batch.assert_called_once() call_args = mock_infer_batch.call_args schema_dict_arg = call_args.kwargs.get("schema_dict") self.assertIsNotNone( schema_dict_arg, "schema_dict should be passed to batch API (not None)", ) self.assertIsInstance(batch_result, list) self.assertEqual( len(batch_result), len(documents), f"Expected {len(documents)} results from Vertex batch API", ) for i, (res, med_name) in enumerate(zip(batch_result, expected_meds)): self.assertIsInstance( res, lx.data.AnnotatedDocument, f"Result {i} should be an AnnotatedDocument, got {type(res)}", ) self.assertTrue( res.extractions, f"No extractions for document {i}", ) for extraction in res.extractions: self.assertIsInstance( extraction, lx.data.Extraction, "Extraction item should be Extraction object, got" f" {type(extraction)}", ) meds = extract_by_class(res, _CLASS_MEDICATION) self.assertTrue( any( re.search(rf"\b{re.escape(med_name)}\b", m, re.IGNORECASE) for m in meds ), f"Expected medication '{med_name}' not found in results: {meds}", ) dosages = extract_by_class(res, _CLASS_DOSAGE) self.assertTrue( dosages, f"No dosage extracted for medication '{med_name}'", ) assert_valid_char_intervals(self, res) @skip_if_no_vertex @live_api @pytest.mark.vertex_ai def test_batch_caching_live(self): """Test batch caching with real Vertex AI Batch API. Verifies that: 1. First run populates GCS cache 2. Second run uses cache (returns same results faster) """ prompt = "Extract the medication: Patient takes 10mg Lisinopril." examples = get_basic_medication_examples() # Use unique IDs to ensure cache isolation between test runs. run_id = uuid.uuid4().hex[:8] documents = [ lx.data.Document( document_id=f"doc_{i}_{run_id}", text=f"Patient takes 10mg Lisinopril {i} {run_id}.", ) for i in range(2) ] language_model_params = dict(GEMINI_MODEL_PARAMS) language_model_params["vertexai"] = True language_model_params["project"] = VERTEX_PROJECT language_model_params["location"] = VERTEX_LOCATION language_model_params["batch"] = { "enabled": True, "threshold": 2, "poll_interval": 1, "timeout": 900, "enable_caching": True, } print("\nStarting first batch run (API)...") start_time = time.time() results1 = list( lx.extract( text_or_documents=documents, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, language_model_params=language_model_params, ) ) duration1 = time.time() - start_time print(f"First run took {duration1:.2f}s") print("Starting second batch run (Cache)...") start_time = time.time() results2 = list( lx.extract( text_or_documents=documents, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, language_model_params=language_model_params, ) ) duration2 = time.time() - start_time print(f"Second run took {duration2:.2f}s") self.assertEqual(len(results1), len(results2)) for r1, r2 in zip(results1, results2): self.assertEqual(r1.text, r2.text) self.assertEqual(len(r1.extractions), len(r2.extractions)) self.assertLess(duration2, 10.0, "Second run took too long for cache hit") self.assertLess(duration2, 10.0, "Second run took too long for cache hit") print("\nVerifying GCS cache content...") bucket_name = gb._get_bucket_name(VERTEX_PROJECT, VERTEX_LOCATION) print(f"Checking bucket: {bucket_name}") self._verify_gcs_cache_content(bucket_name) class TestCrossChunkContext(unittest.TestCase): """Tests for cross-chunk context feature with real API.""" @skip_if_no_gemini @live_api @retry_on_transient_errors(max_retries=3) def test_context_window_extracts_from_both_chunks(self): """Verify context_window_chars enables extraction across chunk boundaries.""" input_text = ( "Dr. Sarah Chen is the lead researcher at the institute. " "She published groundbreaking work on neural networks last year." ) prompt = textwrap.dedent( """\ Extract all person names, roles, and achievements mentioned in the text. Include both explicit names and information associated with pronouns.""" ) examples = [ lx.data.ExampleData( text=( "Professor James Miller leads the physics department. " "He won the Nobel Prize in 2020." ), extractions=[ lx.data.Extraction( extraction_class="person", extraction_text="Professor James Miller", attributes={"role": "leads the physics department"}, ), lx.data.Extraction( extraction_class="achievement", extraction_text="won the Nobel Prize in 2020", ), ], ) ] result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=DEFAULT_GEMINI_MODEL, api_key=GEMINI_API_KEY, language_model_params=GEMINI_MODEL_PARAMS, max_char_buffer=60, context_window_chars=50, ) self.assertIsNotNone(result) self.assertGreater(len(result.extractions), 0) all_extraction_text = " ".join( str(e.extraction_text) + " " + str(e.attributes) for e in result.extractions ).lower() has_chunk1_content = any( term in all_extraction_text for term in ("sarah", "chen", "researcher", "lead") ) has_chunk2_content = any( term in all_extraction_text for term in ("published", "groundbreaking", "neural", "networks") ) self.assertTrue( has_chunk1_content, f"Expected chunk 1 content (Sarah Chen). Got: {result.extractions}", ) self.assertTrue( has_chunk2_content, f"Expected chunk 2 content (publication). Got: {result.extractions}", ) class TestLiveAPIOpenAI(unittest.TestCase): """Tests using real OpenAI API.""" @skip_if_no_openai @live_api @retry_on_transient_errors(max_retries=2) def test_medication_extraction(self): """Test medication extraction with OpenAI models.""" prompt = textwrap.dedent("""\ Extract medication information including medication name, dosage, route, frequency, and duration in the order they appear in the text.""") examples = get_basic_medication_examples() input_text = "Patient took 400 mg PO Ibuprofen q4h for two days." result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=DEFAULT_OPENAI_MODEL, api_key=OPENAI_API_KEY, use_schema_constraints=False, language_model_params=OPENAI_MODEL_PARAMS, ) assert result is not None self.assertIsInstance(result, lx.data.AnnotatedDocument) assert len(result.extractions) > 0 expected_classes = { _CLASS_DOSAGE, _CLASS_ROUTE, _CLASS_MEDICATION, _CLASS_FREQUENCY, _CLASS_DURATION, } assert_extractions_contain(self, result, expected_classes) assert_valid_char_intervals(self, result) # Using regex for precise matching to avoid false positives medication_texts = extract_by_class(result, _CLASS_MEDICATION) self.assertTrue( any( re.search(r"\bIbuprofen\b", text, re.IGNORECASE) for text in medication_texts ), f"No Ibuprofen found in: {medication_texts}", ) dosage_texts = extract_by_class(result, _CLASS_DOSAGE) self.assertTrue( any( re.search(r"\b400\s*mg\b", text, re.IGNORECASE) for text in dosage_texts ), f"No 400mg dosage found in: {dosage_texts}", ) route_texts = extract_by_class(result, _CLASS_ROUTE) self.assertTrue( any( re.search(r"\b(PO|oral)\b", text, re.IGNORECASE) for text in route_texts ), f"No PO/oral route found in: {route_texts}", ) @skip_if_no_openai @live_api @retry_on_transient_errors(max_retries=2) def test_explicit_provider_selection(self): """Test using explicit provider parameter for disambiguation.""" # Test with explicit model_id and provider config = lx.factory.ModelConfig( model_id=DEFAULT_OPENAI_MODEL, provider="OpenAILanguageModel", # Explicit provider selection provider_kwargs={ "api_key": OPENAI_API_KEY, "temperature": 0.0, }, ) model = lx.factory.create_model(config) self.assertIsInstance(model, lx.providers.openai.OpenAILanguageModel) self.assertEqual(model.model_id, DEFAULT_OPENAI_MODEL) # Also test using provider without model_id (uses default) config_default = lx.factory.ModelConfig( provider="OpenAILanguageModel", provider_kwargs={ "api_key": OPENAI_API_KEY, }, ) model_default = lx.factory.create_model(config_default) self.assertEqual(model_default.__class__.__name__, "OpenAILanguageModel") # Should use the default model_id from the provider self.assertEqual(model_default.model_id, "gpt-4o-mini") @skip_if_no_openai @live_api @retry_on_transient_errors(max_retries=2) def test_medication_relationship_extraction(self): """Test relationship extraction for medications with OpenAI.""" input_text = """ The patient was prescribed Lisinopril and Metformin last month. He takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes. """ prompt = textwrap.dedent(""" Extract medications with their details, using attributes to group related information: 1. Extract entities in the order they appear in the text 2. Each entity must have a 'medication_group' attribute linking it to its medication 3. All details about a medication should share the same medication_group value """) examples = get_relationship_examples() result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=DEFAULT_OPENAI_MODEL, api_key=OPENAI_API_KEY, use_schema_constraints=False, language_model_params=OPENAI_MODEL_PARAMS, ) assert result is not None assert len(result.extractions) > 0 assert_valid_char_intervals(self, result) medication_groups = {} for extraction in result.extractions: assert ( extraction.attributes is not None ), f"Missing attributes for {extraction.extraction_text}" assert ( "medication_group" in extraction.attributes ), f"Missing medication_group for {extraction.extraction_text}" group_name = extraction.attributes["medication_group"] medication_groups.setdefault(group_name, []).append(extraction) assert ( len(medication_groups) >= 2 ), f"Expected at least 2 medications, found {len(medication_groups)}" # Allow flexible matching for dosage field (could be "dosage" or "dose") for med_name, extractions in medication_groups.items(): extraction_classes = {e.extraction_class for e in extractions} # At minimum, each group should have the medication itself assert ( _CLASS_MEDICATION in extraction_classes ), f"{med_name} group missing medication entity" # Dosage is expected but might be formatted differently assert any( c in extraction_classes for c in [_CLASS_DOSAGE, "dose"] ), f"{med_name} group missing dosage" ================================================ FILE: tests/test_ollama_integration.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Integration tests for Ollama functionality.""" import socket import pytest import requests import langextract as lx def _ollama_available(): """Check if Ollama is running on localhost:11434.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: result = sock.connect_ex(("localhost", 11434)) return result == 0 @pytest.mark.skipif(not _ollama_available(), reason="Ollama not running") def test_ollama_extraction(): input_text = "Isaac Asimov was a prolific science fiction writer." prompt = "Extract the author's full name and their primary literary genre." examples = [ lx.data.ExampleData( text=( "J.R.R. Tolkien was an English writer, best known for" " high-fantasy." ), extractions=[ lx.data.Extraction( extraction_class="author_details", extraction_text="J.R.R. Tolkien was an English writer...", attributes={ "name": "J.R.R. Tolkien", "genre": "high-fantasy", }, ) ], ) ] model_id = "gemma2:2b" result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=model_id, model_url="http://localhost:11434", temperature=0.3, fence_output=False, use_schema_constraints=False, ) assert len(result.extractions) > 0 extraction = result.extractions[0] assert extraction.extraction_class == "author_details" if extraction.attributes: assert "asimov" in extraction.attributes.get("name", "").lower() @pytest.mark.skipif(not _ollama_available(), reason="Ollama not running") def test_ollama_extraction_with_fence_fallback(): input_text = "Marie Curie was a physicist who won two Nobel prizes." prompt = "Extract information about people and their achievements." examples = [ lx.data.ExampleData( text="Albert Einstein developed the theory of relativity.", extractions=[ lx.data.Extraction( extraction_class="person", extraction_text="Albert Einstein", attributes={"achievement": "theory of relativity"}, ) ], ) ] model_id = "gemma2:2b" result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id=model_id, model_url="http://localhost:11434", temperature=0.3, fence_output=True, # Testing that fallback works use_schema_constraints=False, ) assert len(result.extractions) > 0 extraction = result.extractions[0] assert extraction.extraction_class == "person" assert ( "marie" in extraction.extraction_text.lower() or "curie" in extraction.extraction_text.lower() ) def _model_available(model_name): """Check if a specific model is available in Ollama.""" if not _ollama_available(): return False try: response = requests.get("http://localhost:11434/api/tags", timeout=5) models = [m["name"] for m in response.json().get("models", [])] return any(model_name in m for m in models) except (requests.RequestException, KeyError, TypeError): return False @pytest.mark.skipif( not _model_available("deepseek-r1"), reason="DeepSeek-R1 not available in Ollama", ) def test_deepseek_r1_extraction(): """Test extraction with DeepSeek-R1 reasoning model. DeepSeek-R1 outputs tags before JSON when not using format:json. This test verifies the model works correctly with langextract. """ input_text = "John Smith is a software engineer at Google." prompt = "Extract people and their roles." examples = [ lx.data.ExampleData( text="Alice works as a designer at Apple.", extractions=[ lx.data.Extraction( extraction_class="person", extraction_text="Alice", attributes={"role": "designer", "company": "Apple"}, ) ], ) ] result = lx.extract( text_or_documents=input_text, prompt_description=prompt, examples=examples, model_id="deepseek-r1:1.5b", model_url="http://localhost:11434", temperature=0.3, ) assert len(result.extractions) > 0 extraction = result.extractions[0] assert extraction.extraction_class == "person" assert "john" in extraction.extraction_text.lower() ================================================ FILE: tests/tokenizer_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import textwrap from absl.testing import absltest from absl.testing import parameterized from langextract.core import tokenizer class TokenizerTest(parameterized.TestCase): # pylint: disable=too-many-public-methods def assertTokenListEqual(self, actual_tokens, expected_tokens, msg=None): self.assertLen(actual_tokens, len(expected_tokens), msg=msg) for i, (expected, actual) in enumerate(zip(expected_tokens, actual_tokens)): expected = tokenizer.Token( index=expected.index, token_type=expected.token_type, first_token_after_newline=expected.first_token_after_newline, ) actual = tokenizer.Token( index=actual.index, token_type=actual.token_type, first_token_after_newline=actual.first_token_after_newline, ) self.assertDataclassEqual( expected, actual, msg=f"Token mismatch at index {i}", ) @parameterized.named_parameters( dict( testcase_name="basic_text", input_text="Hello, world!", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=1, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=3, token_type=tokenizer.TokenType.PUNCTUATION ), ], ), dict( testcase_name="multiple_spaces_and_numbers", input_text="Age: 25\nWeight=70kg.", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=1, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token(index=2, token_type=tokenizer.TokenType.NUMBER), tokenizer.Token( index=3, token_type=tokenizer.TokenType.WORD, first_token_after_newline=True, ), tokenizer.Token( index=4, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token(index=5, token_type=tokenizer.TokenType.NUMBER), tokenizer.Token(index=6, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=7, token_type=tokenizer.TokenType.PUNCTUATION ), ], ), dict( testcase_name="multi_line_input", input_text="Line1\nLine2\nLine3", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=1, token_type=tokenizer.TokenType.NUMBER), tokenizer.Token( index=2, token_type=tokenizer.TokenType.WORD, first_token_after_newline=True, ), tokenizer.Token(index=3, token_type=tokenizer.TokenType.NUMBER), tokenizer.Token( index=4, token_type=tokenizer.TokenType.WORD, first_token_after_newline=True, ), tokenizer.Token(index=5, token_type=tokenizer.TokenType.NUMBER), ], ), dict( testcase_name="only_symbols", input_text="!!!@# $$$%", expected_tokens=[ tokenizer.Token( index=0, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token( index=1, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token( index=2, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token( index=3, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token( index=4, token_type=tokenizer.TokenType.PUNCTUATION ), ], ), dict( testcase_name="empty_string", input_text="", expected_tokens=[], ), dict( testcase_name="non_ascii_text", input_text="café", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), ], ), dict( testcase_name="mixed_punctuation", input_text="?!", expected_tokens=[ tokenizer.Token( index=0, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token( index=1, token_type=tokenizer.TokenType.PUNCTUATION ), ], ), ) def test_tokenize_various_inputs(self, input_text, expected_tokens): tokenized = tokenizer.tokenize(input_text) self.assertTokenListEqual( tokenized.tokens, expected_tokens, msg=f"Tokens mismatch for input: {input_text!r}", ) def test_first_token_after_newline_flag(self): input_text = "Line1\nLine2\nLine3" tokenized = tokenizer.tokenize(input_text) expected_tokens = [ tokenizer.Token( index=0, token_type=tokenizer.TokenType.WORD, ), tokenizer.Token( index=1, token_type=tokenizer.TokenType.NUMBER, ), tokenizer.Token( index=2, token_type=tokenizer.TokenType.WORD, first_token_after_newline=True, ), tokenizer.Token( index=3, token_type=tokenizer.TokenType.NUMBER, ), tokenizer.Token( index=4, token_type=tokenizer.TokenType.WORD, first_token_after_newline=True, ), tokenizer.Token( index=5, token_type=tokenizer.TokenType.NUMBER, ), ] self.assertTokenListEqual( tokenized.tokens, expected_tokens, msg="Newline flags mismatch", ) def test_performance_optimization_no_crash(self): """Verify that tokenization handles empty strings and newlines without error.""" tok = tokenizer.RegexTokenizer() text = "" tokenized = tok.tokenize(text) self.assertEmpty(tokenized.tokens) text = "\n" tokenized = tok.tokenize(text) self.assertEmpty(tokenized.tokens) text = "A\nB" tokenized = tok.tokenize(text) self.assertLen(tokenized.tokens, 2) self.assertTrue(tokenized.tokens[1].first_token_after_newline) def test_underscore_handling(self): """Verify that underscores are preserved as punctuation/symbols.""" # RegexTokenizer should now capture underscores explicitly. tok = tokenizer.RegexTokenizer() text = "user_id" tokenized = tok.tokenize(text) # Expecting: "user", "_", "id" self.assertLen(tokenized.tokens, 3) self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD) self.assertEqual( tokenized.tokens[1].token_type, tokenizer.TokenType.PUNCTUATION ) self.assertEqual(tokenized.tokens[2].token_type, tokenizer.TokenType.WORD) class UnicodeTokenizerTest(parameterized.TestCase): # pylint: disable=too-many-public-methods def assertTokenListEqual(self, actual_tokens, expected_tokens, msg=None): self.assertLen(actual_tokens, len(expected_tokens), msg=msg) for i, (expected, actual) in enumerate(zip(expected_tokens, actual_tokens)): expected_tok = tokenizer.Token( index=expected.index, token_type=expected.token_type, first_token_after_newline=expected.first_token_after_newline, ) actual_tok = tokenizer.Token( index=actual.index, token_type=actual.token_type, first_token_after_newline=actual.first_token_after_newline, ) self.assertDataclassEqual( expected_tok, actual_tok, msg=f"Token mismatch at index {i}", ) @parameterized.named_parameters( dict( testcase_name="japanese_text", input_text="こんにちは、世界!", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=1, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=3, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=4, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=5, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token(index=6, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=7, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=8, token_type=tokenizer.TokenType.PUNCTUATION ), ], ), dict( testcase_name="english_text", input_text="Hello, world!", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=1, token_type=tokenizer.TokenType.PUNCTUATION ), tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD), tokenizer.Token( index=3, token_type=tokenizer.TokenType.PUNCTUATION ), ], ), dict( testcase_name="mixed_text", input_text="Hello 世界 123", expected_tokens=[ tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=1, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD), tokenizer.Token(index=3, token_type=tokenizer.TokenType.NUMBER), ], ), ) def test_tokenize_various_inputs(self, input_text, expected_tokens): tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(input_text) self.assertTokenListEqual( tokenized.tokens, expected_tokens, msg=f"Tokens mismatch for input: {input_text!r}", ) @parameterized.named_parameters( dict( testcase_name="mixed_digit_han_same_type_grouping", input_text="10毫克", # "10 milligrams" expected_tokens=[ ("10", tokenizer.TokenType.NUMBER), ("毫", tokenizer.TokenType.WORD), ("克", tokenizer.TokenType.WORD), ], expected_first_after_newline=[False, False, False], ), dict( testcase_name="underscore_word_separator", input_text="hello_world", expected_tokens=[ ("hello", tokenizer.TokenType.WORD), ("_", tokenizer.TokenType.PUNCTUATION), ("world", tokenizer.TokenType.WORD), ], expected_first_after_newline=[False, False, False], ), dict( testcase_name="leading_trailing_underscores", input_text="_test_case_", expected_tokens=[ ("_", tokenizer.TokenType.PUNCTUATION), ("test", tokenizer.TokenType.WORD), ("_", tokenizer.TokenType.PUNCTUATION), ("case", tokenizer.TokenType.WORD), ("_", tokenizer.TokenType.PUNCTUATION), ], expected_first_after_newline=[False, False, False, False, False], ), ) def test_special_unicode_and_punctuation_handling( self, input_text, expected_tokens, expected_first_after_newline ): """Test special Unicode sequences, punctuation grouping, and script handling edge cases.""" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(input_text) self.assertLen( tokenized.tokens, len(expected_tokens), f"Expected {len(expected_tokens)} tokens for edge case test, but got" f" {len(tokenized.tokens)}", ) for i, ( token, (expected_text, expected_type), expected_newline, ) in enumerate( zip(tokenized.tokens, expected_tokens, expected_first_after_newline) ): actual_text = input_text[ token.char_interval.start_pos : token.char_interval.end_pos ] self.assertEqual( actual_text, expected_text, msg=f"Token {i} text mismatch.", ) self.assertEqual( token.token_type, expected_type, msg=f"Token {i} type mismatch.", ) self.assertEqual( token.first_token_after_newline, expected_newline, msg=f"Token {i} newline flag mismatch.", ) def test_first_token_after_newline_parity(self): """Test that UnicodeTokenizer matches RegexTokenizer for newline detection.""" input_text = "a\n b" regex_tok = tokenizer.RegexTokenizer() regex_tokens = regex_tok.tokenize(input_text).tokens self.assertTrue(regex_tokens[1].first_token_after_newline) unicode_tok = tokenizer.UnicodeTokenizer() unicode_tokens = unicode_tok.tokenize(input_text).tokens self.assertTrue( unicode_tokens[1].first_token_after_newline, "UnicodeTokenizer failed to detect newline in gap 'a\\n b'", ) def test_expanded_cjk_detection(self): """Test detection of CJK characters in extended ranges.""" input_text = "\u4e00\u3400\U00020000" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(input_text) self.assertLen(tokenized.tokens, 3) for token in tokenized.tokens: self.assertEqual(token.token_type, tokenizer.TokenType.WORD) def test_mixed_script_and_emoji(self): """Test mixed script and emoji handling.""" input_text = "Hello👋🏼世界123" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(input_text) expected_tokens = [ ("Hello", tokenizer.TokenType.WORD), ( "👋🏼", tokenizer.TokenType.PUNCTUATION, ), ("世", tokenizer.TokenType.WORD), ("界", tokenizer.TokenType.WORD), ("123", tokenizer.TokenType.NUMBER), ] self.assertLen(tokenized.tokens, len(expected_tokens)) for i, (expected_text, expected_type) in enumerate(expected_tokens): token = tokenized.tokens[i] actual_text = tokenized.text[ token.char_interval.start_pos : token.char_interval.end_pos ] self.assertEqual(actual_text, expected_text) self.assertEqual(token.token_type, expected_type) def test_script_boundary_grouping(self): """Test that we do NOT group characters from different scripts.""" tok = tokenizer.UnicodeTokenizer() text = "HelloПривет" tokenized = tok.tokenize(text) self.assertLen(tokenized.tokens, 2, "Should be split into 2 tokens") self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD) self.assertEqual(tokenized.tokens[1].token_type, tokenizer.TokenType.WORD) t1_text = text[ tokenized.tokens[0] .char_interval.start_pos : tokenized.tokens[0] .char_interval.end_pos ] t2_text = text[ tokenized.tokens[1] .char_interval.start_pos : tokenized.tokens[1] .char_interval.end_pos ] self.assertEqual(t1_text, "Hello") self.assertEqual(t2_text, "Привет") def test_non_spaced_scripts_no_grouping(self): """Test that non-spaced scripts (Thai, Lao, etc.) are NOT grouped into a single word.""" tok = tokenizer.UnicodeTokenizer() text = "สวัสดี" tokenized = tok.tokenize(text) self.assertGreater( len(tokenized.tokens), 1, "Should not be grouped into a single token" ) self.assertLen(tokenized.tokens, 4) def test_cjk_detection_regex(self): """Test that CJK characters are detected and not grouped.""" tok = tokenizer.UnicodeTokenizer() text = "你好" tokenized = tok.tokenize(text) self.assertLen(tokenized.tokens, 2) self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD) self.assertEqual(tokenized.tokens[1].token_type, tokenizer.TokenType.WORD) def test_newline_simplification(self): """Test that newline handling works correctly with the simplified logic.""" tok = tokenizer.UnicodeTokenizer() text = "LineA\nLineB" tokenized = tok.tokenize(text) self.assertLen(tokenized.tokens, 2) self.assertEqual(tokenized.tokens[0].first_token_after_newline, False) self.assertTrue(tokenized.tokens[1].first_token_after_newline) def test_newline_simplification_start(self): """Test newline at start of text.""" tok = tokenizer.UnicodeTokenizer() text = "\nLineA" tokenized = tok.tokenize(text) self.assertLen(tokenized.tokens, 1) self.assertTrue(tokenized.tokens[0].first_token_after_newline) def test_mixed_line_endings(self): """Test mixed line endings (\\r\\n).""" # \\r\\n should be treated as a single newline for the purpose of the flag, # or at least trigger it. text = "LineOne\r\nLineTwo" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(text) self.assertLen(tokenized.tokens, 2) self.assertTrue(tokenized.tokens[1].first_token_after_newline) def test_mixed_uncommon_scripts_no_grouping(self): """Test that adjacent unknown scripts are NOT merged.""" tok = tokenizer.UnicodeTokenizer() # Armenian "Բարև" + Georgian "გამარჯობა". # Both are "unknown" to _COMMON_SCRIPTS, so should not be grouped together. text = "Բարևგამარჯობა" tokenized = tok.tokenize(text) # Unknown scripts are fragmented into characters for safety. self.assertLen( tokenized.tokens, 13, "Should be fragmented into characters for safety (13 tokens)", ) self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD) self.assertEqual(tokenized.tokens[1].token_type, tokenizer.TokenType.WORD) def test_unknown_script_merging_edge_case(self): # Verify that adjacent IDENTICAL unknown scripts are fragmented for safety. # Armenian "Բարև" + Armenian "Բարև". tok = tokenizer.UnicodeTokenizer() text = "ԲարևԲարև" tokenized = tok.tokenize(text) # Should be fragmented into 8 characters self.assertLen(tokenized.tokens, 8) self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD) def test_find_sentence_range_empty_input(self): # Ensure robustness against empty input, which previously caused a crash. interval = tokenizer.find_sentence_range("", [], 0) self.assertEqual(interval, tokenizer.TokenInterval(0, 0)) def test_normalization_indices_match_input(self): """Test that token indices match the ORIGINAL input, not normalized text.""" # "e" + combining acute accent (2 chars) -> NFC "é" (1 char) nfd_text = "e\u0301" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(nfd_text) # We want indices to match input, so CharInterval should be [0, 2). self.assertEqual(tokenized.text, nfd_text) self.assertLen(tokenized.tokens, 1) self.assertEqual(tokenized.tokens[0].char_interval.start_pos, 0) self.assertEqual(tokenized.tokens[0].char_interval.end_pos, 2) def test_acronym_inconsistency(self): """Test that RegexTokenizer does NOT produce ACRONYM tokens (standardization).""" tok = tokenizer.RegexTokenizer() text = "A/B" tokenized = tok.tokenize(text) # Ensure parity with UnicodeTokenizer by splitting acronyms into constituent parts. self.assertLen(tokenized.tokens, 3) self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD) self.assertEqual( tokenized.tokens[1].token_type, tokenizer.TokenType.PUNCTUATION ) self.assertEqual(tokenized.tokens[2].token_type, tokenizer.TokenType.WORD) def test_consecutive_punctuation_grouping(self): """Test that consecutive punctuation is grouped into a single token.""" input_text = "Hello!! World..." expected_tokens = ["Hello", "!!", "World", "..."] tokens = tokenizer.UnicodeTokenizer().tokenize(input_text).tokens self.assertEqual( [ input_text[t.char_interval.start_pos : t.char_interval.end_pos] for t in tokens ], expected_tokens, ) def test_punctuation_merging_identical_only(self): """Test that only identical punctuation is merged.""" input_text = "Hello!! World..." expected_tokens = ["Hello", "!!", "World", "..."] tokens = tokenizer.UnicodeTokenizer().tokenize(input_text).tokens self.assertEqual( [ input_text[t.char_interval.start_pos : t.char_interval.end_pos] for t in tokens ], expected_tokens, ) input_text_mixed = 'End."' expected_tokens_mixed = ["End", ".", '"'] tokens_mixed = ( tokenizer.UnicodeTokenizer().tokenize(input_text_mixed).tokens ) self.assertEqual( [ input_text_mixed[ t.char_interval.start_pos : t.char_interval.end_pos ] for t in tokens_mixed ], expected_tokens_mixed, ) def test_distinct_unknown_scripts_do_not_merge(self): """Verify that distinct unknown scripts (e.g. Bengali vs Devanagari) are not merged.""" # Bengali "অ" (U+0985) and Devanagari "अ" (U+0905) text = "অअ" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(text) # Should be 2 tokens because scripts are different self.assertLen(tokenized.tokens, 2) self.assertEqual(tokenized.tokens[0].char_interval.start_pos, 0) self.assertEqual(tokenized.tokens[0].char_interval.end_pos, 1) self.assertEqual(tokenized.tokens[1].char_interval.start_pos, 1) self.assertEqual(tokenized.tokens[1].char_interval.end_pos, 2) def test_identical_unknown_scripts_merge(self): """Verify that identical unknown scripts merge into a single token.""" # Bengali "অ" (U+0985) and Bengali "আ" (U+0986) text = "অআ" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(text) # Identical unknown scripts are not merged to avoid expensive lookups. self.assertLen(tokenized.tokens, 2) self.assertEqual(tokenized.tokens[0].char_interval.start_pos, 0) self.assertEqual(tokenized.tokens[0].char_interval.end_pos, 1) self.assertEqual(tokenized.tokens[1].char_interval.start_pos, 1) self.assertEqual(tokenized.tokens[1].char_interval.end_pos, 2) class ExceptionTest(absltest.TestCase): """Test custom exception types and error conditions.""" def test_invalid_token_interval_errors(self): """Test that InvalidTokenIntervalError is raised for invalid intervals.""" text = "Hello, world!" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(text) with self.assertRaisesRegex( tokenizer.InvalidTokenIntervalError, "Invalid token interval.*start_index=-1", ): tokenizer.tokens_text( tokenized, tokenizer.TokenInterval(start_index=-1, end_index=1) ) with self.assertRaisesRegex( tokenizer.InvalidTokenIntervalError, "Invalid token interval.*end_index=999", ): tokenizer.tokens_text( tokenized, tokenizer.TokenInterval(start_index=0, end_index=999) ) with self.assertRaisesRegex( tokenizer.InvalidTokenIntervalError, "Invalid token interval.*start_index=2.*end_index=1", ): tokenizer.tokens_text( tokenized, tokenizer.TokenInterval(start_index=2, end_index=1) ) def test_sentence_range_errors(self): """Test that SentenceRangeError is raised for invalid start positions.""" text = "Hello world." tok = tokenizer.UnicodeTokenizer() tokens = tok.tokenize(text).tokens with self.assertRaisesRegex( tokenizer.SentenceRangeError, "start_token_index=-1 out of range" ): tokenizer.find_sentence_range(text, tokens, -1) with self.assertRaisesRegex( tokenizer.SentenceRangeError, "start_token_index=999 out of range.*Total tokens: 3", ): tokenizer.find_sentence_range(text, tokens, 999) # Empty input should NOT raise SentenceRangeError (Feedback 10 Robustness) interval = tokenizer.find_sentence_range("", [], 0) self.assertEqual(interval, tokenizer.TokenInterval(0, 0)) class NegativeTestCases(parameterized.TestCase): """Test cases for invalid input and edge cases.""" @parameterized.named_parameters( dict( testcase_name="invalid_utf8_sequence", input_text="Invalid \ufffd sequence", expected_tokens=[ ("Invalid", tokenizer.TokenType.WORD), ( "\ufffd", tokenizer.TokenType.PUNCTUATION, ), ("sequence", tokenizer.TokenType.WORD), ], ), dict( testcase_name="extremely_long_grapheme_cluster", input_text="e" + "\u0301" * 10, expected_tokens=[ ( "e" + "\u0301" * 10, tokenizer.TokenType.WORD, ), ], ), dict( testcase_name="mixed_valid_invalid_unicode", input_text="Valid текст \ufffd 中文", expected_tokens=[ ("Valid", tokenizer.TokenType.WORD), ("текст", tokenizer.TokenType.WORD), ("\ufffd", tokenizer.TokenType.PUNCTUATION), ("中", tokenizer.TokenType.WORD), ("文", tokenizer.TokenType.WORD), ], ), dict( testcase_name="zero_width_joiners", input_text="Family: 👨‍👩‍👧‍👦", expected_tokens=[ ("Family", tokenizer.TokenType.WORD), (":", tokenizer.TokenType.PUNCTUATION), ( "👨‍👩‍👧‍👦", tokenizer.TokenType.PUNCTUATION, ), ], ), dict( testcase_name="isolated_combining_marks", input_text="\u0301\u0302\u0303 test", expected_tokens=[ ( "\u0301\u0302\u0303", tokenizer.TokenType.PUNCTUATION, ), ("test", tokenizer.TokenType.WORD), ], ), ) def test_invalid_and_edge_case_unicode(self, input_text, expected_tokens): """Test handling of invalid Unicode sequences and edge cases.""" tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize(input_text) self.assertLen( tokenized.tokens, len(expected_tokens), f"Expected {len(expected_tokens)} tokens for edge case '{input_text}'," f" but got {len(tokenized.tokens)}", ) for i, (token, (expected_text, expected_type)) in enumerate( zip(tokenized.tokens, expected_tokens) ): # UPDATE: Tokenizer no longer normalizes to NFC, so we expect original text. # expected_text = unicodedata.normalize("NFC", expected_text) actual_text = tokenized.text[ token.char_interval.start_pos : token.char_interval.end_pos ] self.assertEqual( actual_text, expected_text, f"Token {i} text mismatch. Expected '{expected_text}', got" f" '{actual_text}'", ) self.assertEqual( token.token_type, expected_type, f"Token {i} type mismatch. Expected {expected_type}, got" f" {token.token_type}", ) def test_empty_string_edge_case(self): tok = tokenizer.UnicodeTokenizer() tokenized = tok.tokenize("") self.assertEmpty(tokenized.tokens, "Empty string should produce no tokens") self.assertEqual( tokenized.text, "", "Tokenized text should preserve empty string" ) def test_whitespace_only_string(self): tok = tokenizer.UnicodeTokenizer() test_cases = [ " ", # Spaces "\t\t", # Tabs "\n\n", # Newlines " \t\n\r ", # Mixed whitespace ] for whitespace in test_cases: tokenized = tok.tokenize(whitespace) self.assertEmpty( tokenized.tokens, f"Whitespace-only string '{repr(whitespace)}' should produce no" " tokens", ) class TokensTextTest(parameterized.TestCase): _SENTENCE_WITH_ONE_LINE = "Patient Jane Doe, ID 67890, received 10mg daily." @parameterized.named_parameters( dict( testcase_name="substring_jane_doe", input_text=_SENTENCE_WITH_ONE_LINE, start_index=1, end_index=3, expected_substring="Jane Doe", ), dict( testcase_name="substring_with_punctuation", input_text=_SENTENCE_WITH_ONE_LINE, start_index=0, end_index=4, expected_substring="Patient Jane Doe,", ), dict( testcase_name="numeric_tokens", input_text=_SENTENCE_WITH_ONE_LINE, start_index=5, end_index=6, expected_substring="67890", ), ) def test_valid_intervals( self, input_text, start_index, end_index, expected_substring ): input_tokenized = tokenizer.tokenize(input_text) interval = tokenizer.TokenInterval( start_index=start_index, end_index=end_index ) result_str = tokenizer.tokens_text(input_tokenized, interval) self.assertEqual( result_str, expected_substring, msg=f"Wrong substring for interval {start_index}..{end_index}", ) @parameterized.named_parameters( dict( testcase_name="start_index_negative", input_text=_SENTENCE_WITH_ONE_LINE, start_index=-1, end_index=2, ), dict( testcase_name="end_index_out_of_bounds", input_text=_SENTENCE_WITH_ONE_LINE, start_index=0, end_index=999, ), dict( testcase_name="start_index_gt_end_index", input_text=_SENTENCE_WITH_ONE_LINE, start_index=5, end_index=4, ), ) def test_invalid_intervals(self, input_text, start_index, end_index): input_tokenized = tokenizer.tokenize(input_text) interval = tokenizer.TokenInterval( start_index=start_index, end_index=end_index ) with self.assertRaises(tokenizer.InvalidTokenIntervalError): _ = tokenizer.tokens_text(input_tokenized, interval) class SentenceRangeTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name="simple_sentence", input_text="This is one sentence. Then another?", start_pos=0, expected_interval=(0, 5), ), dict( testcase_name="abbreviation_not_boundary", input_text="Dr. John visited. Then left.", start_pos=0, expected_interval=(0, 5), ), dict( testcase_name="second_line_capital_letter_terminates_sentence", input_text=textwrap.dedent("""\ Blood pressure was 160/90 and patient was recommended to Atenolol 50 mg daily."""), start_pos=0, # "160/90" is now 3 tokens: "160", "/", "90". # Tokens: Blood, pressure, was, 160, /, 90, and, patient, was, recommended, to (11 tokens) expected_interval=(0, 11), ), ) def test_partial_sentence_range( self, input_text, start_pos, expected_interval ): tokenized = tokenizer.tokenize(input_text) tokens = tokenized.tokens interval = tokenizer.find_sentence_range(input_text, tokens, start_pos) expected_start, expected_end = expected_interval self.assertEqual(interval.start_index, expected_start) self.assertEqual(interval.end_index, expected_end) @parameterized.named_parameters( dict( testcase_name="end_of_text", input_text="Only one sentence here", start_pos=0, ), ) def test_full_sentence_range(self, input_text, start_pos): tokenized = tokenizer.tokenize(input_text) tokens = tokenized.tokens interval = tokenizer.find_sentence_range(input_text, tokens, start_pos) self.assertEqual(interval.start_index, 0) self.assertLen(tokens, interval.end_index) @parameterized.named_parameters( dict( testcase_name="out_of_range_negative_start", input_text="Hello world.", start_pos=-1, ), dict( testcase_name="out_of_range_exceeding_length", input_text="Hello world.", start_pos=999, ), ) def test_invalid_start_pos(self, input_text, start_pos): tokenized = tokenizer.tokenize(input_text) tokens = tokenized.tokens with self.assertRaises(tokenizer.SentenceRangeError): tokenizer.find_sentence_range(input_text, tokens, start_pos) def test_sentence_boundary_with_quote(self): """Test that sentence boundary detection works with trailing quotes.""" text = 'He said "Hello."' tokens = tokenizer.UnicodeTokenizer().tokenize(text).tokens interval = tokenizer.find_sentence_range(text, tokens, 0) self.assertEqual(interval.end_index, len(tokens)) def test_sentence_splitting_permissive(self): """Test permissive sentence splitting (quotes, numbers, \\r).""" # Quote-initiated sentence. text_quote = '"The time is now." Next sentence.' tokens = tokenizer.UnicodeTokenizer().tokenize(text_quote).tokens interval = tokenizer.find_sentence_range(text_quote, tokens, 0) self.assertEqual(interval.end_index, 7) # Number-initiated sentence. text_number = "2025 will be good. Really." tokens = tokenizer.tokenize(text_number).tokens interval = tokenizer.find_sentence_range(text_number, tokens, 0) self.assertEqual(interval.end_index, 5) # Carriage return support. text_cr = "Line one.\rLine two." tokens = tokenizer.tokenize(text_cr).tokens interval = tokenizer.find_sentence_range(text_cr, tokens, 0) self.assertEqual(interval.end_index, 3) def test_unicode_sentence_boundaries(self): """Verify that Unicode sentence terminators are respected.""" # Japanese full stop text_jp = "こんにちは。世界。" tokens = tokenizer.UnicodeTokenizer().tokenize(text_jp).tokens interval = tokenizer.find_sentence_range(text_jp, tokens, 0) # "こんにちは" (5 tokens due to CJK fragmentation) + "。" (1 token) = 6 tokens self.assertEqual(interval.end_index, 6) # Hindi Danda text_hi = "नमस्ते। दुनिया।" tokens = tokenizer.UnicodeTokenizer().tokenize(text_hi).tokens interval = tokenizer.find_sentence_range(text_hi, tokens, 0) # "नमस्ते" (1 token, Devanagari is grouped) + "।" (1 token) = 2 tokens self.assertEqual(interval.end_index, 2) def test_configurable_sentence_splitting(self): """Verify that custom abbreviations prevent sentence splitting.""" # Test with custom abbreviations (e.g. German "z.B.") text = "Das ist z.B. ein Test." tok = tokenizer.RegexTokenizer() _ = tok.tokenize(text) text_french = "M. Smith est ici." tokenized_french = tok.tokenize(text_french) # "M." is not in default _KNOWN_ABBREVIATIONS ("Mr.", "Mrs.", etc.) # Default: "M." ends sentence. sentence1 = tokenizer.find_sentence_range( text_french, tokenized_french.tokens, 0 ) self.assertEqual(sentence1.end_index, 2) # Now with custom abbreviations custom_abbrevs = {"M."} sentence2 = tokenizer.find_sentence_range( text_french, tokenized_french.tokens, 0, known_abbreviations=custom_abbrevs, ) # Should NOT split at "M." self.assertEqual(sentence2.end_index, 6) if __name__ == "__main__": absltest.main() ================================================ FILE: tests/visualization_test.py ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for langextract.visualization.""" from unittest import mock from absl.testing import absltest from langextract import visualization from langextract.core import data _PALETTE = visualization._PALETTE _VISUALIZATION_CSS = visualization._VISUALIZATION_CSS class VisualizationTest(absltest.TestCase): def test_assign_colors_basic_assignment(self): extractions = [ data.Extraction( extraction_class="CLASS_A", extraction_text="text_a", char_interval=data.CharInterval(start_pos=0, end_pos=1), ), data.Extraction( extraction_class="CLASS_B", extraction_text="text_b", char_interval=data.CharInterval(start_pos=1, end_pos=2), ), ] # Classes are sorted alphabetically before color assignment. expected_color_map = { "CLASS_A": _PALETTE[0], "CLASS_B": _PALETTE[1], } actual_color_map = visualization._assign_colors(extractions) self.assertDictEqual(actual_color_map, expected_color_map) def test_build_highlighted_text_single_span_correct_html(self): text = "Hello world" extraction = data.Extraction( extraction_class="GREETING", extraction_text="Hello", char_interval=data.CharInterval(start_pos=0, end_pos=5), ) extractions = [extraction] color_map = {"GREETING": "#ff0000"} expected_html = ( 'Hello world' ) actual_html = visualization._build_highlighted_text( text, extractions, color_map ) self.assertEqual(actual_html, expected_html) def test_build_highlighted_text_escapes_html_in_text_and_tooltip(self): text = "Text with content & ampersand." extraction = data.Extraction( extraction_class="UNSAFE_CLASS", extraction_text=" content & ampersand.", char_interval=data.CharInterval(start_pos=10, end_pos=39), attributes={"detail": "Attribute with & 'quote'"}, ) # Highlighting " content & ampersand" extractions = [extraction] color_map = {"UNSAFE_CLASS": "#00ff00"} expected_highlighted_segment = "<unsafe> content & ampersand." expected_html = ( 'Text with {expected_highlighted_segment}' ) actual_html = visualization._build_highlighted_text( text, extractions, color_map ) self.assertEqual(actual_html, expected_html) @mock.patch.object( visualization, "HTML", new=None ) # Ensures visualize returns str def test_visualize_basic_document_renders_correctly(self): doc = data.AnnotatedDocument( text="Patient needs Aspirin.", extractions=[ data.Extraction( extraction_class="MEDICATION", extraction_text="Aspirin", char_interval=data.CharInterval( start_pos=14, end_pos=21 ), # "Aspirin" ) ], ) # Predictable color based on sorted class name "MEDICATION" med_color = _PALETTE[0] body_html = ( 'Patient needs Aspirin.' ) legend_html = ( '
Highlights Legend: MEDICATION
' ) css_html = _VISUALIZATION_CSS expected_components = [ css_html, "lx-animated-wrapper", body_html, legend_html, ] actual_html = visualization.visualize(doc) # Verify expected components appear in output for component in expected_components: self.assertIn(component, actual_html) @mock.patch.object( visualization, "HTML", new=None ) # Ensures visualize returns str def test_visualize_no_extractions_renders_text_and_empty_legend(self): doc = data.AnnotatedDocument(text="No entities here.", extractions=[]) body_html = ( '

No valid extractions to' " animate.

" ) css_html = _VISUALIZATION_CSS expected_html = css_html + body_html actual_html = visualization.visualize(doc) self.assertEqual(actual_html, expected_html) if __name__ == "__main__": absltest.main() ================================================ FILE: tox.ini ================================================ # Copyright 2025 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [tox] envlist = py310, py311, py312, format, lint-src, lint-tests skip_missing_interpreters = True [testenv] setenv = PYTHONWARNINGS = ignore deps = .[openai,dev,test] commands = pytest -ra -m "not live_api and not requires_pip" [testenv:format] skip_install = true deps = isort>=5.13.2 pyink~=24.3.0 commands = isort langextract tests --check-only --diff pyink langextract tests --check --diff --config pyproject.toml [testenv:lint-src] deps = pylint>=3.0.0 commands = pylint --rcfile=.pylintrc langextract [testenv:lint-tests] deps = pylint>=3.0.0 commands = pylint --rcfile=tests/.pylintrc tests [testenv:live-api] basepython = python3.11 passenv = GEMINI_API_KEY LANGEXTRACT_API_KEY OPENAI_API_KEY GOOGLE_APPLICATION_CREDENTIALS GOOGLE_CLOUD_PROJECT deps = .[all,dev,test] commands = pytest tests/test_live_api.py -v -m live_api --maxfail=1 [testenv:ollama-integration] basepython = python3.11 deps = .[openai,dev,test] requests>=2.25.0 commands = pytest tests/test_ollama_integration.py -v --tb=short [testenv:plugin-integration] basepython = python3.11 setenv = PIP_NO_INPUT = 1 PIP_DISABLE_PIP_VERSION_CHECK = 1 deps = .[dev,test] commands = pytest tests/provider_plugin_test.py::PluginE2ETest -v -m "requires_pip" [testenv:plugin-smoke] basepython = python3.11 deps = .[dev,test] commands = pytest tests/provider_plugin_test.py::PluginSmokeTest -v