[
  {
    "path": ".github/ISSUE_TEMPLATE/1-bug.md",
    "content": "---\nname: Bug Report\nabout: Create a bug report to help us improve\ntitle: 'Bug: <brief title of your issue>'\nlabels: 'bug', 'needs triage'\nassignees: ''\n---\n\n## Describe the overall issue and situation\n\nProvide a clear summary of what the issue is about, the area of the project you\nfound it in, and what you were trying to do.\n\n## Expected behavior\n\nProvide a clear and concise description of what you expected to happen\n\n## Actual behavior\n\nProvide a clear and concise description of what actually happened.\n\n## Steps to reproduce the issue\n\nProvide a sequence of steps we can use to reproduce the issue.\n\n1.  <First step...>\n2.  <Second step...>\n3.  <Third step...>\n\n## Any additional content\n\nDescribe your environment or any other set up details that might help us\nreproduce the issue.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/2-feature-request.md",
    "content": "---\nname: Feature Request\nabout: Suggest an idea or improvement\ntitle: 'Request: <brief title of your feature request>'\nlabels: 'enhancement', 'needs triage'\nassignees: ''\n---\n\n## Describe the overall idea and motivation\n\nProvide a clear summary of the idea and what use cases it's addressing.\n\n## Related to an issue?\n\nIs this addressing a known / documented issue? If so, which one?\n\n## Possible solutions and alternatives\n\nDo you already have an idea of how the solution should work? If so, document\nthat here.\n\nAlso, if there are alternatives, please document those as well.\n\n## Priority and timeline considerations\n\nIs this time sensitive? Is it a nice to have? Please describe what priority you\nfeel this should have and why. We'll take this into advisement as we go through\nour internal prioritization process.\n\n## Additional context\n\nIs there anything else to consider that wasn't covered by the above?\n\nWould you like to contribute to the project and work on this request?\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Allow users to create issues that don't follow the templates since they don't cover all use cases\nblank_issues_enabled: true\n\n# Redirect users to other channels for general support or security issues\ncontact_links:\n  - name: Community Support\n    url: https://github.com/google/langextract/discussions\n    about: Please ask and answer questions here.\n  - name: Security Bug Reporting\n    url: https://g.co/vulnz\n    about: >\n      To report a security issue, please use https://g.co/vulnz. The Google Security Team will\n      respond within 5 working days of your report on https://g.co/vulnz.\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE/pull_request_template.md",
    "content": "# Description\n\nReplace this with a clear and concise change description\n\n<!--- Important: All PRs must be linked to at least one issue (except for\n  extremely trivial and straightforward changes). --->\n\n<!--- This issue (or issues) should document the motivation, context,\n  alternatives considered, risks (such as breaking backwards compatibility), and\n  any new dependencies. --->\n\n<!--- Use \"Fixes #123\" to auto-close the issue when merged (for bug fixes/implementations) -->\n<!--- Use \"Related to #123\" or \"Addresses #123\" for documentation updates or partial solutions -->\nFixes/Related to #[issue number]\n\nChoose one: (Bug fix | Feature | Documentation | Testing | Code health | Other)\n\n# How Has This Been Tested?\n\nReplace this with a description of the tests that you ran to verify your\nchanges. If executing the existing test suite without customization, simply\npaste the command line used.\n\n```\n$ python -m unittest discover ...\n```\n\n# Checklist:\n\n<!--- Put an `x` in the box if you did the task -->\n\n<!--- If you forgot a task please follow the instructions below -->\n\n-   [ ] I have read and acknowledged Google's Open Source\n    [Code of conduct](https://opensource.google/conduct).\n-   [ ] I have read the\n    [Contributing](https://github.com/google-health/langextract/blob/master/CONTRIBUTING.md)\n    page, and I either signed the Google\n    [Individual CLA](https://cla.developers.google.com/about/google-individual)\n    or am covered by my company's\n    [Corporate CLA](https://cla.developers.google.com/about/google-corporate).\n-   [ ] I have discussed my proposed solution with code owners in the linked\n    issue(s) and we have agreed upon the general approach.\n-   [ ] I have made any needed documentation changes, or noted in the linked\n    issue(s) that documentation elsewhere needs updating.\n-   [ ] I have added tests, or I have ensured existing tests cover the changes\n-   [ ] I have followed\n    [Google's Python Style Guide](https://google.github.io/styleguide/pyguide.html)\n    and ran `pylint` over the affected code.\n"
  },
  {
    "path": ".github/scripts/add-new-checks.sh",
    "content": "#!/bin/bash\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Script to add new required status checks to an existing branch protection rule.\n# This preserves all your current settings and just adds the new checks\n\necho \"Adding new PR validation checks to existing branch protection...\"\n\n# Add the new checks to existing ones\necho \"Adding new checks: enforce, size, and protect-infrastructure...\"\ngh api repos/:owner/:repo/branches/main/protection/required_status_checks/contexts \\\n  --method POST \\\n  --input - <<< '[\"enforce\", \"size\", \"protect-infrastructure\"]'\n\necho \"\"\necho \"✓ New checks added!\"\necho \"\"\necho \"Updated required status checks will include:\"\necho \"- test (3.10)                    [existing]\"\necho \"- test (3.11)                    [existing]\"\necho \"- test (3.12)                    [existing]\"\necho \"- Validate PR Template           [existing]\"\necho \"- live-api-tests                 [existing]\"\necho \"- ollama-integration-test        [existing]\"\necho \"- enforce                        [NEW - linked issue validation]\"\necho \"- size                           [NEW - PR size limit]\"\necho \"- protect-infrastructure         [NEW - infrastructure file protection]\"\n"
  },
  {
    "path": ".github/scripts/add-size-labels.sh",
    "content": "#!/bin/bash\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Add size labels to PRs based on their change count\n\necho \"Adding size labels to PRs...\"\n\n# Get all open PRs with their additions and deletions\ngh pr list --limit 50 --json number,additions,deletions --jq '.[]' | while read -r pr_data; do\n    pr_number=$(echo \"$pr_data\" | jq -r '.number')\n    additions=$(echo \"$pr_data\" | jq -r '.additions')\n    deletions=$(echo \"$pr_data\" | jq -r '.deletions')\n    total_changes=$((additions + deletions))\n\n    # Determine size label\n    if [ $total_changes -lt 50 ]; then\n        size_label=\"size/XS\"\n    elif [ $total_changes -lt 150 ]; then\n        size_label=\"size/S\"\n    elif [ $total_changes -lt 600 ]; then\n        size_label=\"size/M\"\n    elif [ $total_changes -lt 1000 ]; then\n        size_label=\"size/L\"\n    else\n        size_label=\"size/XL\"\n    fi\n\n    echo \"PR #$pr_number: $total_changes lines -> $size_label\"\n\n    # Remove any existing size labels first\n    existing_labels=$(gh pr view $pr_number --json labels --jq '.labels[].name' | grep \"^size/\" || true)\n    if [ ! -z \"$existing_labels\" ]; then\n        echo \"  Removing existing label: $existing_labels\"\n        gh pr edit $pr_number --remove-label \"$existing_labels\"\n    fi\n\n    # Add the new size label\n    gh pr edit $pr_number --add-label \"$size_label\"\n\n    sleep 1  # Avoid rate limiting\ndone\n\necho \"Done adding size labels!\"\n"
  },
  {
    "path": ".github/scripts/revalidate-all-prs.sh",
    "content": "#!/bin/bash\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Revalidate all open PRs\n\necho \"Fetching all open PRs...\"\nPR_NUMBERS=$(gh pr list --limit 50 --json number --jq '.[].number')\nTOTAL=$(echo \"$PR_NUMBERS\" | wc -w | tr -d ' ')\n\necho \"Found $TOTAL open PRs\"\necho \"Starting revalidation...\"\necho \"\"\n\nCOUNT=0\nfor pr in $PR_NUMBERS; do\n    COUNT=$((COUNT + 1))\n    echo \"[$COUNT/$TOTAL] Triggering revalidation for PR #$pr...\"\n    gh workflow run revalidate-pr.yml -f pr_number=$pr\n\n    # Small delay to avoid rate limiting\n    sleep 2\ndone\n\necho \"\"\necho \"All workflows triggered!\"\necho \"\"\necho \"To monitor progress:\"\necho \"  gh run list --workflow=revalidate-pr.yml --limit=$TOTAL\"\necho \"\"\necho \"To see results, check comments on each PR\"\n"
  },
  {
    "path": ".github/scripts/zenodo_publish.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Publish a new version to Zenodo via REST API.\n\nThis script reads project metadata from pyproject.toml to avoid duplication.\nFor subsequent releases, it creates new versions from the existing Zenodo record,\ninheriting most metadata automatically.\n\"\"\"\n\nimport glob\nimport os\nimport sys\nimport tomllib\nimport urllib.request\n\nimport requests\n\nAPI = \"https://zenodo.org/api\"\nTOKEN = os.environ[\"ZENODO_TOKEN\"]\nRECORD_ID = os.environ[\"ZENODO_RECORD_ID\"]\nVERSION = os.environ[\"RELEASE_TAG\"].lstrip(\"v\")\nREPO = os.environ[\"GITHUB_REPOSITORY\"]\nSERVER = os.environ.get(\"GITHUB_SERVER_URL\", \"https://github.com\")\nHEADERS = {\n    \"Authorization\": f\"Bearer {TOKEN}\",\n    \"Content-Type\": \"application/json\",\n}\n\ntry:\n  with open(\"pyproject.toml\", \"rb\") as f:\n    pyproject = tomllib.load(f)\n    PROJECT_META = pyproject[\"project\"]\n    PROJECT = PROJECT_META[\"name\"]\nexcept (KeyError, FileNotFoundError) as e:\n  print(f\"❌ Error loading project metadata: {e}\", file=sys.stderr)\n  sys.exit(1)\n\n\ndef new_version_from_record(record_id: str):\n  \"\"\"Create a new draft that inherits metadata from the latest published record.\"\"\"\n  r = requests.post(\n      f\"{API}/deposit/depositions/{record_id}/actions/newversion\",\n      headers=HEADERS,\n      timeout=30,\n  )\n  r.raise_for_status()\n  # Zenodo returns a link to the draft, not the draft itself\n  latest_draft_url = r.json()[\"links\"][\"latest_draft\"]\n  return requests.get(latest_draft_url, headers=HEADERS, timeout=30).json()\n\n\ndef upload_file(bucket_url: str, path: str, dest_name: str = None):\n  \"\"\"Upload a file to the deposition bucket.\"\"\"\n  dest = dest_name or os.path.basename(path)\n  with open(path, \"rb\") as fp:\n    r = requests.put(\n        f\"{bucket_url}/{dest}\",\n        data=fp,\n        headers={\"Authorization\": f\"Bearer {TOKEN}\"},\n        timeout=60,\n    )\n    r.raise_for_status()\n\n\ndef main():\n  \"\"\"Main workflow.\"\"\"\n  try:\n    draft = new_version_from_record(RECORD_ID)\n\n    bucket = draft[\"links\"][\"bucket\"]\n    dep_id = draft[\"id\"]\n\n    # GitHub auto-generates source archives for tags\n    tarball = f\"/tmp/{PROJECT}-v{VERSION}.tar.gz\"\n    src_url = f\"{SERVER}/{REPO}/archive/refs/tags/v{VERSION}.tar.gz\"\n    urllib.request.urlretrieve(src_url, tarball)\n    upload_file(bucket, tarball, f\"{PROJECT}-{VERSION}.tar.gz\")\n\n    for path in glob.glob(\"dist/*\"):\n      upload_file(bucket, path)\n\n    # Update only version-specific metadata; rest is inherited\n    meta = {\n        \"metadata\": {\n            \"title\": f\"{PROJECT.replace('-', ' ').title()} v{VERSION}\",\n            \"version\": VERSION,\n            \"upload_type\": \"software\",\n        }\n    }\n    r = requests.put(\n        f\"{API}/deposit/depositions/{dep_id}\",\n        headers=HEADERS,\n        json=meta,\n        timeout=30,\n    )\n    r.raise_for_status()\n\n    # Publish to mint DOI\n    r = requests.post(\n        f\"{API}/deposit/depositions/{dep_id}/actions/publish\",\n        headers=HEADERS,\n        timeout=30,\n    )\n    r.raise_for_status()\n    record = r.json()\n\n    doi = record.get(\"doi\")\n    record_id = record.get(\"record_id\")\n\n    print(f\"✅ Published to Zenodo: https://doi.org/{doi}\")\n\n    if \"GITHUB_OUTPUT\" in os.environ:\n      with open(os.environ[\"GITHUB_OUTPUT\"], \"a\") as f:\n        f.write(f\"doi={doi}\\n\")\n        f.write(f\"record_id={record_id}\\n\")\n        f.write(f\"zenodo_url=https://zenodo.org/records/{record_id}\\n\")\n\n    return 0\n\n  except Exception as e:\n    print(f\"❌ Error: {e}\", file=sys.stderr)\n    return 1\n\n\nif __name__ == \"__main__\":\n  sys.exit(main())\n"
  },
  {
    "path": ".github/workflows/auto-update-pr.yaml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: Auto Update PR\n\non:\n  push:\n    branches: [main]\n  schedule:\n    # Run daily at 2 AM UTC to catch stale PRs\n    - cron: '0 2 * * *'\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: 'PR number to update (optional, updates all if not specified)'\n        required: false\n        type: string\n\npermissions:\n  contents: write  # Required for updateBranch API\n  pull-requests: write\n  issues: write\n\njobs:\n  update-prs:\n    runs-on: ubuntu-latest\n    concurrency:\n      group: auto-update-pr-${{ github.event_name }}\n      cancel-in-progress: true\n    steps:\n      - name: Update PRs that are behind main\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const prNumber = context.payload.inputs?.pr_number;\n\n            // Get list of open PRs\n            const prs = prNumber\n              ? [(await github.rest.pulls.get({\n                  owner: context.repo.owner,\n                  repo: context.repo.repo,\n                  pull_number: parseInt(prNumber)\n                })).data]\n              : await github.paginate(github.rest.pulls.list, {\n                  owner: context.repo.owner,\n                  repo: context.repo.repo,\n                  state: 'open',\n                  sort: 'updated',\n                  direction: 'desc'\n                });\n\n            console.log(`Found ${prs.length} open PRs to check`);\n\n            // Constants for comment flood control\n            const UPDATE_COMMENT_COOLDOWN_DAYS = 7;\n            const COOLDOWN_MS = UPDATE_COMMENT_COOLDOWN_DAYS * 24 * 60 * 60 * 1000;\n\n            for (const pr of prs) {\n              // Skip bot PRs and drafts\n              if (pr.user.login.includes('[bot]')) {\n                console.log(`Skipping bot PR #${pr.number} from ${pr.user.login}`);\n                continue;\n              }\n              if (pr.draft) {\n                console.log(`Skipping draft PR #${pr.number}`);\n                continue;\n              }\n\n              try {\n                // Check if PR is behind main (base...head comparison)\n                const { data: comparison } = await github.rest.repos.compareCommits({\n                  owner: context.repo.owner,\n                  repo: context.repo.repo,\n                  base: pr.base.ref,  // main branch\n                  head: `${pr.head.repo.owner.login}:${pr.head.ref}`  // Fully qualified ref for forks\n                });\n\n                if (comparison.behind_by > 0) {\n                  console.log(`PR #${pr.number} is ${comparison.behind_by} commits behind ${pr.base.ref}`);\n\n                  // Check if the PR allows maintainer edits\n                  if (pr.maintainer_can_modify) {\n                    // Try to update the branch\n                    try {\n                      await github.rest.pulls.updateBranch({\n                        owner: context.repo.owner,\n                        repo: context.repo.repo,\n                        pull_number: pr.number\n                      });\n\n                      console.log(`✅ Updated PR #${pr.number}`);\n\n                      // Add a comment\n                      await github.rest.issues.createComment({\n                        owner: context.repo.owner,\n                        repo: context.repo.repo,\n                        issue_number: pr.number,\n                        body: `🔄 **Branch Updated**\\n\\nYour branch was ${comparison.behind_by} commits behind \\`${pr.base.ref}\\` and has been automatically updated. CI checks will re-run shortly.`\n                      });\n                    } catch (updateError) {\n                      console.log(`Could not auto-update PR #${pr.number}: ${updateError.message}`);\n\n                      // Determine the reason for failure\n                      let failureReason = '';\n                      if (updateError.status === 409 || updateError.message.includes('merge conflict')) {\n                        failureReason = '\\n\\n**Note:** Automatic update failed due to merge conflicts. Please resolve them manually.';\n                      } else if (updateError.status === 422) {\n                        failureReason = '\\n\\n**Note:** Cannot push to fork. Please update manually.';\n                      }\n\n                      // Notify the contributor to update manually\n                      await github.rest.issues.createComment({\n                        owner: context.repo.owner,\n                        repo: context.repo.repo,\n                        issue_number: pr.number,\n                        body: `⚠️ **Branch Update Required**\\n\\nYour branch is ${comparison.behind_by} commits behind \\`${pr.base.ref}\\`.${failureReason}\\n\\nPlease update your branch:\\n\\n\\`\\`\\`bash\\ngit fetch origin ${pr.base.ref}\\ngit merge origin/${pr.base.ref}\\ngit push\\n\\`\\`\\`\\n\\nOr use GitHub's \"Update branch\" button if available.`\n                      });\n                    }\n                  } else {\n                    // Can't modify, just notify\n                    console.log(`PR #${pr.number} doesn't allow maintainer edits`);\n\n                    // Check if we already commented recently (within last 7 days)\n                    const { data: comments } = await github.rest.issues.listComments({\n                      owner: context.repo.owner,\n                      repo: context.repo.repo,\n                      issue_number: pr.number,\n                      since: new Date(Date.now() - COOLDOWN_MS).toISOString()\n                    });\n\n                    const hasRecentUpdateComment = comments.some(c =>\n                      c.body?.includes('Branch Update Required') &&\n                      c.user?.login === 'github-actions[bot]'\n                    );\n\n                    if (!hasRecentUpdateComment) {\n                      await github.rest.issues.createComment({\n                        owner: context.repo.owner,\n                        repo: context.repo.repo,\n                        issue_number: pr.number,\n                        body: `⚠️ **Branch Update Required**\\n\\nYour branch is ${comparison.behind_by} commits behind \\`${pr.base.ref}\\`. Please update your branch to ensure CI checks run with the latest code:\\n\\n\\`\\`\\`bash\\ngit fetch origin ${pr.base.ref}\\ngit merge origin/${pr.base.ref}\\ngit push\\n\\`\\`\\`\\n\\nNote: Enable \"Allow edits by maintainers\" to allow automatic updates.`\n                      });\n                    }\n                  }\n                } else {\n                  console.log(`PR #${pr.number} is up to date`);\n                }\n              } catch (error) {\n                console.error(`Error processing PR #${pr.number}:`, error.message);\n              }\n            }\n\n            // Log rate limit status\n            const { data: rateLimit } = await github.rest.rateLimit.get();\n            console.log(`API rate limit remaining: ${rateLimit.rate.remaining}/${rateLimit.rate.limit}`);\n"
  },
  {
    "path": ".github/workflows/check-infrastructure-changes.yml",
    "content": "name: Protect Infrastructure Files\n\non:\n  pull_request_target:\n    types: [opened, synchronize, reopened]\n  workflow_dispatch:\n\npermissions:\n  contents: read\n  pull-requests: write\n\njobs:\n  protect-infrastructure:\n    if: github.event_name == 'workflow_dispatch' || github.event.pull_request.draft == false\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Check for infrastructure file changes\n        if: github.event_name == 'pull_request_target'\n        uses: actions/github-script@v7\n        with:\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          script: |\n            // Get the PR author and check if they're a maintainer\n            const prAuthor = context.payload.pull_request.user.login;\n            const { data: authorPermission } = await github.rest.repos.getCollaboratorPermissionLevel({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              username: prAuthor\n            });\n\n            const isMaintainer = ['admin', 'maintain'].includes(authorPermission.permission);\n\n            // Get list of files changed in the PR\n            const { data: files } = await github.rest.pulls.listFiles({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              pull_number: context.payload.pull_request.number\n            });\n\n            // Check for infrastructure file changes\n            const infrastructureFiles = files.filter(file =>\n              file.filename.startsWith('.github/') ||\n              file.filename === 'pyproject.toml' ||\n              file.filename === 'tox.ini' ||\n              file.filename === '.pre-commit-config.yaml' ||\n              file.filename === '.pylintrc' ||\n              file.filename === 'Dockerfile' ||\n              file.filename === 'autoformat.sh' ||\n              file.filename === '.gitignore' ||\n              file.filename === 'CONTRIBUTING.md' ||\n              file.filename === 'LICENSE' ||\n              file.filename === 'CITATION.cff'\n            );\n\n            if (infrastructureFiles.length > 0 && !isMaintainer) {\n              // Check if changes are only formatting/whitespace\n              let hasStructuralChanges = false;\n              for (const file of infrastructureFiles) {\n                const additions = file.additions || 0;\n                const deletions = file.deletions || 0;\n                const changes = file.changes || 0;\n\n                // If file has significant changes (not just whitespace), consider it structural\n                if (additions > 5 || deletions > 5 || changes > 10) {\n                  hasStructuralChanges = true;\n                  break;\n                }\n              }\n\n              const fileList = infrastructureFiles.map(f => `  - ${f.filename} (${f.changes} changes)`).join('\\n');\n\n              // Post a comment explaining the issue\n              await github.rest.issues.createComment({\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n                issue_number: context.payload.pull_request.number,\n                body: `❌ **Infrastructure File Protection**\\n\\n` +\n                      `This PR modifies protected infrastructure files:\\n\\n${fileList}\\n\\n` +\n                      `Only repository maintainers are allowed to modify infrastructure files (including \\`.github/\\`, build configuration, and repository documentation).\\n\\n` +\n                      `**Note**: If these are only formatting changes, please:\\n` +\n                      `1. Revert changes to \\`.github/\\` files\\n` +\n                      `2. Use \\`./autoformat.sh\\` to format only source code directories\\n` +\n                      `3. Avoid running formatters on infrastructure files\\n\\n` +\n                      `If structural changes are necessary:\\n` +\n                      `1. Open an issue describing the needed infrastructure changes\\n` +\n                      `2. A maintainer will review and implement the changes if approved\\n\\n` +\n                      `For more information, see our [Contributing Guidelines](https://github.com/google/langextract/blob/main/CONTRIBUTING.md).`\n              });\n\n              core.setFailed(\n                `This PR modifies ${infrastructureFiles.length} protected infrastructure file(s). ` +\n                `Only maintainers can modify these files. ` +\n                `Use ./autoformat.sh to format code without touching infrastructure.`\n              );\n            } else if (infrastructureFiles.length > 0 && isMaintainer) {\n              core.info(`PR modifies ${infrastructureFiles.length} infrastructure file(s) - allowed for maintainer ${prAuthor}`);\n            } else {\n              core.info('No infrastructure files modified');\n            }\n"
  },
  {
    "path": ".github/workflows/check-linked-issue.yml",
    "content": "name: Require linked issue with community support\n\non:\n  pull_request_target:\n    types: [opened, edited, synchronize, reopened, ready_for_review]\n\npermissions:\n  contents: read\n  issues: write\n  pull-requests: write\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  enforce:\n    if: github.event_name == 'pull_request_target' && !github.event.pull_request.draft\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Check linked issue and community support\n        uses: actions/github-script@v7\n        with:\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          script: |\n            // Strip code blocks and inline code to avoid false matches\n            const stripCode = txt =>\n              txt.replace(/```[\\s\\S]*?```/g, '').replace(/`[^`]*`/g, '');\n\n            // Combine title + body for comprehensive search\n            const prText = stripCode(`${context.payload.pull_request.title || ''}\\n${context.payload.pull_request.body || ''}`);\n\n            // Issue reference pattern: #123, org/repo#123, or full URL (with http/https and optional www)\n            const issueRef = String.raw`(?:#(?<num>\\d+)|(?<o1>[\\w.-]+)\\/(?<r1>[\\w.-]+)#(?<n1>\\d+)|https?:\\/\\/(?:www\\.)?github\\.com\\/(?<o2>[\\w.-]+)\\/(?<r2>[\\w.-]+)\\/issues\\/(?<n2>\\d+))`;\n\n            // Keywords - supporting common variants\n            const closingRe = new RegExp(String.raw`\\b(?:close[sd]?|fix(?:e[sd])?|resolve[sd]?)\\b\\s*:?\\s+${issueRef}`, 'gi');\n            const referenceRe = new RegExp(String.raw`\\b(?:related\\s+to|relates\\s+to|refs?|part\\s+of|addresses|see(?:\\s+also)?|depends\\s+on|blocked\\s+by|supersedes)\\b\\s*:?\\s+${issueRef}`, 'gi');\n\n            // Gather all matches\n            const closings = [...prText.matchAll(closingRe)];\n            const references = [...prText.matchAll(referenceRe)];\n            const first = closings[0] || references[0];\n\n            // Check for draft PRs and bots\n            const pr = context.payload.pull_request;\n            const isDraft = !!pr.draft;\n            const login = pr.user.login;\n            const isBot = pr.user.type === 'Bot' || /\\[bot\\]$/.test(login);\n\n            if (isDraft || isBot) {\n              core.info('Draft or bot PR – skipping enforcement');\n              return;\n            }\n\n            // Check if PR author is a maintainer\n            let authorPerm = 'none';\n            try {\n              const { data } = await github.rest.repos.getCollaboratorPermissionLevel({\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n                username: pr.user.login,\n              });\n              authorPerm = data.permission || 'none';\n            } catch (_) {\n              // User might not have any permissions\n            }\n\n            core.info(`Author permission: ${authorPerm}`);\n            const isMaintainer = ['admin', 'maintain'].includes(authorPerm);  // Removed 'write' for stricter maintainer definition\n\n            // Maintainers bypass entirely\n            if (isMaintainer) {\n              core.info(`Maintainer ${pr.user.login} - bypassing linked issue requirement`);\n              return;\n            }\n\n            if (!first) {\n              // Check for existing comment to avoid duplicates\n              const MARKER = '<!-- linkcheck:missing-issue -->';\n              const existing = await github.paginate(github.rest.issues.listComments, {\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n                issue_number: context.payload.pull_request.number,\n                per_page: 100,\n              });\n              const alreadyLeft = existing.some(c => c.body && c.body.includes(MARKER));\n\n              if (!alreadyLeft) {\n                const contribUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/blob/main/CONTRIBUTING.md#pull-request-guidelines`;\n                const commentBody = [\n                  'No linked issues found. Please link an issue in your pull request description or title.',\n                  '',\n                  `Per our [Contributing Guidelines](${contribUrl}), all PRs must:`,\n                  '- Reference an issue with one of:',\n                  '  - **Closing keywords**: `Fixes #123`, `Closes #123`, `Resolves #123` (auto-closes on merge in the same repository)',\n                  '  - **Reference keywords**: `Related to #123`, `Refs #123`, `Part of #123`, `See #123` (links without closing)',\n                  '- The linked issue should have 5+ 👍 reactions from unique users (excluding bots and the PR author)',\n                  '- Include discussion demonstrating the importance of the change',\n                  '',\n                  'You can also use cross-repo references like `owner/repo#123` or full URLs.',\n                  '',\n                  MARKER\n                ].join('\\n');\n\n                await github.rest.issues.createComment({\n                  owner: context.repo.owner,\n                  repo: context.repo.repo,\n                  issue_number: context.payload.pull_request.number,\n                  body: commentBody\n                });\n              }\n\n              core.setFailed('No linked issue found. Use \"Fixes #123\" to close an issue or \"Related to #123\" to reference it.');\n              return;\n            }\n\n            // Resolve owner/repo/number, defaulting to the current repo\n            const groups = first.groups || {};\n            const owner = groups.o1 || groups.o2 || context.repo.owner;\n            const repo = groups.r1 || groups.r2 || context.repo.repo;\n            const issue_number = Number(groups.num || groups.n1 || groups.n2);\n\n            // Validate issue number\n            if (!Number.isInteger(issue_number) || issue_number <= 0) {\n              core.setFailed(\n                'Found a potential issue link but no valid number. ' +\n                'Use \"Fixes #123\" or \"Related to owner/repo#123\".'\n              );\n              return;\n            }\n\n            core.info(`Found linked issue: ${owner}/${repo}#${issue_number}`);\n\n            // Count unique users who reacted with 👍 on the linked issue (excluding bots and PR author)\n            try {\n              const reactions = await github.paginate(github.rest.reactions.listForIssue, {\n                owner,\n                repo,\n                issue_number,\n                per_page: 100,\n              });\n\n              const prAuthorId = pr.user.id;\n              const uniqueThumbs = new Set(\n                reactions\n                  .filter(r =>\n                    r.content === '+1' &&\n                    r.user &&\n                    r.user.id !== prAuthorId &&\n                    r.user.type !== 'Bot' &&\n                    !String(r.user.login || '').endsWith('[bot]')\n                  )\n                  .map(r => r.user.id)\n              ).size;\n\n              core.info(`Issue ${owner}/${repo}#${issue_number} has ${uniqueThumbs} unique 👍 reactions`);\n\n              const REQUIRED_THUMBS_UP = 5;\n              if (uniqueThumbs < REQUIRED_THUMBS_UP) {\n                core.setFailed(`Linked issue ${owner}/${repo}#${issue_number} has only ${uniqueThumbs} 👍 (need ${REQUIRED_THUMBS_UP}).`);\n                return;\n              }\n            } catch (error) {\n              const isSameRepo = owner === context.repo.owner && repo === context.repo.repo;\n              if (error.status === 404 || error.status === 403) {\n                if (!isSameRepo) {\n                  core.setFailed(\n                    `Linked issue ${owner}/${repo}#${issue_number} is not accessible. ` +\n                    `Please link to an issue in ${context.repo.owner}/${context.repo.repo} or a public repo.`\n                  );\n                } else {\n                  core.info(`Cannot access reactions for ${owner}/${repo}#${issue_number}; skipping enforcement for same-repo issue.`);\n                }\n                return;\n              }\n\n              // Any other error should fail to prevent accidental bypass\n              const msg = (error && error.message) ? String(error.message).toLowerCase() : '';\n              const isRateLimit = msg.includes('rate limit') || error?.headers?.['x-ratelimit-remaining'] === '0';\n\n              if (isRateLimit) {\n                core.setFailed(`Rate limit while checking reactions for ${owner}/${repo}#${issue_number}. Please retry the workflow.`);\n              } else {\n                core.setFailed(`Unexpected error checking reactions for ${owner}/${repo}#${issue_number}: ${error?.message || error}`);\n              }\n            }\n"
  },
  {
    "path": ".github/workflows/check-pr-size.yml",
    "content": "name: Check PR size\n\non:\n  pull_request_target:\n    types: [opened, synchronize, reopened]\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: 'PR number to check (optional)'\n        required: false\n        type: string\n\npermissions:\n  contents: read\n  pull-requests: write\n  issues: write\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  size:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Get PR data for manual trigger\n        if: github.event_name == 'workflow_dispatch' && github.event.inputs.pr_number\n        id: get_pr\n        uses: actions/github-script@v7\n        with:\n          result-encoding: string\n          script: |\n            const { data } = await github.rest.pulls.get({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              pull_number: ${{ github.event.inputs.pr_number }}\n            });\n            return JSON.stringify(data);\n\n      - name: Evaluate PR size\n        if: github.event_name == 'pull_request_target' || (github.event_name == 'workflow_dispatch' && github.event.inputs.pr_number)\n        uses: actions/github-script@v7\n        env:\n          PR_JSON: ${{ steps.get_pr.outputs.result }}\n        with:\n          script: |\n            const pr = context.payload.pull_request || JSON.parse(process.env.PR_JSON || '{}');\n            if (!pr || !pr.number) {\n              core.setFailed('Unable to resolve PR data. For workflow_dispatch, pass a valid pr_number.');\n              return;\n            }\n\n            // Check for draft PRs and bots\n            const isDraft = !!pr.draft;\n            const login = pr.user.login;\n            const isBot = pr.user.type === 'Bot' || /\\[bot\\]$/.test(login);\n\n            if (isDraft || isBot) {\n              core.info('Draft or bot PR – skipping size enforcement');\n              return;\n            }\n\n            const totalChanges = pr.additions + pr.deletions;\n            core.info(`PR contains ${pr.additions} additions and ${pr.deletions} deletions (${totalChanges} total)`);\n\n            const sizeLabel =\n              totalChanges < 50   ? 'size/XS' :\n              totalChanges < 150  ? 'size/S'  :\n              totalChanges < 600  ? 'size/M'  :\n              totalChanges < 1000 ? 'size/L'  : 'size/XL';\n\n            // Re-fetch labels to avoid acting on stale payload data\n            const { data: freshIssue } = await github.rest.issues.get({\n              ...context.repo,\n              issue_number: pr.number\n            });\n            const currentLabels = (freshIssue.labels || []).map(l => l.name);\n\n            // Remove old size labels before adding new one\n            const allSizeLabels = ['size/XS', 'size/S', 'size/M', 'size/L', 'size/XL'];\n            const toRemove = currentLabels.filter(name => allSizeLabels.includes(name) && name !== sizeLabel);\n\n            for (const name of toRemove) {\n              try {\n                await github.rest.issues.removeLabel({\n                  ...context.repo,\n                  issue_number: pr.number,\n                  name\n                });\n              } catch (_) {\n                // Ignore if already removed\n              }\n            }\n\n            await github.rest.issues.addLabels({\n              ...context.repo,\n              issue_number: pr.number,\n              labels: [sizeLabel]\n            });\n\n            // Check if PR author is a maintainer\n            let authorPerm = 'none';\n            try {\n              const { data } = await github.rest.repos.getCollaboratorPermissionLevel({\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n                username: pr.user.login,\n              });\n              authorPerm = data.permission || 'none';\n            } catch (_) {\n              // User might not have any permissions\n            }\n\n            core.info(`Author permission: ${authorPerm}`);\n            const isMaintainer = ['admin', 'maintain'].includes(authorPerm); // Stricter maintainer definition\n\n            // Check for bypass label (using fresh labels)\n            const hasBypass = currentLabels.includes('bypass:size-limit');\n\n            const MAX_LINES = 1000;\n            if (totalChanges > MAX_LINES) {\n              if (isMaintainer || hasBypass) {\n                core.info(`${isMaintainer ? 'Maintainer' : 'Bypass label'} - allowing large PR with ${totalChanges} lines`);\n              } else {\n                core.setFailed(\n                  `This PR contains ${totalChanges} lines of changes, which exceeds the maximum of ${MAX_LINES} lines. ` +\n                  `Please split this into smaller, focused pull requests.`\n                );\n              }\n            }\n"
  },
  {
    "path": ".github/workflows/check-pr-up-to-date.yaml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: Check PR Up-to-Date\n\non:\n  pull_request:\n    types: [opened, synchronize]\n\npermissions:\n  contents: read\n  pull-requests: write\n\njobs:\n  check-up-to-date:\n    runs-on: ubuntu-latest\n    # Skip for bot PRs\n    if: ${{ !contains(github.actor, '[bot]') }}\n    concurrency:\n      group: check-pr-${{ github.event.pull_request.number }}\n      cancel-in-progress: true\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 2  # Sufficient for rev-list comparison\n\n      - name: Check if PR is up-to-date with main\n        id: check\n        run: |\n          # Fetch the latest main branch\n          git fetch origin main\n\n          # Check how many commits behind main\n          BEHIND=$(git rev-list --count HEAD..origin/main)\n\n          echo \"commits_behind=$BEHIND\" >> $GITHUB_OUTPUT\n\n          if [ \"$BEHIND\" -gt 0 ]; then\n            echo \"::warning::PR is $BEHIND commits behind main\"\n            exit 0  # Don't fail the check, just warn\n          else\n            echo \"PR is up-to-date with main\"\n          fi\n\n      - name: Comment if PR needs update\n        if: ${{ steps.check.outputs.commits_behind != '0' }}\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const behind = ${{ steps.check.outputs.commits_behind }};\n            const COMMENT_COOLDOWN_HOURS = 24;\n            const COOLDOWN_MS = COMMENT_COOLDOWN_HOURS * 60 * 60 * 1000;\n\n            // Check for recent similar comments\n            const { data: comments } = await github.rest.issues.listComments({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              issue_number: context.payload.pull_request.number,\n              per_page: 10\n            });\n\n            const hasRecentComment = comments.some(c =>\n              c.body?.includes('commits behind `main`') &&\n              c.user?.login === 'github-actions[bot]' &&\n              new Date(c.created_at) > new Date(Date.now() - COOLDOWN_MS)\n            );\n\n            if (!hasRecentComment) {\n              await github.rest.issues.createComment({\n                owner: context.repo.owner,\n                repo: context.repo.repo,\n                issue_number: context.payload.pull_request.number,\n                body: `📊 **PR Status**: ${behind} commits behind \\`main\\`\\n\\nConsider updating your branch for the most accurate CI results:\\n\\n**Option 1**: Use GitHub's \"Update branch\" button (if available)\\n\\n**Option 2**: Update locally:\\n\\`\\`\\`bash\\ngit fetch origin main\\ngit merge origin/main\\ngit push\\n\\`\\`\\`\\n\\n*Note: If you use a different remote name (e.g., upstream), adjust the commands accordingly.*\\n\\nThis ensures your changes are tested against the latest code.`\n              });\n            }\n"
  },
  {
    "path": ".github/workflows/ci.yaml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: CI\n\non:\n  workflow_dispatch:\n  push:\n    branches: [\"main\"]\n  pull_request:\n    branches: [\"main\"]\n  pull_request_target:\n    types: [labeled]\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\npermissions:\n  contents: read\n\njobs:\n  format-check:\n    runs-on: ubuntu-latest\n    if: github.event_name == 'pull_request'\n    permissions:\n      contents: read\n      issues: write\n    steps:\n      - name: Checkout PR branch\n        uses: actions/checkout@v4\n        with:\n          repository: ${{ github.event.pull_request.head.repo.full_name }}\n          ref: ${{ github.event.pull_request.head.ref }}\n          persist-credentials: false\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n\n      - name: Install format tools\n        run: |\n          python -m pip install --upgrade pip\n          pip install -e \".[dev]\"\n\n      - name: Check formatting\n        id: format-check\n        env:\n          GITHUB_TOKEN: \"\"\n        run: |\n          set -euo pipefail\n          pyink --check --diff .\n          isort --check-only --diff .\n\n      - name: Check import structure\n        id: import-check\n        env:\n          GITHUB_TOKEN: \"\"\n        run: |\n          set -euo pipefail\n          lint-imports --config pyproject.toml\n\n      - name: Comment on PR if formatting fails\n        if: failure() && steps.format-check.outcome == 'failure'\n        uses: actions/github-script@v7\n        continue-on-error: true\n        with:\n          script: |\n            github.rest.issues.createComment({\n              issue_number: context.payload.pull_request.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: '❌ **Formatting Check Failed**\\n\\nYour PR has formatting issues. Please run the following command locally and push the changes:\\n\\n```bash\\n./autoformat.sh\\n```\\n\\nThis will automatically fix all formatting issues using pyink (Google\\'s Python formatter) and isort.'\n            }).catch(err => {\n              console.log('Comment posting failed:', err.message);\n            });\n\n  test:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.10\", \"3.11\", \"3.12\"]\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install tox\n          pip install -e \".[dev,test]\"\n\n      - name: Run unit tests and linting\n        run: |\n          PY_VERSION=$(echo \"${{ matrix.python-version }}\" | tr -d '.')\n          # Format check is handled by separate job for better isolation\n          tox -e py${PY_VERSION},lint-src,lint-tests\n\n  live-api-tests:\n    needs: test\n    runs-on: ubuntu-latest\n    if: |\n      github.event_name == 'push' ||\n      (github.event_name == 'pull_request' &&\n       github.event.pull_request.head.repo.full_name == github.repository)\n\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install tox\n          pip install -e \".[dev,test]\"\n\n      - name: Run live API tests\n        env:\n          GITHUB_TOKEN: \"\"\n        run: |\n          set -euo pipefail\n          if [[ -z '${{ secrets.GEMINI_API_KEY }}' && -z '${{ secrets.OPENAI_API_KEY }}' ]]; then\n            echo \"::notice::Live API tests skipped - API keys not configured\"\n            exit 0\n          fi\n          GEMINI_API_KEY=\"${{ secrets.GEMINI_API_KEY }}\" \\\n          LANGEXTRACT_API_KEY=\"${{ secrets.GEMINI_API_KEY }}\" \\\n          OPENAI_API_KEY=\"${{ secrets.OPENAI_API_KEY }}\" \\\n          tox -e live-api\n\n  plugin-integration-test:\n    needs: test\n    runs-on: ubuntu-latest\n    if: github.event_name == 'pull_request'\n    permissions:\n      contents: read\n      pull-requests: read\n\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n          fetch-depth: 0\n\n      - name: Detect provider-related changes\n        id: provider-changes\n        uses: tj-actions/changed-files@v46\n        with:\n          files: |\n            langextract/providers/**\n            langextract/factory.py\n            langextract/inference.py\n            tests/provider_plugin_test.py\n            pyproject.toml\n            .github/workflows/ci.yaml\n\n      - name: Skip if no provider changes\n        if: steps.provider-changes.outputs.any_changed == 'false'\n        run: |\n          echo \"No provider-related changes detected – skipping plugin integration test.\"\n          exit 0\n\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install tox\n\n      - name: Run plugin smoke test\n        run: tox -e plugin-smoke\n\n      - name: Run plugin integration test\n        run: tox -e plugin-integration\n\n  ollama-integration-test:\n    needs: test\n    runs-on: ubuntu-latest\n    if: github.event_name == 'pull_request'\n    permissions:\n      contents: read\n      pull-requests: read\n\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n          fetch-depth: 0\n\n      - name: Detect file changes\n        id: changes\n        uses: tj-actions/changed-files@v46\n        with:\n          files: |\n            langextract/inference.py\n            examples/ollama/**\n            tests/test_ollama_integration.py\n            .github/workflows/ci.yaml\n\n      - name: Skip if no Ollama changes\n        if: steps.changes.outputs.any_changed == 'false'\n        run: |\n          echo \"No Ollama-related changes detected – skipping job.\"\n          exit 0\n\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n\n      - name: Launch Ollama container\n        run: |\n          docker run -d --name ollama \\\n            -p 127.0.0.1:11434:11434 \\\n            -v ollama:/root/.ollama \\\n            ollama/ollama:0.5.4\n          for i in {1..20}; do\n            curl -fs http://localhost:11434/api/version && break\n            sleep 3\n          done\n\n      - name: Pull gemma2 model\n        run: docker exec ollama ollama pull gemma2:2b || true\n\n      - name: Install tox\n        run: |\n          python -m pip install --upgrade pip\n          pip install tox\n\n      - name: Run Ollama integration tests\n        run: tox -e ollama-integration\n\n  test-fork-pr:\n    runs-on: ubuntu-latest\n    timeout-minutes: 30\n    environment:\n      name: live-keys\n    # Triggered when a maintainer adds 'ready-to-merge' label to fork PRs only\n    if: |\n      github.event_name == 'pull_request_target' &&\n      github.event.action == 'labeled' &&\n      github.event.label.name == 'ready-to-merge' &&\n      github.event.pull_request.head.repo.full_name != github.repository\n\n    permissions:\n      contents: read\n      issues: write\n\n    steps:\n      - name: Check if user is maintainer\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const { data: permission } = await github.rest.repos.getCollaboratorPermissionLevel({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              username: context.actor\n            });\n\n            const isMaintainer = ['admin', 'maintain'].includes(permission.permission);\n            if (!isMaintainer) {\n              throw new Error(`User ${context.actor} does not have maintainer permissions.`);\n            }\n\n      - name: Pin commit SHA for security\n        id: sha-pin\n        run: |\n          SHA_TO_TEST=\"${{ github.event.pull_request.head.sha }}\"\n          echo \"SHA_TO_TEST=${SHA_TO_TEST}\" >> $GITHUB_OUTPUT\n          echo \"::notice title=Security::Pinned commit SHA for testing: ${SHA_TO_TEST}\"\n\n      - name: Checkout base repo\n        uses: actions/checkout@v4\n        with:\n          ref: main\n          fetch-depth: 0\n          persist-credentials: false\n\n      - name: Fetch and verify exact PR commit\n        run: |\n          set -euo pipefail\n          EXPECTED_SHA=\"${STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST}\"\n          echo \"Fetching exact commit: $EXPECTED_SHA\"\n\n          # Fetch the specific commit SHA\n          git fetch --no-tags --prune --no-recurse-submodules origin \"$EXPECTED_SHA\" || {\n            echo \"::error::Failed to fetch PR commit $EXPECTED_SHA. The commit may have been deleted.\"\n            exit 1\n          }\n\n          git checkout -b pr-to-test \"$EXPECTED_SHA\"\n\n          # Verify checkout\n          ACTUAL_SHA=\"$(git rev-parse HEAD)\"\n          if [ \"$ACTUAL_SHA\" != \"$EXPECTED_SHA\" ]; then\n            echo \"::error::SHA verification failed! Expected $EXPECTED_SHA but got $ACTUAL_SHA\"\n            exit 1\n          fi\n\n          echo \"::notice title=Security::Successfully verified commit SHA: $ACTUAL_SHA\"\n        env:\n          STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST: ${{ steps.sha-pin.outputs.SHA_TO_TEST }}\n\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n\n      - name: Install format tools\n        run: |\n          python -m pip install --upgrade pip\n          # Install formatter tools with pinned versions\n          pip install pyink==24.3.0 isort==5.13.2 lint-imports==0.3.1\n\n      - name: Validate PR formatting\n        run: |\n          set -euo pipefail\n          echo \"Validating code formatting...\"\n          pyink --check --diff . || {\n            echo \"::error::Code formatting (pyink) does not meet project standards. Please run ./autoformat.sh locally and push the changes.\"\n            exit 1\n          }\n          isort --check-only --diff . || {\n            echo \"::error::Import sorting (isort) does not meet project standards. Please run ./autoformat.sh locally and push the changes.\"\n            exit 1\n          }\n\n      - name: Checkout main branch\n        uses: actions/checkout@v4\n        with:\n          ref: main\n          fetch-depth: 0\n          persist-credentials: false\n\n      - name: Merge verified PR commit\n        run: |\n          set -euo pipefail\n          git config user.name \"github-actions[bot]\"\n          git config user.email \"github-actions[bot]@users.noreply.github.com\"\n\n          SHA_TO_MERGE=\"${STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST}\"\n          echo \"Merging verified commit: $SHA_TO_MERGE\"\n\n          git fetch --no-tags --prune --no-recurse-submodules origin \"$SHA_TO_MERGE\"\n          git merge --no-ff --no-edit \"$SHA_TO_MERGE\" || {\n            echo \"::error::Failed to merge commit $SHA_TO_MERGE\"\n            exit 1\n          }\n\n          echo \"::notice title=Security::Successfully merged verified commit\"\n        env:\n          STEPS_SHA_PIN_OUTPUTS_SHA_TO_TEST: ${{ steps.sha-pin.outputs.SHA_TO_TEST }}\n\n      - name: Add status comment\n        uses: actions/github-script@v7\n        with:\n          script: |\n            github.rest.issues.createComment({\n              issue_number: context.payload.pull_request.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'Preparing to run live API tests (pending environment approval and API key availability)...'\n            });\n\n      - name: Run live API tests\n        env:\n          GITHUB_TOKEN: \"\"\n        run: |\n          set -euo pipefail\n          if [[ -z '${{ secrets.GEMINI_API_KEY }}' && -z '${{ secrets.OPENAI_API_KEY }}' ]]; then\n            echo \"::notice::Live API tests skipped - API keys not configured\"\n            exit 0\n          fi\n          python -m pip install --upgrade pip\n          pip install tox\n          pip install -e \".[dev,test]\"\n          GEMINI_API_KEY=\"${{ secrets.GEMINI_API_KEY }}\" \\\n          LANGEXTRACT_API_KEY=\"${{ secrets.GEMINI_API_KEY }}\" \\\n          OPENAI_API_KEY=\"${{ secrets.OPENAI_API_KEY }}\" \\\n          tox -e live-api\n\n      - name: Report success\n        if: success()\n        uses: actions/github-script@v7\n        with:\n          script: |\n            github.rest.issues.createComment({\n              issue_number: context.payload.pull_request.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: '✅ Live API tests passed! All endpoints are working correctly.'\n            });\n\n      - name: Report failure\n        if: failure()\n        uses: actions/github-script@v7\n        with:\n          script: |\n            github.rest.issues.createComment({\n              issue_number: context.payload.pull_request.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: '❌ Live API tests failed. Please check the workflow logs for details.'\n            });\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: Publish to PyPI\n\non:\n  release:\n    types: [published]\n\npermissions:\n  contents: read\n  id-token: write\n\njobs:\n  pypi-publish:\n    name: Publish to PyPI\n    runs-on: ubuntu-latest\n    environment: pypi\n    permissions:\n      id-token: write\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.11'\n\n      - name: Install build dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install build\n\n      - name: Build package\n        run: python -m build\n\n      - name: Verify build artifacts\n        run: |\n          ls -la dist/\n          pip install twine\n          twine check dist/*\n\n      - name: Publish to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n"
  },
  {
    "path": ".github/workflows/revalidate-pr.yml",
    "content": "name: Revalidate PR\n\non:\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: 'PR number to validate'\n        required: true\n        type: string\n\npermissions:\n  contents: read\n  pull-requests: write\n  issues: write\n  checks: write\n  statuses: write\n\njobs:\n  revalidate:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Get PR data\n        id: pr_data\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const { data: pr } = await github.rest.pulls.get({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              pull_number: ${{ inputs.pr_number }}\n            });\n\n            core.info(`Validating PR #${pr.number}: ${pr.title}`);\n            core.info(`Author: ${pr.user.login}`);\n            core.info(`Changes: +${pr.additions} -${pr.deletions}`);\n\n            // Store head SHA for creating status\n            core.setOutput('head_sha', pr.head.sha);\n\n            return pr;\n\n      - name: Create pending status\n        uses: actions/github-script@v7\n        with:\n          script: |\n            await github.rest.repos.createCommitStatus({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              sha: '${{ steps.pr_data.outputs.head_sha }}',\n              state: 'pending',\n              context: 'Manual Validation',\n              description: 'Running validation checks...'\n            });\n\n      - name: Validate PR\n        id: validate\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const pr = ${{ steps.pr_data.outputs.result }};\n            const errors = [];\n            let passed = true;\n\n            // Check size\n            const totalChanges = pr.additions + pr.deletions;\n            const MAX_LINES = 1000;\n            if (totalChanges > MAX_LINES) {\n              errors.push(`PR size (${totalChanges} lines) exceeds ${MAX_LINES} line limit`);\n              passed = false;\n            }\n\n            // Check template\n            const body = pr.body || '';\n            const requiredSections = [\"# Description\", \"Fixes #\", \"# How Has This Been Tested?\", \"# Checklist\"];\n            const missingSections = requiredSections.filter(section => !body.includes(section));\n\n            if (missingSections.length > 0) {\n              errors.push(`Missing PR template sections: ${missingSections.join(', ')}`);\n              passed = false;\n            }\n\n            if (body.match(/Replace this with|Choose one:|Fixes #\\[issue number\\]/i)) {\n              errors.push('PR template contains unmodified placeholders');\n              passed = false;\n            }\n\n            // Check linked issue\n            const issueMatch = body.match(/(?:Fixes|Closes|Resolves)\\s+#(\\d+)/i);\n            if (!issueMatch) {\n              errors.push('No linked issue found');\n              passed = false;\n            }\n\n            // Store results\n            core.setOutput('passed', passed);\n            core.setOutput('errors', errors.join('; '));\n            core.setOutput('totalChanges', totalChanges);\n            core.setOutput('hasTemplate', missingSections.length === 0);\n            core.setOutput('hasIssue', !!issueMatch);\n\n            if (!passed) {\n              core.setFailed(errors.join('; '));\n            }\n\n      - name: Update commit status\n        if: always()\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const passed = ${{ steps.validate.outputs.passed }};\n            const errors = '${{ steps.validate.outputs.errors }}';\n\n            await github.rest.repos.createCommitStatus({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              sha: '${{ steps.pr_data.outputs.head_sha }}',\n              state: passed ? 'success' : 'failure',\n              context: 'Manual Validation',\n              description: passed ? 'All validation checks passed' : errors.substring(0, 140),\n              target_url: `https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}`\n            });\n\n      - name: Add validation comment\n        if: always()\n        uses: actions/github-script@v7\n        with:\n          script: |\n            const pr = ${{ steps.pr_data.outputs.result }};\n            const passed = ${{ steps.validate.outputs.passed }};\n            const totalChanges = ${{ steps.validate.outputs.totalChanges }};\n            const hasTemplate = ${{ steps.validate.outputs.hasTemplate }};\n            const hasIssue = ${{ steps.validate.outputs.hasIssue }};\n            const errors = '${{ steps.validate.outputs.errors }}'.split('; ').filter(e => e);\n\n            let body = `### Manual Validation Results\\n\\n`;\n            body += `**Status**: ${passed ? '✅ Passed' : '❌ Failed'}\\n\\n`;\n            body += `| Check | Status | Details |\\n`;\n            body += `|-------|--------|----------|\\n`;\n            body += `| PR Size | ${totalChanges <= 1000 ? '✅' : '❌'} | ${totalChanges} lines ${totalChanges > 1000 ? '(exceeds 1000 limit)' : ''} |\\n`;\n            body += `| Template | ${hasTemplate ? '✅' : '❌'} | ${hasTemplate ? 'Complete' : 'Missing required sections'} |\\n`;\n            body += `| Linked Issue | ${hasIssue ? '✅' : '❌'} | ${hasIssue ? 'Found' : 'Missing Fixes/Closes #XXX'} |\\n`;\n\n            if (errors.length > 0) {\n              body += `\\n**Errors:**\\n`;\n              errors.forEach(error => {\n                body += `- ❌ ${error}\\n`;\n              });\n            }\n\n            body += `\\n[View workflow run](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId})`;\n\n            await github.rest.issues.createComment({\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              issue_number: pr.number,\n              body: body\n            });\n"
  },
  {
    "path": ".github/workflows/validate-community-providers.yaml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: Validate Community Providers\n\non:\n  pull_request:\n    paths:\n      - 'COMMUNITY_PROVIDERS.md'\n      - 'scripts/validate_community_providers.py'\n\npermissions:\n  contents: read\n  pull-requests: read\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  validate:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.11'\n\n      - name: Validate table format\n        run: |\n          python scripts/validate_community_providers.py COMMUNITY_PROVIDERS.md\n"
  },
  {
    "path": ".github/workflows/validate_pr_template.yaml",
    "content": "name: Validate PR template\n\non:\n  pull_request_target:\n    types: [opened, edited, synchronize, reopened]\n  workflow_dispatch:\n\npermissions:\n  contents: read\n  pull-requests: read\n\njobs:\n  check:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Check PR author permissions\n        id: check\n        if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false\n        uses: actions/github-script@v7\n        with:\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          script: |\n            const pr = context.payload.pull_request;\n            const {owner, repo} = context.repo;\n            const actor = pr.user.login;\n            const authorType = pr.user.type;\n\n            // Check if PR author is a bot (e.g., Dependabot)\n            if (authorType === 'Bot') {\n              core.setOutput('skip_validation', 'true');\n              console.log(`Skipping validation for bot-authored PR: ${actor}`);\n              return;\n            }\n\n            // Check if this is a community provider PR (only modifies COMMUNITY_PROVIDERS.md)\n            const { data: files } = await github.rest.pulls.listFiles({\n              owner, repo,\n              pull_number: pr.number\n            });\n\n            const isCommunityProviderPR = files.length === 1 &&\n                                          files[0].filename === 'COMMUNITY_PROVIDERS.md';\n\n            if (isCommunityProviderPR) {\n              core.setOutput('is_community_provider', 'true');\n              console.log('Community provider PR detected - relaxed validation will apply');\n            } else {\n              core.setOutput('is_community_provider', 'false');\n            }\n\n            // Get permission level\n            try {\n              const { data } = await github.rest.repos.getCollaboratorPermissionLevel({\n                owner, repo, username: actor\n              });\n\n              const permission = data.permission; // admin|maintain|write|triage|read|none\n              console.log(`Actor ${actor} has permission level: ${permission}`);\n\n              // Check if user has write+ permissions\n              if (['admin', 'maintain', 'write'].includes(permission)) {\n                core.setOutput('skip_validation', 'true');\n                console.log(`Skipping validation for maintainer: ${actor} (${permission})`);\n              } else {\n                core.setOutput('skip_validation', 'false');\n                console.log(`Validation required for: ${actor} (${permission})`);\n              }\n            } catch (e) {\n              // If we can't determine permissions, require validation\n              core.setOutput('skip_validation', 'false');\n              core.warning(`Permission lookup failed: ${e.message}`);\n            }\n\n      - name: Validate PR template\n        if: |\n          github.event_name == 'pull_request_target' &&\n          github.event.pull_request.draft == false &&\n          steps.check.outputs.skip_validation != 'true'\n        env:\n          PR_BODY: ${{ github.event.pull_request.body }}\n          IS_COMMUNITY_PROVIDER: ${{ steps.check.outputs.is_community_provider }}\n        run: |\n          printf '%s\\n' \"$PR_BODY\" | tr -d '\\r' > body.txt\n\n          # Required sections from the template\n          required=( \"# Description\" \"# How Has This Been Tested?\" \"# Checklist\" )\n          err=0\n\n          # Check for required sections\n          for h in \"${required[@]}\"; do\n            grep -Fq \"$h\" body.txt || { echo \"::error::$h missing\"; err=1; }\n          done\n\n          # Check for issue reference - relaxed for community provider PRs\n          if [ \"$IS_COMMUNITY_PROVIDER\" = \"true\" ]; then\n            # For community provider PRs, accept either \"Fixes #\" or \"Related to #\" (case-insensitive)\n            if ! grep -Eiq '(Fixes #[0-9]+|Related to #[0-9]+)' body.txt; then\n              echo \"::error::Issue reference missing (need 'Fixes #NNN' or 'Related to #NNN')\"\n              err=1\n            fi\n          else\n            # For other PRs, require \"Fixes #\" with a number\n            if ! grep -Eq 'Fixes #[0-9]+' body.txt; then\n              echo \"::error::Missing 'Fixes #NNN' reference\"\n              err=1\n            fi\n          fi\n\n          # Check for placeholder text that should be replaced\n          grep -Eiq 'Replace this with|Choose one:' body.txt && {\n            echo \"::error::Template placeholders still present\"; err=1;\n          }\n\n          # Also check for the unmodified issue number placeholder\n          grep -Fq 'Fixes #[issue number]' body.txt && {\n            echo \"::error::Issue number placeholder not updated\"; err=1;\n          }\n\n          exit $err\n\n      - name: Log skip reason\n        if: |\n          github.event_name == 'pull_request_target' &&\n          (github.event.pull_request.draft == true ||\n           steps.check.outputs.skip_validation == 'true')\n        run: |\n          echo \"Skipping PR template validation. Draft: ${{ github.event.pull_request.draft }}; skip_validation: ${{ steps.check.outputs.skip_validation || 'N/A' }}\"\n"
  },
  {
    "path": ".github/workflows/zenodo-publish.yml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nname: Publish to Zenodo\non:\n  release:\n    types: [published]\n\nconcurrency:\n  group: zenodo-${{ github.ref }}\n  cancel-in-progress: false\n\njobs:\n  zenodo:\n    # Only run on releases from the main repository, not forks\n    # Skip pre-releases to avoid creating DOIs for test releases\n    if: ${{ !github.event.release.prerelease && github.repository == 'google/langextract' }}\n    runs-on: ubuntu-latest\n    timeout-minutes: 15\n    permissions:\n      contents: read\n    env:\n      ZENODO_TOKEN: ${{ secrets.ZENODO_TOKEN }}\n      ZENODO_RECORD_ID: ${{ secrets.ZENODO_RECORD_ID }}\n      RELEASE_TAG: ${{ github.ref_name }}\n      GITHUB_REPOSITORY: ${{ github.repository }}\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: '3.11'\n\n      - name: Build distributions\n        run: |\n          python -m pip install --upgrade pip build\n          python -m build\n\n      - name: Install dependencies\n        run: python -m pip install requests\n\n      - name: Publish new Zenodo version\n        run: python .github/scripts/zenodo_publish.py\n"
  },
  {
    "path": ".gitignore",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Byte-compiled / Cache files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# Distribution / Packaging\nbuild/\ndist/\n*.egg-info/\n.eggs/\neggs/\n\n# Virtual Environments\n.env\n.venv\nenv/\nvenv/\nENV/\n*_env/\n\n# Test & Coverage Reports\n.pytest_cache/\n.tox/\nhtmlcov/\n.coverage\n.coverage.*\n\n# Generated Output & Data\n# LangExtract outputs are defaulted to test_output/\n/test_output/\n\n# Sphinx documentation build output\ndocs/_build/\n\n# IDE / Editor specific\n.idea/\n.vscode/\n*.swp\n*.swo\n*~\n.*.swp\n.*.swo\n\n# OS-specific\n.DS_Store\nThumbs.db\nehthumbs.db\nDesktop.ini\n$RECYCLE.BIN/\n*.cab\n*.msi\n*.msm\n*.msp\n*.lnk\n\n# Development tools & environments\n.python-version\n.pytype/\n.mypy_cache/\n.dmypy.json\ndmypy.json\n.pyre/\n.ruff_cache/\n*.sage.py\n.hypothesis/\n.scrapy\n\n# Jupyter Notebooks\n.ipynb_checkpoints\n*/.ipynb_checkpoints/*\nprofile_default/\nipython_config.py\n\n# Logs and databases\n*.log\n*.sql\n*.sqlite\n*.sqlite3\ndb.sqlite3\ndb.sqlite3-journal\nlogs/\n*.pid\n\n# Security and secrets\n*.key\n*.pem\n*.crt\n*.csr\n.env.local\n.env.production\n.env.*.local\nsecrets/\ncredentials/\n\n# AI tooling\nCLAUDE.md\n.claude/settings.local.json\n.aider.chat.history.*\n.aider.input.history\n.gemini/\nGEMINI.md\n\n# Package managers\npip-log.txt\npip-delete-this-directory.txt\nnode_modules/\nnpm-debug.log*\nyarn-debug.log*\nyarn-error.log*\n.pnpm-debug.log*\npackage-lock.json\nyarn.lock\npnpm-lock.yaml\n\n# Local development\nlocal_settings.py\ninstance/\n.webassets-cache\n.sass-cache/\n*.css.map\n*.js.map\n.dev/\n\n# Temporary files\ntmp/\ntemp/\ncache/\n*.tmp\n*.bak\n*.backup\n*.orig\n.~lock.*#\n\n# Archives\n*.tar\n*.tar.gz\n*.zip\n*.rar\n*.7z\n*.dmg\n*.iso\n*.jar\n\n# Media files\n*.mp4\n*.avi\n*.mov\n*.wmv\n*.flv\n*.mp3\n*.wav\n*.ogg\n\n# Benchmark results and local environment\nlangextract_env/\nbenchmarks/benchmark_results\n\n# Benchmark results in root\nbenchmark_results/**/*.json\nbenchmark_results/**/*.jsonl\nbenchmark_results/**/*.html\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Pre-commit hooks for LangExtract\n# Install with: pre-commit install\n# Run manually: pre-commit run --all-files\n\nrepos:\n  - repo: https://github.com/PyCQA/isort\n    rev: 6.0.0\n    hooks:\n      - id: isort\n        name: isort (import sorting)\n        # Configuration is in pyproject.toml\n\n  - repo: https://github.com/google/pyink\n    rev: 24.3.0\n    hooks:\n      - id: pyink\n        name: pyink (Google's Black fork)\n        args: [\"--config\", \"pyproject.toml\"]\n\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n      - id: end-of-file-fixer\n        exclude: \\.gif$|\\.svg$\n      - id: trailing-whitespace\n      - id: check-yaml\n      - id: check-added-large-files\n        args: ['--maxkb=1000']\n      - id: check-merge-conflict\n      - id: check-case-conflict\n      - id: mixed-line-ending\n        args: ['--fix=lf']\n"
  },
  {
    "path": ".pylintrc",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n[MASTER]\n\n\n# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the\n# number of processors available to use.\njobs=0\n\n# Pickle collected data for later comparisons.\npersistent=yes\n\n# List of plugins (as comma separated values of python modules names) to load,\n# usually to register additional checkers.\n# Note: These plugins require Pylint >= 3.0\nload-plugins=\n    pylint.extensions.docparams,\n    pylint.extensions.typing\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time.\nenable=\n    useless-suppression\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once).\ndisable=\n    abstract-method,          # Protocol/ABC classes often have abstract methods\n    too-few-public-methods,   # Valid for data classes with minimal interface\n    fixme,                    # TODO/FIXME comments are useful for tracking work\n    # --- Code style and formatting ---\n    line-too-long,            # Handled by pyink formatter\n    bad-indentation,          # Pyink uses 2-space indentation\n    # --- Design complexity ---\n    too-many-positional-arguments,\n    too-many-locals,\n    too-many-arguments,\n    too-many-branches,\n    too-many-statements,\n    too-many-nested-blocks,\n    # --- Style preferences ---\n    no-else-return,\n    no-else-raise,\n    # --- Documentation ---\n    missing-function-docstring,\n    missing-class-docstring,\n    missing-raises-doc,\n    # --- Gradual improvements ---\n    deprecated-typing-alias,  # For typing.Type etc.\n    unspecified-encoding\n\n\n[REPORTS]\n\n# Set the output format. Available formats are text, parseable, colorized, msvs\n# (visual studio) and html.\noutput-format=text\n\n# Tells whether to display a full report or only the messages\nreports=no\n\n# Activate the evaluation score.\nscore=no\n\n\n[REFACTORING]\n\n# Maximum number of nested blocks for function / method body\nmax-nested-blocks=5\n\n# Complete name of functions that never returns. When checking for\n# inconsistent-return-statements if a never returning function is called then\n# it will be considered as an explicit return statement and no message will be\n# printed.\nnever-returning-functions=sys.exit\n\n\n[BASIC]\n\n# Naming style matching correct argument names.\nargument-naming-style=snake_case\n\n# Naming style matching correct attribute names.\nattr-naming-style=snake_case\n\n# Bad variable names which should always be refused, separated by a comma.\nbad-names=foo,bar,baz,toto,tutu,tata\n\n# Naming style matching correct class attribute names.\nclass-attribute-naming-style=any\n\n# Naming style matching correct class names.\nclass-naming-style=PascalCase\n\n# Naming style matching correct constant names.\nconst-naming-style=UPPER_CASE\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=-1\n\n# Naming style matching correct function names.\nfunction-naming-style=snake_case\n\n# Good variable names which should always be accepted, separated by a comma.\ngood-names=i,j,k,ex,Run,_,id,ok\n\n# Good variable names regexes, separated by a comma. If names match any regex,\n# they will always be accepted\ngood-names-rgxs=^T[A-Z][a-zA-Z]*$\n\n# Include a hint for the correct naming format with invalid-name.\ninclude-naming-hint=no\n\n# Naming style matching correct inline iteration names.\ninlinevar-naming-style=any\n\n# Naming style matching correct method names.\nmethod-naming-style=snake_case\n\n# Naming style matching correct module names.\nmodule-naming-style=snake_case\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=^_\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\n# These decorators are taken in consideration only for invalid-name.\nproperty-classes=abc.abstractproperty\n\n# Naming style matching correct variable names.\nvariable-naming-style=snake_case\n\n\n[FORMAT]\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=LF\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=^\\s*(# )?<?https?://\\S+>?$\n\n# Number of spaces of indent required inside a hanging or continued line.\nindent-after-paren=2\n\n# String used as indentation unit. This is usually \"    \" (4 spaces) or \"\\t\" (1\n# tab).\nindent-string=\"  \"\n\n# Maximum number of characters on a single line.\nmax-line-length=80\n\n# Maximum number of lines in a module.\nmax-module-lines=2000\n\n# Allow the body of a class to be on the same line as the declaration if body\n# contains single statement.\nsingle-line-class-stmt=no\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=no\n\n\n[LOGGING]\n\n# The type of string formatting that logging methods do. `old` means using %\n# formatting, `new` is for `{}` formatting.\nlogging-format-style=old\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format.\nlogging-modules=logging\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=FIXME,XXX,TODO\n\n\n[SIMILARITIES]\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=6\n\n\n[SPELLING]\n\n# Limits count of emitted suggestions for spelling mistakes.\nmax-spelling-suggestions=4\n\n# Spelling dictionary name. Available dictionaries: none. To make it working\n# install python-enchant package..\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# Tells whether to warn about missing members when the owner of the attribute\n# is inferred to be None.\nignore-none=yes\n\n# This flag controls whether pylint should warn about no-member and similar\n# checks whenever an opaque object is returned when inferring. The inference\n# can return multiple potential results while evaluating a Python object, but\n# some branches might not be evaluated, which results in partial inference. In\n# that case, it might be useful to still emit no-member and other checks for\n# the rest of the inferred objects.\nignore-on-opaque-inference=yes\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local,dataclasses.InitVar,typing.Any\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=dotenv,absl,more_itertools,pandas,requests,pydantic,yaml,IPython.display,\n                tqdm,numpy,google,langfun,typing_extensions\n\n# Show a hint with possible names when a member name was not found. The aspect\n# of finding the hint is based on edit distance.\nmissing-member-hint=yes\n\n# The minimum edit distance a name should have in order to be considered a\n# similar match for a missing member name.\nmissing-member-hint-distance=1\n\n# The total number of similar names that should be taken in consideration when\n# showing a hint for a missing member.\nmissing-member-max-choices=1\n\n# List of decorators that change the signature of a decorated function.\nsignature-mutators=\n\n\n[VARIABLES]\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid defining new builtins when possible.\nadditional-builtins=\n\n# Tells whether unused global variables should be treated as a violation.\nallow-global-unused-variables=yes\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,_cb\n\n# A regular expression matching the name of dummy variables (i.e. expected to\n# not be used).\ndummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_\n\n# Argument names that match this expression will be ignored. Default to name\n# with leading underscore.\nignored-argument-names=_.*|^ignored_|^unused_\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp,\n                      __post_init__\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=cls\n\n\n[DESIGN]\n\n# Maximum number of arguments for function / method.\nmax-args=7\n\n# Maximum number of attributes for a class (see R0902).\nmax-attributes=10\n\n# Maximum number of boolean expressions in an if statement.\nmax-bool-expr=5\n\n# Maximum number of branch for function / method body.\nmax-branches=12\n\n# Maximum number of locals for function / method body.\nmax-locals=15\n\n# Maximum number of parents for a class (see R0901).\nmax-parents=7\n\n# Maximum number of public methods for a class (see R0904).\nmax-public-methods=20\n\n# Maximum number of return / yield for function / method body.\nmax-returns=6\n\n# Maximum number of statements in function / method body.\nmax-statements=50\n\n# Minimum number of public methods for a class (see R0903).\nmin-public-methods=0\n\n\n[IMPORTS]\n\n# Allow wildcard imports from modules that define __all__.\nallow-wildcard-with-all=yes\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n# Deprecated modules which should not be used, separated by a comma.\ndeprecated-modules=optparse,tkinter.tix\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled).\next-import-graph=\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled).\nimport-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled).\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant,numpy,pandas,torch,langfun,pyglove\n\n# Couples of modules and preferred modules, separated by a comma.\npreferred-modules=\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"BaseException, Exception\".\novergeneral-exceptions=BaseException,\n                       Exception\n"
  },
  {
    "path": "CITATION.cff",
    "content": "# SPDX-FileCopyrightText: 2025 Google LLC\n# SPDX-License-Identifier: Apache-2.0\n#\n# This file contains citation metadata for LangExtract.\n# For more information visit: https://citation-file-format.github.io/\n\ncff-version: 1.2.0\ntitle: \"LangExtract\"\nmessage: \"If you use this software, please cite it as below.\"\ntype: software\nauthors:\n  - given-names: Akshay\n    family-names: Goel\n    email: goelak@google.com\n    affiliation: Google LLC\nrepository-code: \"https://github.com/google/langextract\"\nurl: \"https://github.com/google/langextract\"\nrepository: \"https://github.com/google/langextract\"\nabstract: \"LangExtract: LLM-powered structured information extraction from text with source grounding\"\nkeywords:\n  - language-models\n  - structured-data-extraction\n  - nlp\n  - machine-learning\n  - python\nlicense: Apache-2.0\nversion: 1.1.1\ndate-released: 2025-11-27\n\ndoi: \"10.5281/zenodo.17015089\"\nidentifiers:\n  - type: doi\n    value: \"10.5281/zenodo.17015089\"\n    description: \"Concept DOI for LangExtract\"\n"
  },
  {
    "path": "COMMUNITY_PROVIDERS.md",
    "content": "# Community Provider Plugins\n\nCommunity-developed provider plugins that extend LangExtract with additional model backends.\n\n**Supporting the Community:** Star plugin repositories you find useful and add 👍 reactions to their tracking issues to support maintainers' efforts.\n\n**⚠️ Important:** These are community-maintained packages. Please review the [safety guidelines](#safety-disclaimer) before use.\n\n## Plugin Registry\n\n| Plugin Name | PyPI Package | Maintainer | GitHub Repo | Description | Issue Link |\n|-------------|--------------|------------|-------------|-------------|------------|\n| AWS Bedrock | `langextract-bedrock` | [@andyxhadji](https://github.com/andyxhadji) | [andyxhadji/langextract-bedrock](https://github.com/andyxhadji/langextract-bedrock) | AWS Bedrock provider for LangExtract, supports all models & inference profiles | [#148](https://github.com/google/langextract/issues/148) |\n| LiteLLM | `langextract-litellm` | [@JustStas](https://github.com/JustStas) | [JustStas/langextract-litellm](https://github.com/JustStas/langextract-litellm) | LiteLLM provider for LangExtract, supports all models covered in LiteLLM, including OpenAI, Azure, Anthropic, etc., See [LiteLLM's supported models](https://docs.litellm.ai/docs/providers) | [#187](https://github.com/google/langextract/issues/187) |\n| Llama.cpp | `langextract-llamacpp` | [@fgarnadi](https://github.com/fgarnadi) | [fgarnadi/langextract-llamacpp](https://github.com/fgarnadi/langextract-llamacpp) | Llama.cpp provider for LangExtract, supports GGUF models from HuggingFace and local files | [#199](https://github.com/google/langextract/issues/199) |\n| Outlines | `langextract-outlines` | [@RobinPicard](https://github.com/RobinPicard) | [dottxt-ai/langextract-outlines](https://github.com/dottxt-ai/langextract-outlines) | Outlines provider for LangExtract, supports structured generation for various local and API-based models | [#101](https://github.com/google/langextract/issues/101) |\n| vLLM | `langextract-vllm` | [@wuli666](https://github.com/wuli666) | [wuli666/langextract-vllm](https://github.com/wuli666/langextract-vllm) | vLLM provider for LangExtract, supports local and distributed model serving | [#236](https://github.com/google/langextract/issues/236) |\n<!-- ADD NEW PLUGINS ABOVE THIS LINE -->\n\n## How to Add Your Plugin (PR Checklist)\n\nCopy this row template, replace placeholders, and insert **above** the marker line:\n\n```markdown\n| Your Plugin | `langextract-provider-yourname` | [@yourhandle](https://github.com/yourhandle) | [yourorg/yourrepo](https://github.com/yourorg/yourrepo) | Brief description (min 10 chars) | [#456](https://github.com/google/langextract/issues/456) |\n```\n\n**Before submitting your PR:**\n- [ ] PyPI package name starts with `langextract-` (recommended: `langextract-provider-<name>`)\n- [ ] PyPI package is published (or will be soon) and listed in backticks\n- [ ] Maintainer(s) listed as GitHub profile links (comma-separated if multiple)\n- [ ] Repository link points to public GitHub repo\n- [ ] Description clearly explains what your provider does\n- [ ] Issue Link points to a tracking issue in the LangExtract repository for integration and usage feedback (plugin-specific features and discussions can optionally happen in the plugin's repository)\n- [ ] Entries are sorted alphabetically by Plugin Name\n\n## Documentation\n\nFor detailed plugin development instructions, see the [Custom Provider Plugin Example](examples/custom_provider_plugin/README.md).\n\n## Safety Disclaimer\n\nCommunity plugins are independently developed and maintained. While we encourage community contributions, the LangExtract team cannot guarantee the safety, security, or functionality of third-party packages.\n\n**Before installing any plugin, we recommend:**\n\n- **Review the code** - Examine the source code and dependencies on GitHub\n- **Check community feedback** - Read issues and discussions for user experiences\n- **Verify the maintainer** - Look for active maintenance and responsive support\n- **Test safely** - Try plugins in isolated environments before production use\n- **Assess security needs** - Consider your specific security requirements\n\nCommunity plugins are used at your own discretion. When in doubt, reach out to the community through the plugin's issue tracker or the main LangExtract discussions.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# How to Contribute\n\nWe would love to accept your patches and contributions to this project.\n\n## Before you begin\n\n### Sign our Contributor License Agreement\n\nContributions to this project must be accompanied by a\n[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).\nYou (or your employer) retain the copyright to your contribution; this simply\ngives us permission to use and redistribute your contributions as part of the\nproject.\n\nIf you or your current employer have already signed the Google CLA (even if it\nwas for a different project), you probably don't need to do it again.\n\nVisit <https://cla.developers.google.com/> to see your current agreements or to\nsign a new one.\n\n### Review our Community Guidelines\n\nThis project follows HAI-DEF's\n[Community guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines)\n\n## Reporting Issues\n\nIf you encounter a bug or have a feature request, please open an issue on GitHub.\nWe have templates to help guide you:\n\n- **[Bug Report](.github/ISSUE_TEMPLATE/1-bug.md)**: For reporting bugs or unexpected behavior\n- **[Feature Request](.github/ISSUE_TEMPLATE/2-feature-request.md)**: For suggesting new features or improvements\n\nWhen creating an issue, GitHub will prompt you to choose the appropriate template.\nPlease provide as much detail as possible to help us understand and address your concern.\n\n## Contribution Process\n\n### 1. Development Setup\n\nTo get started, clone the repository and install the necessary dependencies for development and testing. Detailed instructions can be found in the [Installation from Source](https://github.com/google/langextract#from-source) section of the `README.md`.\n\n**Windows Users**: The formatting scripts use bash. Please use one of:\n- Git Bash (comes with Git for Windows)\n- WSL (Windows Subsystem for Linux)\n- PowerShell with bash-compatible commands\n\n### 2. Code Style and Formatting\n\nThis project uses automated tools to maintain a consistent code style. Before submitting a pull request, please format your code:\n\n```bash\n# Run the auto-formatter\n./autoformat.sh\n```\n\nThis script uses:\n- `isort` to organize imports with Google style (single-line imports)\n- `pyink` (Google's fork of Black) to format code according to Google's Python Style Guide\n\nYou can also run the formatters manually:\n```bash\nisort langextract tests\npyink langextract tests --config pyproject.toml\n```\n\nNote: The formatters target only `langextract` and `tests` directories by default to avoid\nformatting virtual environments or other non-source directories.\n\n### 3. Pre-commit Hooks (Recommended)\n\nFor automatic formatting checks before each commit:\n\n```bash\n# Install pre-commit\npip install pre-commit\n\n# Install the git hooks\npre-commit install\n\n# Run manually on all files\npre-commit run --all-files\n```\n\n### 4. Linting and Testing\n\nAll contributions must pass linting checks and unit tests. Please run these locally before submitting your changes:\n\n```bash\n# Run linting with Pylint 3.x\npylint --rcfile=.pylintrc langextract tests\n\n# Run tests\npytest tests\n```\n\n**Note on Pylint Configuration**: We use a modern, minimal configuration that:\n- Only disables truly noisy checks (not entire categories)\n- Keeps critical error detection enabled\n- Uses plugins for enhanced docstring and type checking\n- Aligns with our pyink formatter (80-char lines, 2-space indents)\n\nFor full testing across Python versions:\n```bash\ntox  # runs pylint + pytest on Python 3.10 and 3.11\n```\n\n### 5. Adding Custom Model Providers\n\nIf you want to add support for a new LLM provider, please refer to the [Provider System Documentation](langextract/providers/README.md). The recommended approach is to create an external plugin package rather than modifying the core library. This allows for:\n- Independent versioning and releases\n- Faster iteration without core review cycles\n- Custom dependencies without affecting core users\n\n### 6. Submit Your Pull Request\n\nAll submissions, including submissions by project members, require review. We\nuse [GitHub pull requests](https://docs.github.com/articles/about-pull-requests)\nfor this purpose.\n\nWhen you create a pull request, GitHub will automatically populate it with our\n[pull request template](.github/PULL_REQUEST_TEMPLATE/pull_request_template.md).\nPlease fill out all sections of the template to help reviewers understand your changes.\n\n#### Pull Request Guidelines\n\n- **Keep PRs focused and small**: Each PR should address a single issue and contain one cohesive change. PRs are automatically labeled by size to help reviewers:\n  - **size/XS**: < 50 lines — Small fixes and documentation updates\n  - **size/S**: 50-150 lines — Typical features or bug fixes\n  - **size/M**: 150-600 lines — Larger features that remain well-scoped\n  - **size/L**: 600-1000 lines — Consider splitting into smaller PRs if possible\n  - **size/XL**: > 1000 lines — Requires strong justification and may need special review\n- **Reference related issues**: All PRs must include \"Fixes #123\" or \"Closes #123\" in the description. The linked issue should have at least 5 👍 reactions from the community and include discussion that demonstrates the importance and need for the change.\n- **No infrastructure changes**: Contributors cannot modify infrastructure files, build configuration, and core documentation. These files are protected and can only be changed by maintainers. Use `./autoformat.sh` to format code without affecting infrastructure files. In special circumstances, build configuration updates may be considered if they include discussion and evidence of robust testing, ideally with community support.\n- **Single-change commits**: A PR should typically comprise a single git commit. Squash multiple commits before submitting.\n- **Clear description**: Explain what your change does and why it's needed.\n- **Ensure all tests pass**: Check that both formatting and tests are green before requesting review.\n- **Respond to feedback promptly**: Address reviewer comments in a timely manner.\n\nIf your change is large or complex, consider:\n- Opening an issue first to discuss the approach\n- Breaking it into multiple smaller PRs\n- Clearly explaining in the PR description why a larger change is necessary\n\nFor more details, read HAI-DEF's\n[Contributing guidelines](https://developers.google.com/health-ai-developer-foundations/community-guidelines#contributing)\n"
  },
  {
    "path": "Dockerfile",
    "content": "# Production Dockerfile for LangExtract\nFROM python:3.10-slim\n\n# Set working directory\nWORKDIR /app\n\n# Install LangExtract from PyPI\nRUN pip install --no-cache-dir langextract\n\n# Set default command\nCMD [\"python\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <a href=\"https://github.com/google/langextract\">\n    <img src=\"https://raw.githubusercontent.com/google/langextract/main/docs/_static/logo.svg\" alt=\"LangExtract Logo\" width=\"128\" />\n  </a>\n</p>\n\n# LangExtract\n\n[![PyPI version](https://img.shields.io/pypi/v/langextract.svg)](https://pypi.org/project/langextract/)\n[![GitHub stars](https://img.shields.io/github/stars/google/langextract.svg?style=social&label=Star)](https://github.com/google/langextract)\n![Tests](https://github.com/google/langextract/actions/workflows/ci.yaml/badge.svg)\n[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17015089.svg)](https://doi.org/10.5281/zenodo.17015089)\n\n## Table of Contents\n\n- [Introduction](#introduction)\n- [Why LangExtract?](#why-langextract)\n- [Quick Start](#quick-start)\n- [Installation](#installation)\n- [API Key Setup for Cloud Models](#api-key-setup-for-cloud-models)\n- [Adding Custom Model Providers](#adding-custom-model-providers)\n- [Using OpenAI Models](#using-openai-models)\n- [Using Local LLMs with Ollama](#using-local-llms-with-ollama)\n- [More Examples](#more-examples)\n  - [*Romeo and Juliet* Full Text Extraction](#romeo-and-juliet-full-text-extraction)\n  - [Medication Extraction](#medication-extraction)\n  - [Radiology Report Structuring: RadExtract](#radiology-report-structuring-radextract)\n- [Community Providers](#community-providers)\n- [Contributing](#contributing)\n- [Testing](#testing)\n- [Disclaimer](#disclaimer)\n\n## Introduction\n\nLangExtract is a Python library that uses LLMs to extract structured information from unstructured text documents based on user-defined instructions. It processes materials such as clinical notes or reports, identifying and organizing key details while ensuring the extracted data corresponds to the source text.\n\n## Why LangExtract?\n\n1.  **Precise Source Grounding:** Maps every extraction to its exact location in the source text, enabling visual highlighting for easy traceability and verification.\n2.  **Reliable Structured Outputs:** Enforces a consistent output schema based on your few-shot examples, leveraging controlled generation in supported models like Gemini to guarantee robust, structured results.\n3.  **Optimized for Long Documents:** Overcomes the \"needle-in-a-haystack\" challenge of large document extraction by using an optimized strategy of text chunking, parallel processing, and multiple passes for higher recall.\n4.  **Interactive Visualization:** Instantly generates a self-contained, interactive HTML file to visualize and review thousands of extracted entities in their original context.\n5.  **Flexible LLM Support:** Supports your preferred models, from cloud-based LLMs like the Google Gemini family to local open-source models via the built-in Ollama interface.\n6.  **Adaptable to Any Domain:** Define extraction tasks for any domain using just a few examples. LangExtract adapts to your needs without requiring any model fine-tuning.\n7.  **Leverages LLM World Knowledge:** Utilize precise prompt wording and few-shot examples to influence how the extraction task may utilize LLM knowledge. The accuracy of any inferred information and its adherence to the task specification are contingent upon the selected LLM, the complexity of the task, the clarity of the prompt instructions, and the nature of the prompt examples.\n\n## Quick Start\n\n> **Note:** Using cloud-hosted models like Gemini requires an API key. See the [API Key Setup](#api-key-setup-for-cloud-models) section for instructions on how to get and configure your key.\n\nExtract structured information with just a few lines of code.\n\n### 1. Define Your Extraction Task\n\nFirst, create a prompt that clearly describes what you want to extract. Then, provide a high-quality example to guide the model.\n\n```python\nimport langextract as lx\nimport textwrap\n\n# 1. Define the prompt and extraction rules\nprompt = textwrap.dedent(\"\"\"\\\n    Extract characters, emotions, and relationships in order of appearance.\n    Use exact text for extractions. Do not paraphrase or overlap entities.\n    Provide meaningful attributes for each entity to add context.\"\"\")\n\n# 2. Provide a high-quality example to guide the model\nexamples = [\n    lx.data.ExampleData(\n        text=\"ROMEO. But soft! What light through yonder window breaks? It is the east, and Juliet is the sun.\",\n        extractions=[\n            lx.data.Extraction(\n                extraction_class=\"character\",\n                extraction_text=\"ROMEO\",\n                attributes={\"emotional_state\": \"wonder\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"emotion\",\n                extraction_text=\"But soft!\",\n                attributes={\"feeling\": \"gentle awe\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"relationship\",\n                extraction_text=\"Juliet is the sun\",\n                attributes={\"type\": \"metaphor\"}\n            ),\n        ]\n    )\n]\n```\n\n> **Note:** Examples drive model behavior. Each `extraction_text` should ideally be verbatim from the example's `text` (no paraphrasing), listed in order of appearance. LangExtract raises `Prompt alignment` warnings by default if examples don't follow this pattern—resolve these for best results.\n\n### 2. Run the Extraction\n\nProvide your input text and the prompt materials to the `lx.extract` function.\n\n```python\n# The input text to be processed\ninput_text = \"Lady Juliet gazed longingly at the stars, her heart aching for Romeo\"\n\n# Run the extraction\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"gemini-2.5-flash\",\n)\n```\n\n> **Model Selection**: `gemini-2.5-flash` is the recommended default, offering an excellent balance of speed, cost, and quality. For highly complex tasks requiring deeper reasoning, `gemini-2.5-pro` may provide superior results. For large-scale or production use, a Tier 2 Gemini quota is suggested to increase throughput and avoid rate limits. See the [rate-limit documentation](https://ai.google.dev/gemini-api/docs/rate-limits#tier-2) for details.\n>\n> **Model Lifecycle**: Note that Gemini models have a lifecycle with defined retirement dates. Users should consult the [official model version documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions) to stay informed about the latest stable and legacy versions.\n\n### 3. Visualize the Results\n\nThe extractions can be saved to a `.jsonl` file, a popular format for working with language model data. LangExtract can then generate an interactive HTML visualization from this file to review the entities in context.\n\n```python\n# Save the results to a JSONL file\nlx.io.save_annotated_documents([result], output_name=\"extraction_results.jsonl\", output_dir=\".\")\n\n# Generate the visualization from the file\nhtml_content = lx.visualize(\"extraction_results.jsonl\")\nwith open(\"visualization.html\", \"w\") as f:\n    if hasattr(html_content, 'data'):\n        f.write(html_content.data)  # For Jupyter/Colab\n    else:\n        f.write(html_content)\n```\n\nThis creates an animated and interactive HTML file:\n\n![Romeo and Juliet Basic Visualization ](https://raw.githubusercontent.com/google/langextract/main/docs/_static/romeo_juliet_basic.gif)\n\n> **Note on LLM Knowledge Utilization:** This example demonstrates extractions that stay close to the text evidence - extracting \"longing\" for Lady Juliet's emotional state and identifying \"yearning\" from \"gazed longingly at the stars.\" The task could be modified to generate attributes that draw more heavily from the LLM's world knowledge (e.g., adding `\"identity\": \"Capulet family daughter\"` or `\"literary_context\": \"tragic heroine\"`). The balance between text-evidence and knowledge-inference is controlled by your prompt instructions and example attributes.\n\n### Scaling to Longer Documents\n\nFor larger texts, you can process entire documents directly from URLs with parallel processing and enhanced sensitivity:\n\n```python\n# Process Romeo & Juliet directly from Project Gutenberg\nresult = lx.extract(\n    text_or_documents=\"https://www.gutenberg.org/files/1513/1513-0.txt\",\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"gemini-2.5-flash\",\n    extraction_passes=3,    # Improves recall through multiple passes\n    max_workers=20,         # Parallel processing for speed\n    max_char_buffer=1000    # Smaller contexts for better accuracy\n)\n```\n\nThis approach can extract hundreds of entities from full novels while maintaining high accuracy. The interactive visualization seamlessly handles large result sets, making it easy to explore hundreds of entities from the output JSONL file. **[See the full *Romeo and Juliet* extraction example →](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)** for detailed results and performance insights.\n\n### Vertex AI Batch Processing\n\nSave costs on large-scale tasks by enabling Vertex AI Batch API: `language_model_params={\"vertexai\": True, \"batch\": {\"enabled\": True}}`.\n\nSee an example of the Vertex AI Batch API usage in [this example](docs/examples/batch_api_example.md).\n\n## Installation\n\n### From PyPI\n\n```bash\npip install langextract\n```\n\n*Recommended for most users. For isolated environments, consider using a virtual environment:*\n\n```bash\npython -m venv langextract_env\nsource langextract_env/bin/activate  # On Windows: langextract_env\\Scripts\\activate\npip install langextract\n```\n\n### From Source\n\nLangExtract uses modern Python packaging with `pyproject.toml` for dependency management:\n\n*Installing with `-e` puts the package in development mode, allowing you to modify the code without reinstalling.*\n\n\n```bash\ngit clone https://github.com/google/langextract.git\ncd langextract\n\n# For basic installation:\npip install -e .\n\n# For development (includes linting tools):\npip install -e \".[dev]\"\n\n# For testing (includes pytest):\npip install -e \".[test]\"\n```\n\n### Docker\n\n```bash\ndocker build -t langextract .\ndocker run --rm -e LANGEXTRACT_API_KEY=\"your-api-key\" langextract python your_script.py\n```\n\n## API Key Setup for Cloud Models\n\nWhen using LangExtract with cloud-hosted models (like Gemini or OpenAI), you'll need to\nset up an API key. On-device models don't require an API key. For developers\nusing local LLMs, LangExtract offers built-in support for Ollama and can be\nextended to other third-party APIs by updating the inference endpoints.\n\n### API Key Sources\n\nGet API keys from:\n\n*   [AI Studio](https://aistudio.google.com/app/apikey) for Gemini models\n*   [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/sdks/overview) for enterprise use\n*   [OpenAI Platform](https://platform.openai.com/api-keys) for OpenAI models\n\n### Setting up API key in your environment\n\n**Option 1: Environment Variable**\n\n```bash\nexport LANGEXTRACT_API_KEY=\"your-api-key-here\"\n```\n\n**Option 2: .env File (Recommended)**\n\nAdd your API key to a `.env` file:\n\n```bash\n# Add API key to .env file\ncat >> .env << 'EOF'\nLANGEXTRACT_API_KEY=your-api-key-here\nEOF\n\n# Keep your API key secure\necho '.env' >> .gitignore\n```\n\nIn your Python code:\n```python\nimport langextract as lx\n\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=\"Extract information...\",\n    examples=[...],\n    model_id=\"gemini-2.5-flash\"\n)\n```\n\n**Option 3: Direct API Key (Not Recommended for Production)**\n\nYou can also provide the API key directly in your code, though this is not recommended for production use:\n\n```python\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=\"Extract information...\",\n    examples=[...],\n    model_id=\"gemini-2.5-flash\",\n    api_key=\"your-api-key-here\"  # Only use this for testing/development\n)\n```\n\n**Option 4: Vertex AI (Service Accounts)**\n\nUse [Vertex AI](https://cloud.google.com/vertex-ai/docs/start/introduction-unified-platform) for authentication with service accounts:\n\n```python\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=\"Extract information...\",\n    examples=[...],\n    model_id=\"gemini-2.5-flash\",\n    language_model_params={\n        \"vertexai\": True,\n        \"project\": \"your-project-id\",\n        \"location\": \"global\"  # or regional endpoint\n    }\n)\n```\n\n## Adding Custom Model Providers\n\nLangExtract supports custom LLM providers via a lightweight plugin system. You can add support for new models without changing core code.\n\n- Add new model support independently of the core library\n- Distribute your provider as a separate Python package\n- Keep custom dependencies isolated\n- Override or extend built-in providers via priority-based resolution\n\nSee the detailed guide in [Provider System Documentation](langextract/providers/README.md) to learn how to:\n\n- Register a provider with `@registry.register(...)`\n- Publish an entry point for discovery\n- Optionally provide a schema with `get_schema_class()` for structured output\n- Integrate with the factory via `create_model(...)`\n\n## Using OpenAI Models\n\nLangExtract supports OpenAI models (requires optional dependency: `pip install langextract[openai]`):\n\n```python\nimport langextract as lx\n\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"gpt-4o\",  # Automatically selects OpenAI provider\n    api_key=os.environ.get('OPENAI_API_KEY'),\n    fence_output=True,\n    use_schema_constraints=False\n)\n```\n\nNote: OpenAI models require `fence_output=True` and `use_schema_constraints=False` because LangExtract doesn't implement schema constraints for OpenAI yet.\n\n## Using Local LLMs with Ollama\nLangExtract supports local inference using Ollama, allowing you to run models without API keys:\n\n```python\nimport langextract as lx\n\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"gemma2:2b\",  # Automatically selects Ollama provider\n    model_url=\"http://localhost:11434\",\n    fence_output=False,\n    use_schema_constraints=False\n)\n```\n\n**Quick setup:** Install Ollama from [ollama.com](https://ollama.com/), run `ollama pull gemma2:2b`, then `ollama serve`.\n\nFor detailed installation, Docker setup, and examples, see [`examples/ollama/`](examples/ollama/).\n\n## More Examples\n\nAdditional examples of LangExtract in action:\n\n### *Romeo and Juliet* Full Text Extraction\n\nLangExtract can process complete documents directly from URLs. This example demonstrates extraction from the full text of *Romeo and Juliet* from Project Gutenberg (147,843 characters), showing parallel processing, sequential extraction passes, and performance optimization for long document processing.\n\n**[View *Romeo and Juliet* Full Text Example →](https://github.com/google/langextract/blob/main/docs/examples/longer_text_example.md)**\n\n### Medication Extraction\n\n> **Disclaimer:** This demonstration is for illustrative purposes of LangExtract's baseline capability only. It does not represent a finished or approved product, is not intended to diagnose or suggest treatment of any disease or condition, and should not be used for medical advice.\n\nLangExtract excels at extracting structured medical information from clinical text. These examples demonstrate both basic entity recognition (medication names, dosages, routes) and relationship extraction (connecting medications to their attributes), showing LangExtract's effectiveness for healthcare applications.\n\n**[View Medication Examples →](https://github.com/google/langextract/blob/main/docs/examples/medication_examples.md)**\n\n### Radiology Report Structuring: RadExtract\n\nExplore RadExtract, a live interactive demo on HuggingFace Spaces that shows how LangExtract can automatically structure radiology reports. Try it directly in your browser with no setup required.\n\n**[View RadExtract Demo →](https://huggingface.co/spaces/google/radextract)**\n\n## Community Providers\n\nExtend LangExtract with custom model providers! Check out our [Community Provider Plugins](COMMUNITY_PROVIDERS.md) registry to discover providers created by the community or add your own.\n\nFor detailed instructions on creating a provider plugin, see the [Custom Provider Plugin Example](examples/custom_provider_plugin/).\n\n## Contributing\n\nContributions are welcome! See [CONTRIBUTING.md](https://github.com/google/langextract/blob/main/CONTRIBUTING.md) to get started\nwith development, testing, and pull requests. You must sign a\n[Contributor License Agreement](https://cla.developers.google.com/about)\nbefore submitting patches.\n\n\n\n## Testing\n\nTo run tests locally from the source:\n\n```bash\n# Clone the repository\ngit clone https://github.com/google/langextract.git\ncd langextract\n\n# Install with test dependencies\npip install -e \".[test]\"\n\n# Run all tests\npytest tests\n```\n\nOr reproduce the full CI matrix locally with tox:\n\n```bash\ntox  # runs pylint + pytest on Python 3.10 and 3.11\n```\n\n### Ollama Integration Testing\n\nIf you have Ollama installed locally, you can run integration tests:\n\n```bash\n# Test Ollama integration (requires Ollama running with gemma2:2b model)\ntox -e ollama-integration\n```\n\nThis test will automatically detect if Ollama is available and run real inference tests.\n\n## Development\n\n### Code Formatting\n\nThis project uses automated formatting tools to maintain consistent code style:\n\n```bash\n# Auto-format all code\n./autoformat.sh\n\n# Or run formatters separately\nisort langextract tests --profile google --line-length 80\npyink langextract tests --config pyproject.toml\n```\n\n### Pre-commit Hooks\n\nFor automatic formatting checks:\n```bash\npre-commit install  # One-time setup\npre-commit run --all-files  # Manual run\n```\n\n### Linting\n\nRun linting before submitting PRs:\n\n```bash\npylint --rcfile=.pylintrc langextract tests\n```\n\nSee [CONTRIBUTING.md](CONTRIBUTING.md) for full development guidelines.\n\n## Disclaimer\n\nThis is not an officially supported Google product. If you use\nLangExtract in production or publications, please cite accordingly and\nacknowledge usage. Use is subject to the [Apache 2.0 License](https://github.com/google/langextract/blob/main/LICENSE).\nFor health-related applications, use of LangExtract is also subject to the\n[Health AI Developer Foundations Terms of Use](https://developers.google.com/health-ai-developer-foundations/terms).\n\n---\n\n**Happy Extracting!**\n"
  },
  {
    "path": "autoformat.sh",
    "content": "#!/bin/bash\n# Copyright 2025 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Autoformat LangExtract codebase\n#\n# Usage: ./autoformat.sh [target_directory ...]\n#        If no target is specified, formats the current directory\n#\n# This script runs:\n# 1. isort for import sorting\n# 2. pyink (Google's Black fork) for code formatting\n# 3. pre-commit hooks for additional formatting (trailing whitespace, end-of-file, etc.)\n\nset -e\n\necho \"LangExtract Auto-formatter\"\necho \"==========================\"\necho\n\n# Check for required tools\ncheck_tool() {\n    if ! command -v \"$1\" &> /dev/null; then\n        echo \"Error: $1 not found. Please install with: pip install $1\"\n        exit 1\n    fi\n}\n\ncheck_tool \"isort\"\ncheck_tool \"pyink\"\ncheck_tool \"pre-commit\"\n\n# Parse command line arguments\nshow_usage() {\n    echo \"Usage: $0 [target_directory ...]\"\n    echo\n    echo \"Formats Python code using isort and pyink according to Google style.\"\n    echo\n    echo \"Arguments:\"\n    echo \"  target_directory    One or more directories to format (default: langextract tests)\"\n    echo\n    echo \"Examples:\"\n    echo \"  $0                  # Format langextract and tests directories\"\n    echo \"  $0 langextract      # Format only langextract directory\"\n    echo \"  $0 src tests        # Format multiple specific directories\"\n}\n\n# Check for help flag\nif [ \"$1\" = \"-h\" ] || [ \"$1\" = \"--help\" ]; then\n    show_usage\n    exit 0\nfi\n\n# Determine target directories\nif [ $# -eq 0 ]; then\n    TARGETS=\"langextract tests\"\n    echo \"No target specified. Formatting default directories: langextract tests\"\nelse\n    TARGETS=\"$@\"\n    echo \"Formatting targets: $TARGETS\"\nfi\n\n# Find pyproject.toml relative to script location\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" && pwd )\"\nCONFIG_FILE=\"${SCRIPT_DIR}/pyproject.toml\"\n\nif [ ! -f \"$CONFIG_FILE\" ]; then\n    echo \"Warning: pyproject.toml not found at ${CONFIG_FILE}\"\n    echo \"Using default configuration.\"\n    CONFIG_ARG=\"\"\nelse\n    CONFIG_ARG=\"--config $CONFIG_FILE\"\nfi\n\necho\n\n# Run isort\necho \"Running isort to organize imports...\"\nif isort $TARGETS; then\n    echo \"Import sorting complete\"\nelse\n    echo \"Import sorting failed\"\n    exit 1\nfi\n\necho\n\n# Run pyink\necho \"Running pyink to format code (Google style, 80 chars)...\"\nif pyink $TARGETS $CONFIG_ARG; then\n    echo \"Code formatting complete\"\nelse\n    echo \"Code formatting failed\"\n    exit 1\nfi\n\necho\n\n# Run pre-commit hooks for additional formatting\necho \"Running pre-commit hooks for additional formatting...\"\nif pre-commit run --all-files; then\n    echo \"Pre-commit hooks passed\"\nelse\n    echo \"Pre-commit hooks made changes - please review\"\n    # Exit with success since formatting was applied\n    exit 0\nfi\n\necho\necho \"All formatting complete!\"\necho\necho \"Next steps:\"\necho \"  - Run: pylint --rcfile=${SCRIPT_DIR}/.pylintrc $TARGETS\"\necho \"  - Commit your changes\"\n"
  },
  {
    "path": "benchmarks/benchmark.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"LangExtract benchmark suite for performance and quality testing.\n\nMeasures tokenization speed and extraction quality across multiple languages\nand text types. Automatically downloads test texts from Project Gutenberg\nand generates comparative visualizations.\n\nUsage:\n  # Run diverse text type benchmark (default)\n  python benchmarks/benchmark.py\n\n  # Test with specific model\n  python benchmarks/benchmark.py --model gemini-2.5-flash\n  python benchmarks/benchmark.py --model gemma2:2b  # Local model via Ollama\n\n  # Generate comparison plots from existing results\n  python benchmarks/benchmark.py --compare\n\nRequirements:\n  - Set GEMINI_API_KEY for cloud models\n  - Install Ollama for local model testing\n  - Results saved to benchmark_results/\n\"\"\"\n\nimport argparse\nfrom datetime import datetime\nimport json\nimport os\nfrom pathlib import Path\nimport time\nfrom typing import Any\nimport urllib.error\n\nimport dotenv\n\nfrom benchmarks import config\nfrom benchmarks import plotting\nfrom benchmarks import utils\nimport langextract\nfrom langextract import core\nfrom langextract import data\nfrom langextract import visualize\nimport langextract.io as lio\n\n# Load API key from environment\ndotenv.load_dotenv(override=True)\nGEMINI_API_KEY = os.environ.get(\n    \"GEMINI_API_KEY\", os.environ.get(\"LANGEXTRACT_API_KEY\")\n)\n\n\nclass BenchmarkRunner:\n  \"\"\"Orchestrates benchmark execution and result collection.\"\"\"\n\n  def __init__(self):\n    \"\"\"Initialize runner with timestamp and git metadata.\"\"\"\n    self.timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n    self.git_info = utils.get_git_info()\n    self.tokenizer = core.tokenizer.RegexTokenizer()\n\n  def set_tokenizer(self, tokenizer_type: str):\n    \"\"\"Set the tokenizer to use.\"\"\"\n    if tokenizer_type.lower() == \"unicode\":\n      self.tokenizer = core.tokenizer.UnicodeTokenizer()\n      print(\"Using UnicodeTokenizer\")\n    else:\n      self.tokenizer = core.tokenizer.RegexTokenizer()\n      print(\"Using RegexTokenizer (default)\")\n\n  def print_header(self):\n    \"\"\"Print benchmark header.\"\"\"\n    print(\"=\" * config.DISPLAY.separator_width)\n    print(\"LANGEXTRACT BENCHMARK\")\n    print(\"=\" * config.DISPLAY.separator_width)\n    print(\n        f\"Branch: {self.git_info['branch']} | Commit: {self.git_info['commit']}\"\n    )\n    print(\"-\" * config.DISPLAY.separator_width)\n\n  def benchmark_tokenization(self) -> list[dict[str, Any]]:\n    \"\"\"Measure tokenization throughput at different text sizes.\n\n    Returns:\n      List of dicts with words, tokens, timing, and throughput metrics.\n    \"\"\"\n    print(\"\\nTokenization Performance\")\n    print(\"-\" * config.DISPLAY.subseparator_width)\n\n    results = []\n\n    for word_count in config.TOKENIZATION.default_text_sizes:\n      text = \" \".join([\"word\"] * word_count)\n\n      _ = self.tokenizer.tokenize(text)\n\n      times = []\n      for _ in range(config.TOKENIZATION.benchmark_iterations):\n        start = time.perf_counter()\n        tokenized = self.tokenizer.tokenize(text)\n        elapsed = time.perf_counter() - start\n        times.append(elapsed)\n\n      avg_time = sum(times) / len(times)\n      avg_ms = avg_time * 1000\n      num_tokens = len(tokenized.tokens)\n      tokens_per_sec = num_tokens / avg_time if avg_time > 0 else 0\n\n      word_str = (\n          f\"{word_count//1000:,}k\" if word_count >= 1000 else f\"{word_count:,}\"\n      )\n\n      print(\n          f\"{word_str:>6} words: {avg_ms:7.2f}ms  \"\n          f\"({tokens_per_sec/1e6:.1f}M tokens/sec)\"\n      )\n\n      results.append({\n          \"words\": word_count,\n          \"tokens\": num_tokens,\n          \"avg_ms\": avg_ms,\n          \"tokens_per_sec\": tokens_per_sec,\n      })\n\n    return results\n\n  def test_single_extraction(\n      self,\n      model_id: str = config.MODELS.default_model,\n      text_type: config.TextTypes = config.TextTypes.ENGLISH,\n  ) -> dict[str, Any]:\n    \"\"\"Execute extraction test.\n\n    Args:\n      model_id: Model identifier (e.g., 'gemini-2.5-flash', 'gemma2:2b').\n      text_type: Language/text type to test.\n\n    Returns:\n      Dict with success status, timing, entity counts, and metrics.\n    \"\"\"\n    print(\"\\nExtraction Test\")\n    print(\"-\" * config.DISPLAY.subseparator_width)\n\n    try:\n      # Get test text\n      test_text = utils.get_text_from_gutenberg(text_type)\n      test_text = utils.get_optimal_text_size(test_text, model_id)\n\n      print(f\"   Text: {len(test_text):,} characters ({text_type.value})\")\n      print(f\"   Model: {model_id}\")\n\n      # Analyze tokenization\n      tokenization_analysis = utils.analyze_tokenization(\n          test_text, self.tokenizer\n      )\n      print(\n          \"   Tokenization:\"\n          f\" {utils.format_tokenization_summary(tokenization_analysis)}\"\n      )\n\n      # Get extraction config for text type\n      extraction_config = utils.get_extraction_example(text_type)\n\n      example = data.ExampleData(\n          text=\"MACBETH speaks to LADY MACBETH about Duncan.\",\n          extractions=[\n              data.Extraction(\n                  extraction_text=\"Macbeth\", extraction_class=\"Character\"\n              ),\n              data.Extraction(\n                  extraction_text=\"Lady Macbeth\", extraction_class=\"Character\"\n              ),\n              data.Extraction(\n                  extraction_text=\"Duncan\", extraction_class=\"Character\"\n              ),\n          ],\n      )\n\n      max_retries = 5\n      retry_delay = 3.0\n\n      # Retry logic for transient network/API failures\n      for attempt in range(max_retries):\n        try:\n          start_time = time.time()\n          result = langextract.extract(\n              text_or_documents=test_text,\n              model_id=model_id,\n              api_key=GEMINI_API_KEY,\n              prompt_description=extraction_config[\"prompt\"],\n              examples=[example],\n              max_workers=config.MODELS.default_max_workers,\n              temperature=config.MODELS.default_temperature,\n              extraction_passes=config.MODELS.default_extraction_passes,\n              tokenizer=self.tokenizer,\n          )\n          elapsed = time.time() - start_time\n          break\n        except (ConnectionError, TimeoutError):\n          if attempt < max_retries - 1:\n            print(f\"   Retrying in {retry_delay}s...\")\n            time.sleep(retry_delay)\n            retry_delay *= 1.5\n            continue\n          raise\n\n      print(f\"Extraction completed in {elapsed:.1f}s\")\n\n      grounded_entities = []\n      ungrounded_entities = []\n\n      if result.extractions:\n        for extraction in result.extractions:\n          is_grounded = (\n              extraction.char_interval\n              and extraction.char_interval.start_pos is not None\n              and extraction.char_interval.end_pos is not None\n          )\n\n          entity_text = extraction.extraction_text\n          if entity_text:\n            if is_grounded:\n              grounded_entities.append(entity_text)\n            else:\n              ungrounded_entities.append(entity_text)\n\n      unique_grounded = list(set(grounded_entities))\n      unique_ungrounded = list(set(ungrounded_entities))\n\n      print(f\"Found {len(unique_grounded)} grounded entities\")\n      if unique_ungrounded:\n        print(f\"   ({len(unique_ungrounded)} ungrounded entities ignored)\")\n\n      if unique_grounded:\n        sample = unique_grounded[:5]\n        sample_str = \", \".join(sample) + (\n            \"...\" if len(unique_grounded) > 5 else \"\"\n        )\n        print(f\"   Sample: {sample_str}\")\n\n      return {\n          \"success\": True,\n          \"model\": model_id,\n          \"text_type\": text_type.value,\n          \"time_seconds\": elapsed,\n          \"entity_count\": len(unique_grounded),\n          \"ungrounded_count\": len(unique_ungrounded),\n          \"sample_entities\": unique_grounded[:10],\n          \"tokenization\": tokenization_analysis,\n          config.EXTRACTION_RESULT_KEY: result,\n      }\n\n    except (urllib.error.URLError, RuntimeError) as e:\n      # Handle expected text download failures.\n      print(f\"Failed: {e}\")\n      return {\n          \"success\": False,\n          \"model\": model_id,\n          \"text_type\": text_type.value,\n          \"error\": str(e),\n      }\n\n  def test_diverse_text_types(\n      self, models: list[str] | None = None\n  ) -> list[dict[str, Any]]:\n    \"\"\"Test extraction with diverse text types.\"\"\"\n    print(\"\\n\" + \"=\" * config.DISPLAY.separator_width)\n    print(\"DIVERSE TEXT TYPE MODE\")\n    print(\"=\" * config.DISPLAY.separator_width)\n\n    if models is None:\n      models = [config.MODELS.default_model]\n\n    results = []\n    test_count = 0\n\n    for model_id in models:\n      print(f\"\\nTesting {model_id}\")\n      print(\"-\" * 30)\n\n      for text_type in config.TextTypes:\n        print(f\"\\n  Testing {text_type.value} text...\")\n        result = self.test_single_extraction(model_id, text_type)\n        results.append(result)\n\n        if result.get(\"success\"):\n          test_count += 1\n          if test_count % 3 == 0:\n            print(\n                \"   Rate limit delay\"\n                f\" ({config.MODELS.gemini_rate_limit_delay}s)...\"\n            )\n            time.sleep(config.MODELS.gemini_rate_limit_delay)\n\n    print(f\"\\nCompleted {test_count} successful tests\")\n    return results\n\n  def save_results(self, results: dict[str, Any]):\n    \"\"\"Save results and create plots.\"\"\"\n    results[\"timestamp\"] = self.timestamp\n    results[\"git\"] = self.git_info\n\n    json_path = config.PATHS.get_result_path(self.timestamp, \"\").with_suffix(\n        \".json\"\n    )\n\n    viz_dir = json_path.parent / \"visualizations\" / self.timestamp\n    viz_dir.mkdir(parents=True, exist_ok=True)\n\n    if config.RESULTS_KEY in results:\n      print(f\"\\nGenerating visualizations in: {viz_dir}\")\n      for result in results[config.RESULTS_KEY]:\n        if result.get(\"success\") and config.EXTRACTION_RESULT_KEY in result:\n          model_name = result[\"model\"].replace(\"/\", \"_\").replace(\":\", \"_\")\n          text_type = result[\"text_type\"]\n          viz_name = f\"{model_name}_{text_type}\"\n\n          jsonl_path = viz_dir / f\"{viz_name}.jsonl\"\n          lio.save_annotated_documents(\n              [result[config.EXTRACTION_RESULT_KEY]],\n              output_name=jsonl_path.name,\n              output_dir=str(viz_dir),\n          )\n\n          html_content = visualize(str(jsonl_path))\n          html_path = viz_dir / f\"{viz_name}.html\"\n          with open(html_path, \"w\") as f:\n            f.write(getattr(html_content, \"data\", html_content))\n\n    # Remove extraction result objects before saving JSON\n    for result in results.get(config.RESULTS_KEY, []):\n      result.pop(config.EXTRACTION_RESULT_KEY, None)\n\n    with open(json_path, \"w\") as f:\n      json.dump(results, f, indent=2, default=str)\n    print(f\"\\nResults saved to: {json_path}\")\n\n    plot_created = plotting.create_diverse_plots(results, json_path)\n\n    if plot_created:\n      print(f\"Plot saved to: {json_path.with_suffix('.png')}\")\n    else:\n      print(f\"Warning: Failed to create plot for {json_path.name}\")\n\n  def run_diverse_benchmark(self, models: list[str] | None = None):\n    \"\"\"Run benchmark.\"\"\"\n    self.print_header()\n\n    tokenization_results = self.benchmark_tokenization()\n    diverse_results = self.test_diverse_text_types(models)\n\n    results = {\n        \"tokenization\": tokenization_results,\n        config.RESULTS_KEY: diverse_results,\n    }\n\n    self.save_results(results)\n\n\ndef main():\n  \"\"\"Main entry point.\"\"\"\n  parser = argparse.ArgumentParser(description=\"LangExtract Benchmark Suite\")\n\n  parser.add_argument(\n      \"--model\",\n      type=str,\n      default=None,\n      help=f\"Model to use (default: {config.MODELS.default_model})\",\n  )\n\n  parser.add_argument(\n      \"--tokenizer\",\n      type=str,\n      choices=[\"regex\", \"unicode\"],\n      default=\"regex\",\n      help=\"Tokenizer to use (default: regex)\",\n  )\n\n  parser.add_argument(\n      \"--compare\",\n      action=\"store_true\",\n      help=\"Generate comparison plots from existing benchmark results\",\n  )\n\n  args = parser.parse_args()\n\n  # Handle comparison mode\n  if args.compare:\n    results_dir = Path(\"benchmark_results\")\n    json_files = sorted(results_dir.glob(\"benchmark_*.json\"))\n\n    if len(json_files) < 2:\n      print(\n          \"Need at least 2 benchmark results for comparison, found\"\n          f\" {len(json_files)}\"\n      )\n      return\n\n    print(f\"Found {len(json_files)} benchmark results to compare\")\n\n    # Use last 10 results or all if less than 10\n    files_to_compare = json_files[-10:]\n    comparison_path = (\n        results_dir\n        / f\"comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png\"\n    )\n\n    plotting.create_comparison_plots(files_to_compare, comparison_path)\n    print(f\"\\nComparison plot saved to: {comparison_path}\")\n    return\n\n  model_to_test = args.model or config.MODELS.default_model\n  if \"gemini\" in model_to_test.lower() and not GEMINI_API_KEY:\n    print(\n        f\"Error: {model_to_test} requires GEMINI_API_KEY or LANGEXTRACT_API_KEY\"\n    )\n    return\n\n  runner = BenchmarkRunner()\n  runner.set_tokenizer(args.tokenizer)\n  runner.run_diverse_benchmark([args.model] if args.model else None)\n\n\nif __name__ == \"__main__\":\n  main()\n"
  },
  {
    "path": "benchmarks/config.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Benchmark configuration settings and constants.\n\nCentralized configuration for tokenization tests, model parameters,\ndisplay formatting, and test text sources.\n\"\"\"\n\nfrom dataclasses import dataclass\nimport enum\nfrom pathlib import Path\n\n# Result dictionary keys\nRESULTS_KEY = \"results\"\nEXTRACTION_KEY = \"extraction\"\nEXTRACTION_RESULT_KEY = \"extraction_result\"\nTOKENIZATION_KEY = \"tokenization\"\n\n\n@dataclass(frozen=True)\nclass TokenizationConfig:\n  \"\"\"Settings for tokenization performance tests.\"\"\"\n\n  default_text_sizes: tuple[int, ...] = (100, 1000, 10000)  # Word counts\n  benchmark_iterations: int = 10  # Iterations per size for averaging\n\n\n@dataclass(frozen=True)\nclass ModelConfig:\n  \"\"\"Model and API configuration.\"\"\"\n\n  default_model: str = \"gemini-2.5-flash\"  # Cloud model default\n  local_model: str = \"gemma2:9b\"  # Ollama model default\n  default_temperature: float = 0.0  # Deterministic output\n  default_max_workers: int = 10  # Parallel processing threads\n  default_extraction_passes: int = 1  # Single pass extraction\n  gemini_rate_limit_delay: float = 8.0  # Seconds between batches\n\n\nclass TextTypes(str, enum.Enum):\n  \"\"\"Supported languages for extraction testing.\"\"\"\n\n  ENGLISH = \"english\"\n  JAPANESE = \"japanese\"\n  FRENCH = \"french\"\n  SPANISH = \"spanish\"\n\n\n# Test texts from Project Gutenberg (similar genres for fair comparison)\nGUTENBERG_TEXTS = {\n    TextTypes.ENGLISH: (\n        \"https://www.gutenberg.org/files/11/11-0.txt\"\n    ),  # Alice's Adventures\n    TextTypes.JAPANESE: (\n        \"https://www.gutenberg.org/files/1982/1982-0.txt\"\n    ),  # Rashomon\n    TextTypes.FRENCH: (\n        \"https://www.gutenberg.org/files/55456/55456-0.txt\"\n    ),  # Alice (French)\n    TextTypes.SPANISH: (\n        \"https://www.gutenberg.org/files/67248/67248-0.txt\"\n    ),  # El clavo\n}\n\n\n@dataclass(frozen=True)\nclass DisplayConfig:\n  \"\"\"Display configuration.\"\"\"\n\n  separator_width: int = 50\n  subseparator_width: int = 40\n  figure_size_single: tuple[int, int] = (12, 5)\n  figure_size_multi: tuple[int, int] = (14, 10)\n  plot_style: str = \"seaborn-v0_8-darkgrid\"\n\n\n@dataclass(frozen=True)\nclass PathConfig:\n  \"\"\"Path configuration.\"\"\"\n\n  results_dir: Path = Path(\"benchmark_results\")\n\n  def get_result_path(self, timestamp: str, suffix: str = \"\") -> Path:\n    \"\"\"Get result file path.\"\"\"\n    if not self.results_dir.exists():\n      self.results_dir.mkdir(parents=True)\n    filename = f\"benchmark{suffix}_{timestamp}\"\n    return self.results_dir / filename\n\n\n# Global config instances\nTOKENIZATION = TokenizationConfig()\nMODELS = ModelConfig()\nDISPLAY = DisplayConfig()\nPATHS = PathConfig()\n"
  },
  {
    "path": "benchmarks/plotting.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Visualization generation for benchmark results.\n\nCreates multi-panel plots showing tokenization performance, extraction metrics,\nand cross-language comparisons.\n\"\"\"\n\nfrom datetime import datetime\nimport json\nfrom pathlib import Path\nfrom typing import Any\n\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom benchmarks import config\n\nmatplotlib.use(\"Agg\")\nplt.style.use(config.DISPLAY.plot_style)\n\n\ndef create_diverse_plots(results: dict[str, Any], filepath: Path) -> bool:\n  \"\"\"Generate comprehensive benchmark visualization.\n\n  Args:\n    results: Benchmark results dictionary with tokenization and extraction data.\n    filepath: Output path for PNG file.\n\n  Returns:\n    True if plot created successfully, False on error.\n  \"\"\"\n  try:\n    fig = plt.figure(figsize=(15, 10))\n\n    # Create 2x3 grid: tokenization metrics (top), extraction metrics (bottom)\n    gs = fig.add_gridspec(2, 3, hspace=0.25, wspace=0.25)\n\n    ax1 = fig.add_subplot(gs[0, 0])  # Tokenization throughput\n    ax2 = fig.add_subplot(gs[0, 1])  # Token density by language\n    ax3 = fig.add_subplot(gs[0, 2])  # Entity extraction counts\n    ax4 = fig.add_subplot(gs[1, 0])  # Processing speed\n    ax5 = fig.add_subplot(gs[1, 1])  # Summary metrics\n    ax6 = fig.add_subplot(gs[1, 2])  # Unused\n\n    fig.suptitle(\n        f\"LangExtract Benchmark - {results['timestamp']}\", fontsize=14, y=0.98\n    )\n\n    _plot_tokenization_throughput(ax1, results)\n    _plot_tokenization_rate(ax2, results)\n    _plot_extraction_density(ax3, results)\n    _plot_processing_speed(ax4, results)\n    _plot_summary_table(ax5, results)\n    ax6.axis(\"off\")\n\n    plt.tight_layout(rect=[0, 0.02, 1, 0.96])\n\n    plot_path = filepath.with_suffix(\".png\")\n    plt.savefig(plot_path, dpi=100, bbox_inches=\"tight\")\n    plt.close()\n\n    print(f\"Plot saved to: {plot_path}\")\n    return True\n\n  except (IOError, OSError) as e:\n    print(f\"Warning: Could not create benchmark plot: {e}\")\n    return False\n\n\ndef _plot_tokenization_throughput(ax, results):\n  \"\"\"Plot tokenization throughput (tokens per second) on log scale.\"\"\"\n  if (\n      config.TOKENIZATION_KEY not in results\n      or not results[config.TOKENIZATION_KEY]\n  ):\n    ax.text(0.5, 0.5, \"No tokenization data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Tokenization Throughput\")\n    return\n\n  sizes = [r[\"words\"] for r in results[config.TOKENIZATION_KEY]]\n  speeds = [r[\"tokens_per_sec\"] for r in results[config.TOKENIZATION_KEY]]\n\n  ax.semilogx(sizes, speeds, \"b-o\", linewidth=2, markersize=8)\n  ax.set_xlabel(\"Number of Words (log scale)\")\n  ax.set_ylabel(\"Tokens per Second\")\n  ax.set_title(\"Tokenization Throughput\")\n  ax.grid(True, alpha=0.3)\n\n  max_speed = max(speeds)\n  ax.set_ylim(0, max_speed * 1.15)\n\n  y_ticks = [0, 100000, 200000, 300000, 400000]\n  ax.set_yticks(y_ticks)\n  ax.set_yticklabels([f\"{int(y/1000)}K\" if y > 0 else \"0\" for y in y_ticks])\n\n  for x, y in zip(sizes, speeds):\n    label = f\"{y/1000:.0f}K\"\n    ax.annotate(\n        label,\n        xy=(x, y),\n        xytext=(0, 5),\n        textcoords=\"offset points\",\n        ha=\"center\",\n        fontsize=9,\n    )\n\n  ax.set_xticks([100, 1000, 10000])\n  ax.set_xticklabels([\"10²\", \"10³\", \"10⁴\"])\n\n\ndef _plot_tokenization_rate(ax, results):\n  \"\"\"Plot tokenization rate by text type.\"\"\"\n  if config.RESULTS_KEY not in results:\n    ax.text(0.5, 0.5, \"No data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Tokenization Rate\")\n    return\n\n  text_types = []\n  tok_per_char = []\n\n  for result in results[config.RESULTS_KEY]:\n    if config.TOKENIZATION_KEY in result and result.get(\"success\", False):\n      text_type = result.get(\"text_type\", \"unknown\")\n      if text_type not in text_types:\n        text_types.append(text_type)\n        tpc = result[config.TOKENIZATION_KEY][\"tokens_per_char\"]\n        tok_per_char.append(tpc)\n\n  if not text_types:\n    ax.text(0.5, 0.5, \"No tokenization data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Tokenization Rate\")\n    return\n\n  x = np.arange(len(text_types))\n  bars = ax.bar(x, tok_per_char, color=\"#2196f3\", alpha=0.7)\n\n  for bar_rect, val in zip(bars, tok_per_char):\n    ax.text(\n        bar_rect.get_x() + bar_rect.get_width() / 2,\n        val + 0.005,\n        f\"{val:.3f}\",\n        ha=\"center\",\n        va=\"bottom\",\n        fontsize=9,\n    )\n\n  ax.set_xlabel(\"Text Type\")\n  ax.set_ylabel(\"Tokens per Character\")\n  ax.set_title(\"Tokenization Rate\")\n  ax.set_xticks(x)\n  ax.set_xticklabels([t.capitalize() for t in text_types])\n  ax.grid(True, alpha=0.3, axis=\"y\")\n  ax.set_ylim(0, max(0.30, max(tok_per_char) * 1.2) if tok_per_char else 0.30)\n\n\ndef _plot_extraction_density(ax, results):\n  \"\"\"Plot entity extraction density.\"\"\"\n  if config.RESULTS_KEY not in results:\n    ax.text(0.5, 0.5, \"No data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Extraction Density\")\n    return\n\n  text_types = []\n  densities = []\n\n  for result in results[config.RESULTS_KEY]:\n    if result.get(\"success\", False):\n      text_type = result.get(\"text_type\", \"unknown\")\n      if text_type not in text_types:\n        text_types.append(text_type)\n\n        char_count = 1000\n        if config.TOKENIZATION_KEY in result:\n          char_count = result[config.TOKENIZATION_KEY].get(\"num_chars\", 1000)\n\n        entity_count = result.get(\"entity_count\", 0)\n        density = (entity_count * 1000) / char_count\n        densities.append(density)\n\n  if not text_types:\n    ax.text(0.5, 0.5, \"No successful extractions\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Extraction Density\")\n    return\n\n  x = np.arange(len(text_types))\n  bars = ax.bar(x, densities, color=\"#4caf50\", alpha=0.7)\n\n  for bar_rect, val in zip(bars, densities):\n    ax.text(\n        bar_rect.get_x() + bar_rect.get_width() / 2,\n        val,\n        f\"{val:.1f}\",\n        ha=\"center\",\n        va=\"bottom\",\n        fontsize=9,\n    )\n\n  ax.set_xlabel(\"Text Type\")\n  ax.set_ylabel(\"Entities per 1K Characters\")\n  ax.set_title(\"Extraction Density\")\n  ax.set_xticks(x)\n  ax.set_xticklabels([t.capitalize() for t in text_types])\n  ax.grid(True, alpha=0.3, axis=\"y\")\n\n\ndef _plot_processing_speed(ax, results):\n  \"\"\"Plot processing speed normalized by text size.\"\"\"\n  if config.RESULTS_KEY not in results:\n    ax.text(0.5, 0.5, \"No data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Processing Speed\")\n    return\n\n  text_types = []\n  speeds = []\n\n  for result in results[config.RESULTS_KEY]:\n    if result.get(\"success\", False):\n      text_type = result.get(\"text_type\", \"unknown\")\n      if text_type not in text_types:\n        text_types.append(text_type)\n\n        char_count = 1000\n        if config.TOKENIZATION_KEY in result:\n          char_count = result[config.TOKENIZATION_KEY].get(\"num_chars\", 1000)\n\n        time_seconds = result.get(\"time_seconds\", 0)\n        speed = (time_seconds * 1000) / char_count\n        speeds.append(speed)\n\n  if not text_types:\n    ax.text(0.5, 0.5, \"No timing data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Processing Speed\")\n    return\n\n  x = np.arange(len(text_types))\n  bars = ax.bar(x, speeds, color=\"#ff9800\", alpha=0.7)\n\n  for bar_rect, val in zip(bars, speeds):\n    ax.text(\n        bar_rect.get_x() + bar_rect.get_width() / 2,\n        val,\n        f\"{val:.1f}s\",\n        ha=\"center\",\n        va=\"bottom\",\n        fontsize=9,\n    )\n\n  ax.set_xlabel(\"Text Type\")\n  ax.set_ylabel(\"Seconds per 1K Characters\")\n  ax.set_title(\"Processing Speed\")\n  ax.set_xticks(x)\n  ax.set_xticklabels([t.capitalize() for t in text_types])\n  ax.grid(True, alpha=0.3, axis=\"y\")\n\n\ndef _plot_summary_table(ax, results):\n  \"\"\"Create a summary of key findings.\"\"\"\n  ax.axis(\"off\")\n\n  if config.RESULTS_KEY not in results:\n    ax.text(0.5, 0.5, \"No data\", ha=\"center\", va=\"center\")\n    ax.set_title(\"Key Metrics\")\n    return\n\n  summary_lines = []\n  summary_lines.append(\"Key Metrics\")\n  summary_lines.append(\"-\" * 20)\n  summary_lines.append(\"\")\n\n  success_count = sum(\n      1 for r in results.get(config.RESULTS_KEY, []) if r.get(\"success\")\n  )\n  total_count = len(results.get(config.RESULTS_KEY, []))\n\n  if total_count > 0:\n    summary_lines.append(\"Tests Run:\")\n    summary_lines.append(f\"  {success_count} successful\")\n    summary_lines.append(f\"  {total_count - success_count} failed\")\n    summary_lines.append(\"\")\n\n  if success_count > 0:\n    avg_time = (\n        sum(\n            r.get(\"time_seconds\", 0)\n            for r in results.get(config.RESULTS_KEY, [])\n            if r.get(\"success\")\n        )\n        / success_count\n    )\n    summary_lines.append(f\"Avg Time: {avg_time:.1f}s\")\n\n  summary_text = \"\\n\".join(summary_lines)\n  ax.text(\n      0.5,\n      0.5,\n      summary_text,\n      ha=\"center\",\n      va=\"center\",\n      fontsize=10,\n      family=\"monospace\",\n  )\n\n  ax.set_title(\"Key Metrics\", fontweight=\"bold\", y=0.9)\n\n\ndef create_comparison_plots(json_files: list[Path], output_path: Path) -> None:\n  \"\"\"Create comparison plots from multiple benchmark JSON files.\n\n  Args:\n    json_files: List of paths to benchmark JSON files to compare.\n    output_path: Path where the comparison plot should be saved.\n  \"\"\"\n  if len(json_files) < 2:\n    print(\"Need at least 2 JSON files for comparison\")\n    return\n\n  all_results = []\n  for json_file in json_files:\n    try:\n      with open(json_file, \"r\") as f:\n        data = json.load(f)\n        data[\"filename\"] = json_file.stem\n        all_results.append(data)\n    except (IOError, OSError, json.JSONDecodeError) as e:\n      print(f\"Error loading {json_file}: {e}\")\n      continue\n\n  if len(all_results) < 2:\n    print(\"Could not load enough valid JSON files for comparison\")\n    return\n\n  plt.figure(figsize=(18, 12))\n\n  ax1 = plt.subplot(2, 3, (1, 2))\n  _plot_tokenization_comparison(ax1, all_results)\n\n  ax2 = plt.subplot(2, 3, 3)\n  _plot_entity_comparison(ax2, all_results)\n\n  ax3 = plt.subplot(2, 3, 4)\n  _plot_time_comparison(ax3, all_results)\n\n  ax4 = plt.subplot(2, 3, 5)\n  _plot_success_rate_comparison(ax4, all_results)\n\n  ax5 = plt.subplot(2, 3, 6)\n  _plot_timeline(ax5, all_results)\n\n  timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n  plt.suptitle(\n      f\"LangExtract Benchmark Comparison - {timestamp}\",\n      fontsize=14,\n      fontweight=\"bold\",\n  )\n  plt.tight_layout(rect=[0, 0.01, 1, 0.95])\n  plt.subplots_adjust(hspace=0.45, wspace=0.35, top=0.93)\n  plt.savefig(output_path, dpi=100, bbox_inches=\"tight\")\n  plt.close()\n  print(f\"Comparison plot saved to: {output_path}\")\n\n\ndef _plot_entity_comparison(ax, all_results):\n  \"\"\"Plot entity count comparison across runs.\"\"\"\n  runs = []\n  languages = [\"english\", \"french\", \"spanish\", \"japanese\"]\n  language_data = []\n\n  for result in all_results:\n    run_name = result[\"filename\"].replace(\"benchmark_\", \"\")[:10]\n    runs.append(run_name)\n\n    run_counts = {lang: 0 for lang in languages}\n    if config.RESULTS_KEY in result:\n      for res in result[config.RESULTS_KEY]:\n        lang = res.get(\"text_type\", \"\")\n        if lang in languages and res.get(\"success\"):\n          run_counts[lang] = res.get(\"entity_count\", 0)\n\n    language_data.append(run_counts)\n\n  x = np.arange(len(runs))\n  width = 0.2\n\n  for i, lang in enumerate(languages):\n    counts = [data[lang] for data in language_data]\n    bars = ax.bar(x + i * width, counts, width, label=lang.capitalize())\n\n    for bar_rect, count in zip(bars, counts):\n      if count > 0:\n        ax.text(\n            bar_rect.get_x() + bar_rect.get_width() / 2,\n            bar_rect.get_height() + 0.5,\n            str(count),\n            ha=\"center\",\n            fontsize=7,\n        )\n\n  ax.set_xlabel(\"Run\")\n  ax.set_ylabel(\"Entity Count\")\n  title = \"Entities Extracted by Language\\n\"\n  subtitle = \"Number of unique character names found per language\"\n  ax.set_title(title, fontweight=\"bold\", fontsize=10)\n  ax.text(\n      0.5,\n      1.01,\n      subtitle,\n      transform=ax.transAxes,\n      ha=\"center\",\n      fontsize=7,\n      style=\"italic\",\n      color=\"#666666\",\n      va=\"bottom\",\n  )\n  ax.set_xticks(x + width * 1.5)\n  ax.set_xticklabels(runs, rotation=45, ha=\"right\")\n  ax.legend(loc=\"upper left\", fontsize=8)\n  ax.grid(True, alpha=0.3)\n  ax.set_ylim(0, ax.get_ylim()[1] * 1.1)\n\n\ndef _plot_time_comparison(ax, all_results):\n  \"\"\"Plot processing time comparison.\"\"\"\n  runs = []\n  avg_times = []\n\n  for result in all_results:\n    run_name = result[\"filename\"].replace(\"benchmark_\", \"\")[:10]\n    runs.append(run_name)\n\n    if config.RESULTS_KEY in result:\n      times = [\n          r.get(\"time_seconds\", 0)\n          for r in result[config.RESULTS_KEY]\n          if r.get(\"success\")\n      ]\n      avg_time = sum(times) / len(times) if times else 0\n      avg_times.append(avg_time)\n    else:\n      avg_times.append(0)\n\n  x_pos = np.arange(len(runs))\n  bars = ax.bar(x_pos, avg_times, color=\"skyblue\", edgecolor=\"navy\", alpha=0.7)\n\n  ax.set_xlabel(\"Run\")\n  ax.set_ylabel(\"Average Time (seconds)\")\n  title = \"Average Processing Time\\n\"\n  subtitle = \"Mean extraction time across all language tests\"\n  ax.set_title(title, fontweight=\"bold\", fontsize=10)\n  ax.text(\n      0.5,\n      1.01,\n      subtitle,\n      transform=ax.transAxes,\n      ha=\"center\",\n      fontsize=7,\n      style=\"italic\",\n      color=\"#666666\",\n      va=\"bottom\",\n  )\n  ax.set_xticks(x_pos)\n  ax.set_xticklabels(runs, rotation=45, ha=\"right\")\n  ax.grid(True, alpha=0.3)\n\n  for bar_rect, time in zip(bars, avg_times):\n    if time > 0:\n      ax.text(\n          bar_rect.get_x() + bar_rect.get_width() / 2,\n          bar_rect.get_height() + 0.1,\n          f\"{time:.1f}s\",\n          ha=\"center\",\n          fontsize=8,\n      )\n\n  if max(avg_times) > 0:\n    ax.set_ylim(0, max(avg_times) * 1.2)\n\n\ndef _plot_tokenization_comparison(ax, all_results):\n  \"\"\"Plot tokenization throughput comparison as line graphs.\"\"\"\n\n  for i, result in enumerate(all_results):\n    run_name = result[\"filename\"].replace(\"benchmark_\", \"\")[:10]\n\n    if config.TOKENIZATION_KEY in result and result[config.TOKENIZATION_KEY]:\n      sizes = [r[\"words\"] for r in result[config.TOKENIZATION_KEY]]\n      speeds = [r[\"tokens_per_sec\"] for r in result[config.TOKENIZATION_KEY]]\n\n      ax.semilogx(\n          sizes,\n          speeds,\n          \"o-\",\n          linewidth=2,\n          markersize=6,\n          label=run_name,\n          alpha=0.8,\n      )\n\n      for x, y in zip(sizes, speeds):\n        if i == 0:  # Only label first run to avoid overlap\n          label = f\"{y/1000:.0f}K\"\n          ax.annotate(\n              label,\n              xy=(x, y),\n              xytext=(0, 5),\n              textcoords=\"offset points\",\n              ha=\"center\",\n              fontsize=7,\n          )\n\n  ax.set_xlabel(\"Number of Words (log scale)\")\n  ax.set_ylabel(\"Tokens per Second\")\n  title = \"Tokenization Throughput Comparison\\n\"\n  subtitle = \"Speed of text tokenization at different document sizes\"\n  ax.set_title(title, fontweight=\"bold\", fontsize=10)\n  ax.text(\n      0.5,\n      1.01,\n      subtitle,\n      transform=ax.transAxes,\n      ha=\"center\",\n      fontsize=7,\n      style=\"italic\",\n      color=\"#666666\",\n      va=\"bottom\",\n  )\n  ax.grid(True, alpha=0.3)\n  ax.legend(loc=\"best\", fontsize=8)\n\n  ax.set_xticks([100, 1000, 10000])\n  ax.set_xticklabels([\"10²\", \"10³\", \"10⁴\"])\n\n  _, ymax = ax.get_ylim()\n  ax.set_ylim(0, ymax * 1.1)\n\n\ndef _plot_success_rate_comparison(ax, all_results):\n  \"\"\"Plot success rate comparison.\"\"\"\n  runs = []\n  success_rates = []\n\n  for result in all_results:\n    run_name = result[\"filename\"].replace(\"benchmark_\", \"\")[:10]\n    runs.append(run_name)\n\n    if config.RESULTS_KEY in result:\n      total = len(result[config.RESULTS_KEY])\n      success = sum(1 for r in result[config.RESULTS_KEY] if r.get(\"success\"))\n      rate = (success / total * 100) if total > 0 else 0\n      success_rates.append(rate)\n    else:\n      success_rates.append(0)\n\n  x_pos = np.arange(len(runs))\n  colors = [\n      \"green\" if rate == 100 else \"orange\" if rate >= 75 else \"red\"\n      for rate in success_rates\n  ]\n  bars = ax.bar(x_pos, success_rates, color=colors, alpha=0.7)\n\n  ax.set_xlabel(\"Run\")\n  ax.set_ylabel(\"Success Rate (%)\")\n  title = \"Extraction Success Rate\\n\"\n  subtitle = \"Percentage of language tests completed without errors\"\n  ax.set_title(title, fontweight=\"bold\", fontsize=10)\n  ax.text(\n      0.5,\n      1.01,\n      subtitle,\n      transform=ax.transAxes,\n      ha=\"center\",\n      fontsize=7,\n      style=\"italic\",\n      color=\"#666666\",\n      va=\"bottom\",\n  )\n  ax.set_ylim(0, 105)\n  ax.set_xticks(x_pos)\n  ax.set_xticklabels(runs, rotation=45, ha=\"right\")\n  ax.axhline(y=100, color=\"green\", linestyle=\"--\", alpha=0.3)\n  ax.grid(True, alpha=0.3)\n\n  for bar_rect, rate in zip(bars, success_rates):\n    ax.text(\n        bar_rect.get_x() + bar_rect.get_width() / 2,\n        bar_rect.get_height() + 1,\n        f\"{rate:.0f}%\",\n        ha=\"center\",\n        fontsize=8,\n    )\n\n\ndef _plot_token_rate_by_language(ax, all_results):\n  \"\"\"Plot tokenization rates by language.\"\"\"\n  languages = [\"english\", \"french\", \"spanish\", \"japanese\"]\n  latest_result = all_results[-1]\n\n  token_rates = []\n  colors = []\n\n  if config.RESULTS_KEY in latest_result:\n    for lang in languages:\n      lang_results = [\n          r\n          for r in latest_result[config.RESULTS_KEY]\n          if r.get(\"text_type\") == lang and r.get(\"success\")\n      ]\n      if lang_results and config.TOKENIZATION_KEY in lang_results[0]:\n        rate = lang_results[0][config.TOKENIZATION_KEY].get(\n            \"tokens_per_char\", 0\n        )\n        token_rates.append(rate)\n        colors.append(\n            \"red\" if rate < 0.1 else \"orange\" if rate < 0.2 else \"green\"\n        )\n      else:\n        token_rates.append(0)\n        colors.append(\"gray\")\n\n  ax.bar(languages, token_rates, color=colors, alpha=0.7)\n  ax.set_xlabel(\"Language\")\n  ax.set_ylabel(\"Tokens per Character\")\n  ax.set_title(\"Tokenization Density (Latest Run)\")\n  ax.set_xticks(range(len(languages)))\n  ax.set_xticklabels([l.capitalize() for l in languages])\n  ax.grid(True, alpha=0.3)\n\n  for i, (lang, rate) in enumerate(zip(languages, token_rates)):\n    ax.text(i, rate + 0.01, f\"{rate:.3f}\", ha=\"center\", fontsize=8)\n\n\ndef _plot_timeline(ax, all_results):\n  \"\"\"Plot metrics over time if timestamps available.\"\"\"\n  timestamps = []\n  entity_totals = []\n\n  for result in all_results:\n    filename = result[\"filename\"]\n    if \"timestamp\" in result:\n      timestamps.append(result[\"timestamp\"])\n    else:\n      # Try to parse from filename (format: benchmark_YYYYMMDD_HHMMSS)\n      parts = filename.split(\"_\")\n      if len(parts) >= 3:\n        timestamps.append(f\"{parts[-2]}_{parts[-1]}\")\n      else:\n        timestamps.append(filename[:10])\n\n    if config.RESULTS_KEY in result:\n      total_entities = sum(\n          r.get(\"entity_count\", 0)\n          for r in result[config.RESULTS_KEY]\n          if r.get(\"success\")\n      )\n      entity_totals.append(total_entities)\n    else:\n      entity_totals.append(0)\n\n  x_pos = np.arange(len(timestamps))\n  ax.plot(x_pos, entity_totals, \"o-\", color=\"blue\", linewidth=2, markersize=8)\n  ax.set_xlabel(\"Run\")\n  ax.set_ylabel(\"Total Entities\")\n  title = \"Total Entities Over Time\\n\"\n  subtitle = \"Sum of all entities extracted across all languages\"\n  ax.set_title(title, fontweight=\"bold\", fontsize=10)\n  ax.text(\n      0.5,\n      1.01,\n      subtitle,\n      transform=ax.transAxes,\n      ha=\"center\",\n      fontsize=7,\n      style=\"italic\",\n      color=\"#666666\",\n      va=\"bottom\",\n  )\n  ax.set_xticks(x_pos)\n  ax.set_xticklabels([t[-6:] for t in timestamps], rotation=45, ha=\"right\")\n  ax.grid(True, alpha=0.3)\n\n  for i, total in enumerate(entity_totals):\n    ax.text(i, total + 1, str(total), ha=\"center\", fontsize=8)\n\n  if entity_totals:\n    min_val = min(0, min(entity_totals) - 5)\n    max_val = max(entity_totals) + 5\n    ax.set_ylim(min_val, max_val)\n"
  },
  {
    "path": "benchmarks/utils.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Helper functions for benchmark text retrieval and analysis.\"\"\"\n\nimport subprocess\nfrom typing import Any\nimport urllib.error\nimport urllib.request\n\nfrom benchmarks import config\nfrom langextract.core import tokenizer\n\n\ndef download_text(url: str) -> str:\n  \"\"\"Download text from URL.\n\n  Args:\n    url: URL to download from.\n\n  Returns:\n    Downloaded text content.\n  \"\"\"\n  try:\n    with urllib.request.urlopen(url) as response:\n      return response.read().decode(\"utf-8\")\n  except (urllib.error.URLError, urllib.error.HTTPError) as e:\n    raise RuntimeError(f\"Could not download from {url}: {e}\") from e\n\n\ndef extract_text_content(full_text: str) -> str:\n  \"\"\"Extract main content from Gutenberg text.\n\n  Skips headers and footers by taking middle 60% of text.\n\n  Args:\n    full_text: Full text including Gutenberg headers.\n\n  Returns:\n    Extracted main content.\n  \"\"\"\n  start_marker = \"*** START OF\"\n  end_marker = \"*** END OF\"\n\n  start_idx = full_text.upper().find(start_marker)\n  end_idx = full_text.upper().find(end_marker)\n\n  if start_idx != -1 and end_idx != -1:\n    content_start = full_text.find(\"\\n\", start_idx) + 1\n\n    # Handle markers with trailing asterisks (e.g., \"*** START ... ***\").\n    line_end = full_text.find(\"***\", start_idx + 3)\n    if (\n        line_end != -1 and line_end < content_start + 100\n    ):  # Ensure marker is on same line.\n      content_start = full_text.find(\"\\n\", line_end) + 1\n\n    return full_text[content_start:end_idx].strip()\n\n  text_length = len(full_text)\n  start = int(text_length * 0.2)\n  end = int(text_length * 0.8)\n  return full_text[start:end].strip()\n\n\ndef get_text_from_gutenberg(text_type: config.TextTypes) -> str:\n  \"\"\"Get text from Project Gutenberg for given language.\n\n  Args:\n    text_type: Type of text (language).\n\n  Returns:\n    Text sample from Gutenberg.\n  \"\"\"\n  url = config.GUTENBERG_TEXTS[text_type]\n  full_text = download_text(url)\n  content = extract_text_content(full_text)\n\n  mid_point = len(content) // 2\n  start_chunk = max(0, mid_point - 2500)\n  return content[start_chunk : start_chunk + 5000].strip()\n\n\ndef get_optimal_text_size(text: str, model_id: str) -> str:\n  \"\"\"Get optimal text size for model.\n\n  Args:\n    text: Original text.\n    model_id: Model identifier.\n\n  Returns:\n    Text truncated to optimal size.\n  \"\"\"\n  if (\n      \":\" in model_id\n      or \"gemma\" in model_id.lower()\n      or \"llama\" in model_id.lower()\n  ):\n    max_chars = 500  # Smaller context for local models.\n  else:\n    max_chars = 5000\n\n  return text[:max_chars]\n\n\ndef get_extraction_example(text_type: config.TextTypes) -> dict[str, str]:  # pylint: disable=unused-argument\n  \"\"\"Get extraction example configuration.\n\n  Args:\n    text_type: Type of text.\n\n  Returns:\n    Dictionary with prompt configuration.\n  \"\"\"\n  return {\n      \"prompt\": \"Extract all character names from this text\",\n  }\n\n\ndef get_git_info() -> dict[str, str]:\n  \"\"\"Get current git branch and commit info.\n\n  Returns:\n    Dictionary with branch and commit info.\n  \"\"\"\n  try:\n    branch = subprocess.run(\n        [\"git\", \"branch\", \"--show-current\"],\n        capture_output=True,\n        text=True,\n        check=True,\n    ).stdout.strip()\n\n    commit = subprocess.run(\n        [\"git\", \"rev-parse\", \"--short\", \"HEAD\"],\n        capture_output=True,\n        text=True,\n        check=True,\n    ).stdout.strip()\n\n    status = subprocess.run(\n        [\"git\", \"status\", \"--porcelain\"],\n        capture_output=True,\n        text=True,\n        check=True,\n    ).stdout.strip()\n\n    if status:\n      commit += \"-dirty\"\n\n    return {\"branch\": branch, \"commit\": commit}\n  except subprocess.CalledProcessError:\n    return {\"branch\": \"unknown\", \"commit\": \"unknown\"}\n\n\ndef analyze_tokenization(\n    text: str, tokenizer_inst: tokenizer.Tokenizer | None = None\n) -> dict[str, Any]:\n  \"\"\"Analyze tokenization of given text.\n\n  Args:\n    text: Text to analyze.\n    tokenizer_inst: Tokenizer instance to use (default: RegexTokenizer).\n\n  Returns:\n    Dictionary with tokenization metrics.\n  \"\"\"\n  if tokenizer_inst:\n    tokenized = tokenizer_inst.tokenize(text)\n  else:\n    tokenized = tokenizer.tokenize(text)\n  num_tokens = len(tokenized.tokens)\n  num_chars = len(text)\n  tokens_per_char = num_tokens / num_chars if num_chars > 0 else 0\n\n  return {\n      \"num_tokens\": num_tokens,\n      \"num_chars\": num_chars,\n      \"tokens_per_char\": tokens_per_char,\n  }\n\n\ndef format_tokenization_summary(analysis: dict[str, Any]) -> str:\n  \"\"\"Format tokenization analysis as summary string.\n\n  Args:\n    analysis: Tokenization analysis dict.\n\n  Returns:\n    Formatted summary string.\n  \"\"\"\n  return (\n      f\"{analysis['num_tokens']} tokens, \"\n      f\"{analysis['tokens_per_char']:.3f} tok/char\"\n  )\n"
  },
  {
    "path": "docs/examples/batch_api_example.md",
    "content": "# Vertex AI Batch Processing Guide\n\nThe Vertex AI Batch API offers significant cost savings (~50%) for large, non-time-critical workloads. `langextract` seamlessly integrates this with automatic routing, caching, and fault tolerance.\n\n**[Vertex AI Batch Prediction Documentation →](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini)**\n**[Quotas & Limits →](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/quotas#batch-prediction-quotas)**\n\n## Real-World Example: Processing Shakespeare\n\nThis example demonstrates how to process a large text (the first ~20 pages of *Romeo and Juliet*) using the Batch API. We use a small chunk size (`max_char_buffer=500`) to generate enough chunks to trigger batch processing.\n\n```python\nimport requests\nimport textwrap\nimport langextract as lx\nimport logging\n\n# Configure logging to see progress (both in console and file)\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s',\n    handlers=[\n        logging.FileHandler(\"batch_process.log\"),\n        logging.StreamHandler()\n    ]\n)\n\n# 1. Download Text (Shakespeare's Romeo and Juliet)\nurl = \"https://www.gutenberg.org/files/1513/1513-0.txt\"\nprint(f\"Downloading {url}...\")\ntext = requests.get(url).text\n\n# Process first ~20 pages (approx. 60k characters).\ntext_subset = text[:60000]\nprint(f\"Processing first {len(text_subset)} characters...\")\n\n# 2. Define Prompt & Examples\nprompt = textwrap.dedent(\"\"\"\\\n    Extract characters and emotions from the text.\n    Use exact text from the input for extraction_text.\"\"\")\n\nexamples = [\n    lx.data.ExampleData(\n        text=\"ROMEO. But soft! What light through yonder window breaks?\",\n        extractions=[\n            lx.data.Extraction(extraction_class=\"character\", extraction_text=\"ROMEO\"),\n            lx.data.Extraction(extraction_class=\"emotion\", extraction_text=\"But soft!\"),\n        ]\n    )\n]\n\n# 3. Configure Batch Settings\nbatch_config = {\n    \"enabled\": True,\n    \"threshold\": 10,\n    \"poll_interval\": 30,\n    \"timeout\": 3600,\n    # Set to True to cache results in GCS. Add timestamp to prompt to force re-run.\n    \"enable_caching\": True,\n    # Retention policy for GCS bucket (days). None for permanent.\n    \"retention_days\": 30,\n}\n\n# 4. Run Extraction\n# langextract will automatically chunk the text and submit a batch job.\nresults = lx.extract(\n    text_or_documents=text_subset,\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"gemini-2.5-flash\",\n    max_char_buffer=500,\n    batch_length=1000,\n    language_model_params={\n        \"vertexai\": True,\n        \"project\": \"your-gcp-project\", # TODO: Replace with your Project ID.\n        \"location\": \"us-central1\",\n        \"batch\": batch_config\n    }\n)\n\n## GCS File Structure\n\nThe library automatically creates and manages a GCS bucket for you, named:\n`langextract-{project}-{location}-batch`\n\nInside this bucket, data is organized as follows:\n\n- **Input**: `batch-input/{job_name}.jsonl`\n- **Output**: `batch-input/{job_name}/dest/prediction-model-{timestamp}/predictions.jsonl`\n- **Cache**: `cache/{hash}.json` (Individual cached results)\n\n## Cost Optimization & Caching\n\nLangExtract's batch processing is designed to minimize costs:\n\n1.  **Cost Efficiency**: Vertex AI Batch predictions are typically ~50% cheaper than online predictions.\n2.  **Smart Caching**:\n    -   Results are cached in your GCS bucket (`cache/` directory).\n    -   **Instant Retrieval**: Re-running identical prompts fetches results directly from storage, bypassing model inference.\n    -   **Reduced Inference**: You avoid paying for redundant model calls on previously processed data.\n    -   **Lifecycle Management**: Use `retention_days` (e.g., 30) to automatically clean up old data and manage storage usage.\n\n## Analyze Results\nprint(f\"Extracted {len(results.extractions)} entities.\")\nprint(\"First 5 extractions:\")\nfor extraction in results.extractions[:5]:\n    print(f\"- {extraction.extraction_class}: {extraction.extraction_text}\")\n```\n\n## Sample Output\n\n```text\nExtracted 767 entities.\nFirst 5 extractions:\n- character: ESCALUS\n- character: MERCUTIO\n- character: PARIS\n- character: Page to Paris\n- character: MONTAGUE\n```\n\n> **Note on `batch_length`**: The `batch_length` parameter controls how many chunks are submitted in a single batch job. For optimal performance with the Batch API, set this to a high value (e.g., `1000`) to process all chunks in a single job rather than multiple sequential jobs.\n\n## Key Features\n\n### 1. Automatic Routing\n`langextract` automatically switches between real-time and batch APIs based on your `threshold`.\n- **< Threshold**: Uses real-time API for immediate results.\n- **>= Threshold**: Uses Batch API for cost savings.\n\n### 2. Fault Tolerance & Caching\nBuilt-in GCS caching (`enable_caching=True`) allows you to resume interrupted jobs without re-processing completed items, saving time and cost.\n\n### 3. Automated Storage\n`langextract` handles all GCS operations automatically using a dedicated bucket (`gs://langextract-{project}-{location}-batch`). Note that input/output files are retained for debugging.\n\n## Tracking Job Status\n\nTo monitor progress, you can watch the log file from a separate terminal:\n\n```bash\ntail -f batch_process.log\n```\n\nWhen running a batch job, `langextract` provides clear log feedback with a direct link to the Google Cloud Console:\n\n```text\nINFO - Batch job created successfully: projects/123456789/locations/us-central1/batchPredictionJobs/987654321\nINFO - Job State: JobState.JOB_STATE_PENDING\nINFO - Job Console URL: https://console.cloud.google.com/vertex-ai/jobs/batch-predictions/987654321?project=123456789\nINFO - Batch job is running... (State: JOB_STATE_PENDING)\nINFO - Batch job is running... (State: JOB_STATE_RUNNING)\n```\n\n- **Completion**: Once the job succeeds, `langextract` automatically downloads, parses, and aligns the results.\n"
  },
  {
    "path": "docs/examples/japanese_extraction.md",
    "content": "# Japanese Information Extraction\n\nThis example demonstrates how to use LangExtract to extract structured information from Japanese text.\n\n> **Note:** For non-spaced languages like Japanese, use `UnicodeTokenizer` to ensure correct character-based segmentation and alignment.\n\n## Full Pipeline Example\n\n```python\nimport langextract as lx\nfrom langextract.core import tokenizer\n\n# Japanese text with entities (Person, Location, Organization)\n# \"Mr. Tanaka from Tokyo works at Google.\"\ninput_text = \"東京出身の田中さんはGoogleで働いています。\"\n\n# Define extraction prompt\nprompt_description = \"Extract named entities including Person, Location, and Organization.\"\n\n# Define example data (few-shot examples help the model understand the task)\nexamples = [\n    lx.data.ExampleData(\n        text=\"大阪の山田さんはソニーに入社しました。\",  # Mr. Yamada from Osaka joined Sony.\n        extractions=[\n            lx.data.Extraction(extraction_class=\"Location\", extraction_text=\"大阪\"),\n            lx.data.Extraction(extraction_class=\"Person\", extraction_text=\"山田\"),\n            lx.data.Extraction(extraction_class=\"Organization\", extraction_text=\"ソニー\"),\n        ]\n    )\n]\n\n# 1. Initialize the UnicodeTokenizer\n# Essential for Japanese to ensure correct grapheme segmentation.\nunicode_tokenizer = tokenizer.UnicodeTokenizer()\n\n# 2. Run Extraction with the Custom Tokenizer\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt_description,\n    examples=examples,\n    model_id=\"gemini-2.5-flash\",\n    tokenizer=unicode_tokenizer,   # <--- Pass the tokenizer here\n    api_key=\"your-api-key-here\"    # Optional if env var is set\n)\n\n# 3. Display Results\nprint(f\"Input: {input_text}\\n\")\nprint(\"Extracted Entities:\")\nfor entity in result.extractions:\n    position_info = \"\"\n    if entity.char_interval:\n        start, end = entity.char_interval.start_pos, entity.char_interval.end_pos\n        position_info = f\" (pos: {start}-{end})\"\n    \n    print(f\"• {entity.extraction_class}: {entity.extraction_text}{position_info}\")\n\n# Expected Output:\n# Input: 東京出身の田中さんはGoogleで働いています。\n#\n# Extracted Entities:\n# • Location: 東京 (pos: 0-2)\n# • Person: 田中 (pos: 5-7)\n# • Organization: Google (pos: 10-16)\n```\n"
  },
  {
    "path": "docs/examples/longer_text_example.md",
    "content": "# *Romeo and Juliet* Full Text Extraction\n\nLangExtract can process entire documents directly from URLs, handling large texts with high accuracy through parallel processing and enhanced sensitivity features. This example demonstrates extraction from the complete text of *Romeo and Juliet* from Project Gutenberg.\n\n## Example code\n\nThe following code uses a comprehensive prompt and examples optimized for large, complex literary texts. For large complex inputs, using more detailed examples is suggested to increase extraction robustness.\n\n> **Warning:** Running this example processes a large document (~44 000 tokens) and will incur costs. For large-scale use, a Tier 2 Gemini quota is suggested to avoid rate-limit issues ([details](https://ai.google.dev/gemini-api/docs/rate-limits#tier-2)). Please review the [Gemini API pricing](https://ai.google.dev/gemini-api/docs/pricing) before proceeding.\n\n```python\nimport langextract as lx\nimport textwrap\nfrom collections import Counter, defaultdict\n\n# Define comprehensive prompt and examples for complex literary text\nprompt = textwrap.dedent(\"\"\"\\\n    Extract characters, emotions, and relationships from the given text.\n\n    Provide meaningful attributes for every entity to add context and depth.\n\n    Important: Use exact text from the input for extraction_text. Do not paraphrase.\n    Extract entities in order of appearance with no overlapping text spans.\n\n    Note: In play scripts, speaker names appear in ALL-CAPS followed by a period.\"\"\")\n\nexamples = [\n    lx.data.ExampleData(\n        text=textwrap.dedent(\"\"\"\\\n            ROMEO. But soft! What light through yonder window breaks?\n            It is the east, and Juliet is the sun.\n            JULIET. O Romeo, Romeo! Wherefore art thou Romeo?\"\"\"),\n        extractions=[\n            lx.data.Extraction(\n                extraction_class=\"character\",\n                extraction_text=\"ROMEO\",\n                attributes={\"emotional_state\": \"wonder\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"emotion\",\n                extraction_text=\"But soft!\",\n                attributes={\"feeling\": \"gentle awe\", \"character\": \"Romeo\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"relationship\",\n                extraction_text=\"Juliet is the sun\",\n                attributes={\"type\": \"metaphor\", \"character_1\": \"Romeo\", \"character_2\": \"Juliet\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"character\",\n                extraction_text=\"JULIET\",\n                attributes={\"emotional_state\": \"yearning\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"emotion\",\n                extraction_text=\"Wherefore art thou Romeo?\",\n                attributes={\"feeling\": \"longing question\", \"character\": \"Juliet\"}\n            ),\n        ]\n    )\n]\n\n# Process Romeo & Juliet directly from Project Gutenberg\nprint(\"Downloading and processing Romeo and Juliet from Project Gutenberg...\")\n\nresult = lx.extract(\n    text_or_documents=\"https://www.gutenberg.org/files/1513/1513-0.txt\",\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"gemini-2.5-flash\",\n    extraction_passes=3,      # Multiple passes for improved recall\n    max_workers=20,           # Parallel processing for speed\n    max_char_buffer=1000      # Smaller contexts for better accuracy\n)\n\nprint(f\"Extracted {len(result.extractions)} entities from {len(result.text):,} characters\")\n\n# Save and visualize the results\nlx.io.save_annotated_documents([result], output_name=\"romeo_juliet_extractions.jsonl\", output_dir=\".\")\n\n# Generate the interactive visualization\nhtml_content = lx.visualize(\"romeo_juliet_extractions.jsonl\")\nwith open(\"romeo_juliet_visualization.html\", \"w\") as f:\n    if hasattr(html_content, 'data'):\n        f.write(html_content.data)  # For Jupyter/Colab\n    else:\n        f.write(html_content)\n\nprint(\"Interactive visualization saved to romeo_juliet_visualization.html\")\n```\n\nThis creates an interactive HTML visualization for exploring the extracted entities:\n\n![Romeo and Juliet Full Visualization](../_static/romeo_juliet_full.gif)\n\n```python\n\n# Analyze character mentions\ncharacters = {}\nfor e in result.extractions:\n    if e.extraction_class == \"character\":\n        char_name = e.extraction_text\n        if char_name not in characters:\n            characters[char_name] = {\"count\": 0, \"attributes\": set()}\n        characters[char_name][\"count\"] += 1\n        if e.attributes:\n            for attr_key, attr_val in e.attributes.items():\n                characters[char_name][\"attributes\"].add(f\"{attr_key}: {attr_val}\")\n\n# Print character summary\nprint(f\"\\nCHARACTER SUMMARY ({len(characters)} unique characters)\")\nprint(\"=\" * 60)\n\nsorted_chars = sorted(characters.items(), key=lambda x: x[1][\"count\"], reverse=True)\nfor char_name, char_data in sorted_chars[:10]:  # Top 10 characters\n    attrs_preview = list(char_data[\"attributes\"])[:3]\n    attrs_str = f\" ({', '.join(attrs_preview)})\" if attrs_preview else \"\"\n    print(f\"{char_name}: {char_data['count']} mentions{attrs_str}\")\n\n# Entity type breakdown\nentity_counts = Counter(e.extraction_class for e in result.extractions)\nprint(f\"\\nENTITY TYPE BREAKDOWN\")\nprint(\"=\" * 60)\nfor entity_type, count in entity_counts.most_common():\n    percentage = (count / len(result.extractions)) * 100\n    print(f\"{entity_type}: {count} ({percentage:.1f}%)\")\n```\n\n## Sample output\n\n```\nDownloading and processing Romeo and Juliet from Project Gutenberg...\nDownloaded 147,843 characters (25,976 words) from 1513-0.txt\nExtracted 4,088 entities from 147,843 characters\nInteractive visualization saved to romeo_juliet_visualization.html\n\nCHARACTER SUMMARY (153 unique characters)\n============================================================\nROMEO: 287 mentions (emotional_state: excitement, emotional_state: eager to please)\nJULIET: 204 mentions (emotional_state: fond, emotional_state: resilient)\nNURSE: 168 mentions (emotional_state: reporting, emotional_state: teasing and evasive)\nMERCUTIO: 107 mentions (emotional_state: approving, emotional_state: responsive)\nBENVOLIO: 82 mentions (emotional_state: cautious, emotional_state: teasing)\n\nENTITY TYPE BREAKDOWN\n============================================================\ncharacter: 1,685 (41.2%)\nemotion: 1,524 (37.3%)\nrelationship: 879 (21.5%)\n```\n\n## Key benefits for long documents\n\n### Sequential extraction passes\n\nMultiple extraction passes improve recall by performing independent extractions and merging non-overlapping results. Each pass uses identical parameters and processing—they are independent runs of the same extraction task. The number of passes is controlled by the `extraction_passes` parameter (e.g., `extraction_passes=3`).\n\n**How it works**: Each pass processes the full text independently using the same prompt and examples. Results are then merged using a \"first-pass wins\" strategy for overlapping entities, while adding unique non-overlapping entities from later passes. This approach captures entities that might be missed in any single run due to the stochastic nature of language model generation.\n\n### Portable and Interoperable Data with JSONL\nLangExtract uses JSONL, a human-readable format ideal for language model data. Each line is a self-contained JSON object, making outputs easy to parse, share, and integrate with other tools. You can save results with `lx.io.save_annotated_documents` and reload them for later analysis, ensuring your data is both portable and persistent.\n\n### Optimal long context management\nWhile single-inference approaches can be powerful, their accuracy may be affected by distant context. LangExtract uses smart chunking strategies that respect text delimiters (such as paragraph breaks) to keep context intact and well-formed for the model. Users can configure context sizes (`max_char_buffer`) combined with parallel processing (`max_workers`) to maintain extraction quality across large documents. Multiple sequential extraction passes further enhance sensitivity by capturing entities that might be missed in any single run due to the stochastic nature of language model generation.\n\n### Enhanced accuracy through chunking\nThe chunked processing approach can improve extraction quality over a single inference pass on a large document because each chunk uses a smaller, more manageable context size. This helps the model focus on the most relevant information and prevents interference from distant context. While the overall latency and time required remain similar due to parallelization, the extraction quality can be substantially higher with better entity coverage and more accurate attribute assignment across the entire document.¹\n\n### Interactive visualization at scale\nSeamlessly explore hundreds or thousands of entities through interactive HTML visualizations generated directly from JSONL output files. The generated visualizations handle large result sets efficiently, providing navigation and detailed entity inspection capabilities for comprehensive analysis of complex documents.\n\n### Schema-guided knowledge extraction\nLangExtract combines precise text positioning with world knowledge enrichment, enabling extraction of information not explicitly stated in the text (like character identities and traits). Under the hood, the library implements [Controlled Generation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/control-generated-output) with supported models to ensure extracted data adheres to your specified schema while maintaining robust extractions across large inputs.\n\n---\n\n¹ Models like Gemini 1.5 Pro show strong performance on many benchmarks, but [needle-in-a-haystack tests](https://cloud.google.com/blog/products/ai-machine-learning/the-needle-in-the-haystack-test-and-how-gemini-pro-solves-it) across million-token contexts indicate that performance can vary in multi-fact retrieval scenarios. This demonstrates how LangExtract's smaller context windows approach ensures consistently high quality across entire documents by avoiding the complexity and potential degradation of massive single-context processing.\n"
  },
  {
    "path": "docs/examples/medication_examples.md",
    "content": "# Medication Extraction Examples\n\nLangExtract excels at extracting structured medical information from clinical text, making it particularly useful for healthcare applications. The methodology originated from research in medical information extraction, where early versions of the techniques were demonstrated to accelerate annotation tasks significantly.\n\n> **Disclaimer:** This demonstration is only for illustrative purposes of LangExtract's baseline capability. It does not represent a finished or approved product, is not intended to diagnose or suggest treatment of any disease or condition, and should not be used for medical advice.\n\n---\n\n**Medical Information Extraction Research:**\nThe concepts and methods underlying LangExtract were first demonstrated in:\n\nGoel, A., Lehman, E., Gulati, A., Chen, R., Nori, H., Hager, G. D., & Durr, N. J. (2023).\n\"LLMs Accelerate Annotation for Medical Information Extraction.\"\n*Machine Learning for Health (ML4H), PMLR, 2023*.\n[arXiv:2312.02296](https://arxiv.org/abs/2312.02296)\n\n---\n\n## Basic Named Entity Recognition (NER)\n\nIn this basic medical example, LangExtract extracts structured medication information:\n\n```python\nimport langextract as lx\n\n# Text with a medication mention\ninput_text = \"Patient took 400 mg PO Ibuprofen q4h for two days.\"\n\n# Define extraction prompt\nprompt_description = \"Extract medication information including medication name, dosage, route, frequency, and duration in the order they appear in the text.\"\n\n# Define example data with entities in order of appearance\nexamples = [\n    lx.data.ExampleData(\n        text=\"Patient was given 250 mg IV Cefazolin TID for one week.\",\n        extractions=[\n            lx.data.Extraction(extraction_class=\"dosage\", extraction_text=\"250 mg\"),\n            lx.data.Extraction(extraction_class=\"route\", extraction_text=\"IV\"),\n            lx.data.Extraction(extraction_class=\"medication\", extraction_text=\"Cefazolin\"),\n            lx.data.Extraction(extraction_class=\"frequency\", extraction_text=\"TID\"),  # TID = three times a day\n            lx.data.Extraction(extraction_class=\"duration\", extraction_text=\"for one week\")\n        ]\n    )\n]\n\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt_description,\n    examples=examples,\n    model_id=\"gemini-2.5-pro\",\n    api_key=\"your-api-key-here\"  # Optional if LANGEXTRACT_API_KEY environment variable is set\n)\n\n# Display entities with positions\nprint(f\"Input: {input_text}\\n\")\nprint(\"Extracted entities:\")\nfor entity in result.extractions:\n    position_info = \"\"\n    if entity.char_interval:\n        start, end = entity.char_interval.start_pos, entity.char_interval.end_pos\n        position_info = f\" (pos: {start}-{end})\"\n    print(f\"• {entity.extraction_class.capitalize()}: {entity.extraction_text}{position_info}\")\n\n# Save and visualize the results\nlx.io.save_annotated_documents([result], output_name=\"medical_ner_extraction.jsonl\", output_dir=\".\")\n\n# Generate the interactive visualization\nhtml_content = lx.visualize(\"medical_ner_extraction.jsonl\")\nwith open(\"medical_ner_visualization.html\", \"w\") as f:\n    if hasattr(html_content, 'data'):\n        f.write(html_content.data)  # For Jupyter/Colab\n    else:\n        f.write(html_content)\n\nprint(\"Interactive visualization saved to medical_ner_visualization.html\")\n```\n\n![Medical NER Visualization](../_static/medication_entity.gif)\n\nThis will produce an output similar to:\n\n```\nInput: Patient took 400 mg PO Ibuprofen q4h for two days.\n\nExtracted entities:\n• Dosage: 400 mg (pos: 13-19)\n• Route: PO (pos: 20-22)\n• Medication: Ibuprofen (pos: 23-32)\n• Frequency: q4h (pos: 33-36)\n• Duration: for two days (pos: 37-49)\nInteractive visualization saved to medical_ner_visualization.html\n```\n\nThe interactive HTML visualization allows you to explore the extracted entities visually, with each entity type color-coded and clickable for detailed inspection.\n\n## Relationship Extraction (RE)\n\nFor more complex extractions that involve relationships between entities, LangExtract can also extract structured relationships. This example shows how to extract medications and their associated attributes:\n\n```python\nimport langextract as lx\n\n# Text with interleaved medication mentions\ninput_text = \"\"\"\nThe patient was prescribed Lisinopril and Metformin last month.\nHe takes the Lisinopril 10mg daily for hypertension, but often misses\nhis Metformin 500mg dose which should be taken twice daily for diabetes.\n\"\"\"\n\n# Define extraction prompt\nprompt_description = \"\"\"\nExtract medications with their details, using attributes to group related information:\n\n1. Extract entities in the order they appear in the text\n2. Each entity must have a 'medication_group' attribute linking it to its medication\n3. All details about a medication should share the same medication_group value\n\"\"\"\n\n# Define example data with medication groups\nexamples = [\n    lx.data.ExampleData(\n        text=\"Patient takes Aspirin 100mg daily for heart health and Simvastatin 20mg at bedtime.\",\n        extractions=[\n            # First medication group\n            lx.data.Extraction(\n                extraction_class=\"medication\",\n                extraction_text=\"Aspirin\",\n                attributes={\"medication_group\": \"Aspirin\"}  # Group identifier\n            ),\n            lx.data.Extraction(\n                extraction_class=\"dosage\",\n                extraction_text=\"100mg\",\n                attributes={\"medication_group\": \"Aspirin\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"frequency\",\n                extraction_text=\"daily\",\n                attributes={\"medication_group\": \"Aspirin\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"condition\",\n                extraction_text=\"heart health\",\n                attributes={\"medication_group\": \"Aspirin\"}\n            ),\n\n            # Second medication group\n            lx.data.Extraction(\n                extraction_class=\"medication\",\n                extraction_text=\"Simvastatin\",\n                attributes={\"medication_group\": \"Simvastatin\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"dosage\",\n                extraction_text=\"20mg\",\n                attributes={\"medication_group\": \"Simvastatin\"}\n            ),\n            lx.data.Extraction(\n                extraction_class=\"frequency\",\n                extraction_text=\"at bedtime\",\n                attributes={\"medication_group\": \"Simvastatin\"}\n            )\n        ]\n    )\n]\n\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt_description,\n    examples=examples,\n    model_id=\"gemini-2.5-pro\",\n    api_key=\"your-api-key-here\"  # Optional if LANGEXTRACT_API_KEY environment variable is set\n)\n\n# Display grouped medications\nprint(f\"Input text: {input_text.strip()}\\n\")\nprint(\"Extracted Medications:\")\n\n# Group by medication\nmedication_groups = {}\nfor extraction in result.extractions:\n    if not extraction.attributes or \"medication_group\" not in extraction.attributes:\n        print(f\"Warning: Missing medication_group for {extraction.extraction_text}\")\n        continue\n\n    group_name = extraction.attributes[\"medication_group\"]\n    medication_groups.setdefault(group_name, []).append(extraction)\n\n# Print each medication group\nfor med_name, extractions in medication_groups.items():\n    print(f\"\\n* {med_name}\")\n    for extraction in extractions:\n        position_info = \"\"\n        if extraction.char_interval:\n            start, end = extraction.char_interval.start_pos, extraction.char_interval.end_pos\n            position_info = f\" (pos: {start}-{end})\"\n        print(f\"  • {extraction.extraction_class.capitalize()}: {extraction.extraction_text}{position_info}\")\n\n# Save and visualize the results\nlx.io.save_annotated_documents(\n    [result],\n    output_name=\"medical_relationship_extraction.jsonl\",\n    output_dir=\".\"\n)\n\n# Generate the interactive visualization\nhtml_content = lx.visualize(\"medical_relationship_extraction.jsonl\")\nwith open(\"medical_relationship_visualization.html\", \"w\") as f:\n    if hasattr(html_content, 'data'):\n        f.write(html_content.data)  # For Jupyter/Colab\n    else:\n        f.write(html_content)\n\nprint(\"Interactive visualization saved to medical_relationship_visualization.html\")\n```\n\n![Medical Relationship Visualization](../_static/medication_entity_re.gif)\n\nThis will produce output similar to:\n\n```\nInput text: The patient was prescribed Lisinopril and Metformin last month.\nHe takes the Lisinopril 10mg daily for hypertension, but often misses\nhis Metformin 500mg dose which should be taken twice daily for diabetes.\n\nExtracted Medications:\n\n* Lisinopril\n  • Medication: Lisinopril (pos: 28-38)\n  • Dosage: 10mg (pos: 89-93)\n  • Frequency: daily (pos: 94-99)\n  • Condition: hypertension (pos: 104-116)\n\n* Metformin\n  • Medication: Metformin (pos: 43-52)\n  • Dosage: 500mg (pos: 149-154)\n  • Frequency: twice daily (pos: 182-193)\n  • Condition: diabetes (pos: 198-206)\nInteractive visualization saved to medical_relationship_visualization.html\n```\n\nThe visualization highlights how the `medication_group` attributes connect related entities, making it easy to see which dosages, frequencies, and conditions belong to each medication. Each medication group is visually distinguished in the interactive display.\n\n**Understanding Relationship Extraction:**\nThis example demonstrates how attributes enable efficient relationship extraction. Using the `medication_group` attribute as a linking key, related entities are grouped together logically. This approach simplifies extracting connected information and eliminates the need for additional processing steps, while preserving the precise alignment between extracted text and its original location in the document. The interactive visualization makes these relationships immediately apparent, with connected entities sharing visual groupings and color coding.\n\n## Key Features Demonstrated\n\n- **Named Entity Recognition**: Extracts entities with their types (medication, dosage, route, etc.)\n- **Relationship Extraction**: Groups related entities using attributes\n- **Position Tracking**: Records exact positions of extracted entities in the source text\n- **Structured Output**: Organizes information in a format suitable for healthcare applications\n- **Interactive Visualization**: Generates HTML visualizations for exploring complex medical extractions with entity groupings and relationships clearly displayed\n"
  },
  {
    "path": "examples/custom_provider_plugin/README.md",
    "content": "# Custom Provider Plugin Example\n\nThis example demonstrates how to create a custom provider plugin that extends LangExtract with your own model backend.\n\n**Note**: This is an example included in the LangExtract repository for reference. It is not part of the LangExtract package and won't be installed when you `pip install langextract`.\n\n**Automated Creation**: Instead of manually copying this example, use the [provider plugin generator script](../../scripts/create_provider_plugin.py):\n```bash\npython scripts/create_provider_plugin.py MyProvider --with-schema\n```\nThis will create a complete plugin structure with all boilerplate code ready for customization.\n\n## Structure\n\n```\ncustom_provider_plugin/\n├── pyproject.toml                      # Package configuration and metadata\n├── README.md                            # This file\n├── langextract_provider_example/        # Package directory\n│   ├── __init__.py                     # Package initialization\n│   ├── provider.py                     # Custom provider implementation\n│   └── schema.py                       # Custom schema implementation (optional)\n└── test_example_provider.py            # Test script\n```\n\n## Key Components\n\n### Provider Implementation (`provider.py`)\n\n```python\n@lx.providers.registry.register(\n    r'^gemini',  # Pattern for model IDs this provider handles\n)\nclass CustomGeminiProvider(lx.inference.BaseLanguageModel):\n    def __init__(self, model_id: str, **kwargs):\n        # Initialize your backend client\n\n    def infer(self, batch_prompts, **kwargs):\n        # Call your backend API and return results\n```\n\n### Package Configuration (`pyproject.toml`)\n\n```toml\n[project.entry-points.\"langextract.providers\"]\ncustom_gemini = \"langextract_provider_example:CustomGeminiProvider\"\n```\n\nThis entry point allows LangExtract to automatically discover your provider.\n\n### Custom Schema Support (`schema.py`)\n\nProviders can optionally implement custom schemas for structured output:\n\n**Flow:** Examples → `from_examples()` → `to_provider_config()` → Provider kwargs → Inference\n\n```python\nclass CustomProviderSchema(lx.schema.BaseSchema):\n    @classmethod\n    def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n        # Analyze examples to find patterns\n        # Build schema based on extraction classes and attributes seen\n        return cls(schema_dict)\n\n    def to_provider_config(self):\n        # Convert schema to provider kwargs\n        return {\n            \"response_schema\": self._schema_dict,\n            \"enable_structured_output\": True\n        }\n\n    @property\n    def supports_strict_mode(self):\n        # True = valid JSON output, no markdown fences needed\n        return True\n```\n\nThen in your provider:\n\n```python\nclass CustomProvider(lx.inference.BaseLanguageModel):\n    @classmethod\n    def get_schema_class(cls):\n        return CustomProviderSchema  # Tell LangExtract about your schema\n\n    def __init__(self, **kwargs):\n        # Receive schema config in kwargs when use_schema_constraints=True\n        self.response_schema = kwargs.get('response_schema')\n\n    def infer(self, batch_prompts, **kwargs):\n        # Use schema during API calls\n        if self.response_schema:\n            config['response_schema'] = self.response_schema\n```\n\n## Installation\n\n```bash\n# Navigate to this example directory first\ncd examples/custom_provider_plugin\n\n# Install in development mode\npip install -e .\n\n# Test the provider (must be run from this directory)\npython test_example_provider.py\n```\n\n## Usage\n\nSince this example registers the same pattern as the default Gemini provider, you must explicitly specify it:\n\n```python\nimport langextract as lx\n\n# Create a configured model with explicit provider selection\nconfig = lx.factory.ModelConfig(\n    model_id=\"gemini-2.5-flash\",\n    provider=\"CustomGeminiProvider\",\n    provider_kwargs={\"api_key\": \"your-api-key\"}\n)\nmodel = lx.factory.create_model(config)\n\n# Note: Passing model directly to extract() is coming soon.\n# For now, use the model's infer() method directly or pass parameters individually:\nresult = lx.extract(\n    text_or_documents=\"Your text here\",\n    model_id=\"gemini-2.5-flash\",\n    api_key=\"your-api-key\",\n    prompt_description=\"Extract key information\",\n    examples=[...]\n)\n\n# Coming soon: Direct model passing\n# result = lx.extract(\n#     text_or_documents=\"Your text here\",\n#     model=model,  # Planned feature\n#     prompt_description=\"Extract key information\"\n# )\n```\n\n## Creating Your Own Provider - Step by Step\n\n### 1. Copy and Rename\n```bash\n# Copy this example directory\ncp -r examples/custom_provider_plugin/ ~/langextract-myprovider/\n\n# Rename the package directory\ncd ~/langextract-myprovider/\nmv langextract_provider_example langextract_myprovider\n```\n\n### 2. Update Package Configuration\nEdit `pyproject.toml`:\n- Change `name = \"langextract-myprovider\"`\n- Update description and author information\n- Change entry point: `myprovider = \"langextract_myprovider:MyProvider\"`\n\n### 3. Modify Provider Implementation\nEdit `provider.py`:\n- Change class name from `CustomGeminiProvider` to `MyProvider`\n- Update `@register()` patterns to match your model IDs\n- Replace Gemini API calls with your backend\n- Add any provider-specific parameters\n\n### 4. Add Schema Support (Optional)\nEdit `schema.py`:\n- Rename to `MyProviderSchema`\n- Customize `from_examples()` for your extraction format\n- Update `to_provider_config()` for your API requirements\n- Set `supports_strict_mode` based on your capabilities\n\n### 5. Install and Test\n```bash\n# Install in development mode\npip install -e .\n\n# Test your provider\npython -c \"\nimport langextract as lx\nlx.providers.load_plugins_once()\nprint('Provider registered:', any('myprovider' in str(e) for e in lx.providers.registry.list_entries()))\n\"\n```\n\n### 6. Write Tests\n- Test that your provider loads and handles basic inference\n- Verify schema support works (if implemented)\n- Test error handling for your specific API\n\n### 7. Publish to PyPI and Share with Community\n```bash\n# Build package\npython -m build\n\n# Upload to PyPI\ntwine upload dist/*\n```\n\n**Share with the community:**\n- Submit a PR to add your provider to the [Community Providers Registry](../../COMMUNITY_PROVIDERS.md)\n- Open an issue on [LangExtract GitHub](https://github.com/google/langextract/issues) to announce your provider and get feedback\n\n## Common Pitfalls to Avoid\n\n1. **Forgetting to trigger plugin loading** - Plugins load lazily, use `load_plugins_once()` in tests\n2. **Pattern conflicts** - Avoid patterns that conflict with built-in providers\n3. **Missing dependencies** - List all requirements in `pyproject.toml`\n4. **Schema mismatches** - Test schema generation with real examples\n5. **Not handling None schema** - Provider must clear schema when `apply_schema(None)` is called (see provider.py for implementation)\n\n## License\n\nApache License 2.0\n"
  },
  {
    "path": "examples/custom_provider_plugin/langextract_provider_example/__init__.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Example custom provider plugin for LangExtract.\"\"\"\n\nfrom langextract_provider_example.provider import CustomGeminiProvider\n\n__all__ = [\"CustomGeminiProvider\"]\n__version__ = \"0.1.0\"\n"
  },
  {
    "path": "examples/custom_provider_plugin/langextract_provider_example/provider.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Minimal example of a custom provider plugin for LangExtract.\"\"\"\n\nfrom __future__ import annotations\n\nimport dataclasses\nfrom typing import Any, Iterator, Sequence\n\nfrom langextract_provider_example import schema as custom_schema\n\nimport langextract as lx\n\n\n@lx.providers.registry.register(\n    r'^gemini',  # Matches Gemini model IDs (same as default provider)\n)\n@dataclasses.dataclass(init=False)\nclass CustomGeminiProvider(lx.inference.BaseLanguageModel):\n  \"\"\"Example custom LangExtract provider implementation.\n\n  This demonstrates how to create a custom provider for LangExtract\n  that can intercept and handle model requests. This example wraps\n  the actual Gemini API to show how custom schemas integrate, but you\n  would replace the Gemini calls with your own API or model implementation.\n\n  Note: Since this registers the same pattern as the default Gemini provider,\n  you must explicitly specify this provider when creating a model:\n\n  config = lx.factory.ModelConfig(\n      model_id=\"gemini-2.5-flash\",\n      provider=\"CustomGeminiProvider\"\n  )\n  model = lx.factory.create_model(config)\n  \"\"\"\n\n  model_id: str\n  api_key: str | None\n  temperature: float\n  response_schema: dict[str, Any] | None = None\n  enable_structured_output: bool = False\n  _client: Any = dataclasses.field(repr=False, compare=False)\n\n  def __init__(\n      self,\n      model_id: str = 'gemini-2.5-flash',\n      api_key: str | None = None,\n      temperature: float = 0.0,\n      **kwargs: Any,\n  ) -> None:\n    \"\"\"Initialize the custom provider.\n\n    Args:\n      model_id: The model ID.\n      api_key: API key for the service.\n      temperature: Sampling temperature.\n      **kwargs: Additional parameters.\n    \"\"\"\n    super().__init__()\n\n    # TODO: Replace with your own client initialization\n    try:\n      from google import genai  # pylint: disable=import-outside-toplevel\n    except ImportError as e:\n      raise lx.exceptions.InferenceConfigError(\n          'This example requires google-genai package. '\n          'Install with: pip install google-genai'\n      ) from e\n\n    self.model_id = model_id\n    self.api_key = api_key\n    self.temperature = temperature\n\n    # Schema kwargs from CustomProviderSchema.to_provider_config()\n    self.response_schema = kwargs.get('response_schema')\n    self.enable_structured_output = kwargs.get(\n        'enable_structured_output', False\n    )\n\n    # Store any additional kwargs for potential use\n    self._extra_kwargs = kwargs\n\n    if not self.api_key:\n      raise lx.exceptions.InferenceConfigError(\n          'API key required. Set GEMINI_API_KEY or pass api_key parameter.'\n      )\n\n    self._client = genai.Client(api_key=self.api_key)\n\n  @classmethod\n  def get_schema_class(cls) -> type[lx.schema.BaseSchema] | None:\n    \"\"\"Return our custom schema class.\n\n    This allows LangExtract to use our custom schema implementation\n    when use_schema_constraints=True is specified.\n\n    Returns:\n      Our custom schema class that will be used to generate constraints.\n    \"\"\"\n    return custom_schema.CustomProviderSchema\n\n  def apply_schema(self, schema_instance: lx.schema.BaseSchema | None) -> None:\n    \"\"\"Apply or clear schema configuration.\n\n    This method is called by LangExtract to dynamically apply schema\n    constraints after the provider is instantiated. It's important to\n    handle both the application of a new schema and clearing (None).\n\n    Args:\n      schema_instance: The schema to apply, or None to clear existing schema.\n    \"\"\"\n    super().apply_schema(schema_instance)\n\n    if schema_instance:\n      # Apply the new schema configuration\n      config = schema_instance.to_provider_config()\n      self.response_schema = config.get('response_schema')\n      self.enable_structured_output = config.get(\n          'enable_structured_output', False\n      )\n    else:\n      # Clear the schema configuration\n      self.response_schema = None\n      self.enable_structured_output = False\n\n  def infer(\n      self, batch_prompts: Sequence[str], **kwargs: Any\n  ) -> Iterator[Sequence[lx.inference.ScoredOutput]]:\n    \"\"\"Run inference on a batch of prompts.\n\n    Args:\n      batch_prompts: Input prompts to process.\n      **kwargs: Additional generation parameters.\n\n    Yields:\n      Lists of ScoredOutputs, one per prompt.\n    \"\"\"\n    config = {\n        'temperature': kwargs.get('temperature', self.temperature),\n    }\n\n    # Add other parameters if provided\n    for key in ['max_output_tokens', 'top_p', 'top_k']:\n      if key in kwargs:\n        config[key] = kwargs[key]\n\n    # Apply schema constraints if configured\n    if self.response_schema and self.enable_structured_output:\n      # For Gemini, this ensures the model outputs JSON matching our schema\n      # Adapt this section based on your actual provider's API requirements\n      config['response_schema'] = self.response_schema\n      config['response_mime_type'] = 'application/json'\n\n    for prompt in batch_prompts:\n      try:\n        # TODO: Replace this with your own API/model calls\n        response = self._client.models.generate_content(\n            model=self.model_id, contents=prompt, config=config\n        )\n        output = response.text.strip()\n        yield [lx.inference.ScoredOutput(score=1.0, output=output)]\n\n      except Exception as e:\n        raise lx.exceptions.InferenceRuntimeError(\n            f'API error: {str(e)}', original=e\n        ) from e\n"
  },
  {
    "path": "examples/custom_provider_plugin/langextract_provider_example/schema.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Example custom schema implementation for provider plugins.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any, Sequence\n\nimport langextract as lx\n\n\nclass CustomProviderSchema(lx.schema.BaseSchema):\n  \"\"\"Example custom schema implementation for a provider plugin.\n\n  This demonstrates how plugins can provide their own schema implementations\n  that integrate with LangExtract's schema system. Custom schemas allow\n  providers to:\n\n  1. Generate provider-specific constraints from examples\n  2. Control output formatting and validation\n  3. Optimize for their specific model capabilities\n\n  This example generates a JSON schema from the examples and passes it to\n  the Gemini backend (which this example provider wraps) for structured output.\n  \"\"\"\n\n  def __init__(self, schema_dict: dict[str, Any], strict_mode: bool = True):\n    \"\"\"Initialize the custom schema.\n\n    Args:\n      schema_dict: The generated JSON schema dictionary.\n      strict_mode: Whether the provider guarantees valid output.\n    \"\"\"\n    self._schema_dict = schema_dict\n    self._strict_mode = strict_mode\n\n  @classmethod\n  def from_examples(\n      cls,\n      examples_data: Sequence[lx.data.ExampleData],\n      attribute_suffix: str = \"_attributes\",\n  ) -> CustomProviderSchema:\n    \"\"\"Generate schema from example data.\n\n    This method analyzes the provided examples to build a schema that\n    captures the structure of expected extractions. Called automatically\n    by LangExtract when use_schema_constraints=True.\n\n    Args:\n      examples_data: Example extractions to learn from.\n      attribute_suffix: Suffix for attribute fields (unused in this example).\n\n    Returns:\n      A configured CustomProviderSchema instance.\n\n    Example:\n      If examples contain extractions with class \"condition\" and attribute\n      \"severity\", the schema will constrain the model to only output those\n      specific classes and attributes.\n    \"\"\"\n    extraction_classes = set()\n    attribute_keys = set()\n\n    for example in examples_data:\n      for extraction in example.extractions:\n        extraction_classes.add(extraction.extraction_class)\n        if extraction.attributes:\n          attribute_keys.update(extraction.attributes.keys())\n\n    schema_dict = {\n        \"type\": \"object\",\n        \"properties\": {\n            \"extractions\": {\n                \"type\": \"array\",\n                \"items\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"extraction_class\": {\n                            \"type\": \"string\",\n                            \"enum\": (\n                                list(extraction_classes)\n                                if extraction_classes\n                                else None\n                            ),\n                        },\n                        \"extraction_text\": {\"type\": \"string\"},\n                        \"attributes\": {\n                            \"type\": \"object\",\n                            \"properties\": {\n                                key: {\"type\": \"string\"}\n                                for key in attribute_keys\n                            },\n                        },\n                    },\n                    \"required\": [\"extraction_class\", \"extraction_text\"],\n                },\n            },\n        },\n        \"required\": [\"extractions\"],\n    }\n\n    # Remove enum if no classes found\n    if not extraction_classes:\n      del schema_dict[\"properties\"][\"extractions\"][\"items\"][\"properties\"][\n          \"extraction_class\"\n      ][\"enum\"]\n\n    return cls(schema_dict, strict_mode=True)\n\n  def to_provider_config(self) -> dict[str, Any]:\n    \"\"\"Convert schema to provider-specific configuration.\n\n    This is called after from_examples() and returns kwargs that will be\n    passed to the provider's __init__ method. The provider can then use\n    these during inference.\n\n    Returns:\n      Dictionary of provider kwargs that will be passed to the model.\n      In this example, we return both the schema and a flag to enable\n      structured output mode.\n\n    Note:\n      These kwargs are merged with user-provided kwargs, with user values\n      taking precedence (caller-wins merge semantics).\n    \"\"\"\n    return {\n        \"response_schema\": self._schema_dict,\n        \"enable_structured_output\": True,\n        \"output_format\": \"json\",\n    }\n\n  @property\n  def supports_strict_mode(self) -> bool:\n    \"\"\"Whether this schema guarantees valid structured output.\n\n    Returns:\n      True if the provider will emit valid JSON without needing\n      Markdown fences for extraction.\n    \"\"\"\n    return self._strict_mode\n\n  @property\n  def schema_dict(self) -> dict[str, Any]:\n    \"\"\"Access the underlying schema dictionary.\n\n    Returns:\n      The JSON schema dictionary.\n    \"\"\"\n    return self._schema_dict\n"
  },
  {
    "path": "examples/custom_provider_plugin/pyproject.toml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"langextract-provider-example\"  # Change to your package name\nversion = \"0.1.0\"  # Update version for releases\ndescription = \"Example custom provider plugin for LangExtract\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = {text = \"Apache-2.0\"}\ndependencies = [\n    # Uncomment when creating a standalone plugin package:\n    # \"langextract\",  # Will install latest version\n    \"google-genai>=0.2.0\",  # Replace with your backend's SDK\n]\n\n# Register the provider with LangExtract's plugin system\n[project.entry-points.\"langextract.providers\"]\ncustom_gemini = \"langextract_provider_example:CustomGeminiProvider\"\n\n[tool.setuptools.packages.find]\nwhere = [\".\"]\ninclude = [\"langextract_provider_example*\"]\n"
  },
  {
    "path": "examples/custom_provider_plugin/test_example_provider.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Simple test for the custom provider plugin.\"\"\"\n\nimport os\n\nimport dotenv\n# Import the provider to trigger registration with LangExtract\n# Note: This manual import is only needed when running without installation.\n# After `pip install -e .`, the entry point system handles this automatically.\nfrom langextract_provider_example import CustomGeminiProvider  # noqa: F401\n\nimport langextract as lx\n\n\ndef main():\n  \"\"\"Test the custom provider.\"\"\"\n  dotenv.load_dotenv(override=True)\n  api_key = os.getenv(\"GEMINI_API_KEY\") or os.getenv(\"LANGEXTRACT_API_KEY\")\n\n  if not api_key:\n    print(\"Set GEMINI_API_KEY or LANGEXTRACT_API_KEY to test\")\n    return\n\n  config = lx.factory.ModelConfig(\n      model_id=\"gemini-2.5-flash\",\n      provider=\"CustomGeminiProvider\",\n      provider_kwargs={\"api_key\": api_key},\n  )\n  model = lx.factory.create_model(config)\n\n  print(f\"✓ Created {model.__class__.__name__}\")\n\n  # Test inference\n  prompts = [\"Say hello\"]\n  results = list(model.infer(prompts))\n\n  if results and results[0]:\n    print(f\"✓ Inference worked: {results[0][0].output[:50]}...\")\n  else:\n    print(\"✗ No response\")\n\n\nif __name__ == \"__main__\":\n  main()\n"
  },
  {
    "path": "examples/notebooks/romeo_juliet_extraction.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"header\"\n      },\n      \"source\": [\n        \"# Romeo and Juliet Text Extraction with LangExtract\\n\",\n        \"\\n\",\n        \"This notebook demonstrates extracting characters, emotions, and relationships from Shakespeare's Romeo and Juliet using LangExtract.\\n\",\n        \"\\n\",\n        \"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/langextract/blob/main/examples/notebooks/romeo_juliet_extraction.ipynb)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"setup_header\"\n      },\n      \"source\": [\n        \"## Setup\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 8,\n      \"metadata\": {\n        \"id\": \"install\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Note: you may need to restart the kernel to use updated packages.\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# Install LangExtract\\n\",\n        \"%pip install -q langextract\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 9,\n      \"metadata\": {\n        \"id\": \"api_key\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Set up your Gemini API key\\n\",\n        \"# Get your key from: https://aistudio.google.com/app/apikey\\n\",\n        \"import os\\n\",\n        \"from getpass import getpass\\n\",\n        \"\\n\",\n        \"if 'GEMINI_API_KEY' not in os.environ:\\n\",\n        \"    os.environ['GEMINI_API_KEY'] = getpass('Enter your Gemini API key: ')\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"define_header\"\n      },\n      \"source\": [\n        \"## Define Extraction Task\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 10,\n      \"metadata\": {\n        \"id\": \"setup_extraction\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import langextract as lx\\n\",\n        \"import textwrap\\n\",\n        \"\\n\",\n        \"# Define the extraction task\\n\",\n        \"prompt = textwrap.dedent(\\\"\\\"\\\"\\\\\\n\",\n        \"    Extract characters, emotions, and relationships in order of appearance.\\n\",\n        \"    Use exact text for extractions. Do not paraphrase or overlap entities.\\n\",\n        \"    Provide meaningful attributes for each entity to add context.\\\"\\\"\\\")\\n\",\n        \"\\n\",\n        \"# Provide a high-quality example\\n\",\n        \"examples = [\\n\",\n        \"    lx.data.ExampleData(\\n\",\n        \"        text=\\\"ROMEO. But soft! What light through yonder window breaks? It is the east, and Juliet is the sun.\\\",\\n\",\n        \"        extractions=[\\n\",\n        \"            lx.data.Extraction(\\n\",\n        \"                extraction_class=\\\"character\\\",\\n\",\n        \"                extraction_text=\\\"ROMEO\\\",\\n\",\n        \"                attributes={\\\"emotional_state\\\": \\\"wonder\\\"}\\n\",\n        \"            ),\\n\",\n        \"            lx.data.Extraction(\\n\",\n        \"                extraction_class=\\\"emotion\\\",\\n\",\n        \"                extraction_text=\\\"But soft!\\\",\\n\",\n        \"                attributes={\\\"feeling\\\": \\\"gentle awe\\\"}\\n\",\n        \"            ),\\n\",\n        \"            lx.data.Extraction(\\n\",\n        \"                extraction_class=\\\"relationship\\\",\\n\",\n        \"                extraction_text=\\\"Juliet is the sun\\\",\\n\",\n        \"                attributes={\\\"type\\\": \\\"metaphor\\\"}\\n\",\n        \"            ),\\n\",\n        \"        ]\\n\",\n        \"    )\\n\",\n        \"]\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"extract_header\"\n      },\n      \"source\": [\n        \"## Extract from Sample Text\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 11,\n      \"metadata\": {\n        \"id\": \"simple_extraction\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[94m\\u001b[1mLangExtract\\u001b[0m: model=\\u001b[92mgemini-2.5-flash\\u001b[0m, current=\\u001b[92m68\\u001b[0m chars, processed=\\u001b[92m68\\u001b[0m chars:  [00:01]\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[92m✓\\u001b[0m Extraction processing complete\\n\",\n            \"\\u001b[92m✓\\u001b[0m Extracted \\u001b[1m3\\u001b[0m entities (\\u001b[1m3\\u001b[0m unique types)\\n\",\n            \"  \\u001b[96m•\\u001b[0m Time: \\u001b[1m1.96s\\u001b[0m\\n\",\n            \"  \\u001b[96m•\\u001b[0m Speed: \\u001b[1m35\\u001b[0m chars/sec\\n\",\n            \"  \\u001b[96m•\\u001b[0m Chunks: \\u001b[1m1\\u001b[0m\\n\",\n            \"Extracted 3 entities:\\n\",\n            \"\\n\",\n            \"• character: 'Lady Juliet'\\n\",\n            \"  - emotional_state: longing\\n\",\n            \"• emotion: 'gazed longingly at the stars, her heart aching'\\n\",\n            \"  - feeling: melancholy longing\\n\",\n            \"• relationship: 'her heart aching for Romeo'\\n\",\n            \"  - type: romantic\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# Simple extraction from a short text\\n\",\n        \"input_text = \\\"Lady Juliet gazed longingly at the stars, her heart aching for Romeo\\\"\\n\",\n        \"\\n\",\n        \"result = lx.extract(\\n\",\n        \"    text_or_documents=input_text,\\n\",\n        \"    prompt_description=prompt,\\n\",\n        \"    examples=examples,\\n\",\n        \"    model_id=\\\"gemini-2.5-flash\\\",\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"# Display results\\n\",\n        \"print(f\\\"Extracted {len(result.extractions)} entities:\\\\n\\\")\\n\",\n        \"for extraction in result.extractions:\\n\",\n        \"    print(f\\\"• {extraction.extraction_class}: '{extraction.extraction_text}'\\\")\\n\",\n        \"    if extraction.attributes:\\n\",\n        \"        for key, value in extraction.attributes.items():\\n\",\n        \"            print(f\\\"  - {key}: {value}\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"viz_header\"\n      },\n      \"source\": [\n        \"## Interactive Visualization\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 12,\n      \"metadata\": {\n        \"id\": \"visualization\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[94m\\u001b[1mLangExtract\\u001b[0m: Saving to \\u001b[92mromeo_juliet.jsonl\\u001b[0m: 1 docs [00:00, 995.33 docs/s]\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[92m✓\\u001b[0m Saved \\u001b[1m1\\u001b[0m documents to \\u001b[92mromeo_juliet.jsonl\\u001b[0m\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\n\",\n            \"\\u001b[94m\\u001b[1mLangExtract\\u001b[0m: Loading \\u001b[92mromeo_juliet.jsonl\\u001b[0m: 100%|██████████| 961/961 [00:00<00:00, 2.49MB/s]\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[92m✓\\u001b[0m Loaded \\u001b[1m1\\u001b[0m documents from \\u001b[92mromeo_juliet.jsonl\\u001b[0m\\n\",\n            \"Interactive visualization (hover over highlights to see attributes):\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\n\"\n          ]\n        },\n        {\n          \"data\": {\n            \"text/html\": [\n              \"<style>\\n\",\n              \".lx-highlight { position: relative; border-radius:3px; padding:1px 2px;}\\n\",\n              \".lx-highlight .lx-tooltip {\\n\",\n              \"  visibility: hidden;\\n\",\n              \"  opacity: 0;\\n\",\n              \"  transition: opacity 0.2s ease-in-out;\\n\",\n              \"  background: #333;\\n\",\n              \"  color: #fff;\\n\",\n              \"  text-align: left;\\n\",\n              \"  border-radius: 4px;\\n\",\n              \"  padding: 6px 8px;\\n\",\n              \"  position: absolute;\\n\",\n              \"  z-index: 1000;\\n\",\n              \"  bottom: 125%;\\n\",\n              \"  left: 50%;\\n\",\n              \"  transform: translateX(-50%);\\n\",\n              \"  font-size: 12px;\\n\",\n              \"  max-width: 240px;\\n\",\n              \"  white-space: normal;\\n\",\n              \"  box-shadow: 0 2px 6px rgba(0,0,0,0.3);\\n\",\n              \"}\\n\",\n              \".lx-highlight:hover .lx-tooltip { visibility: visible; opacity:1; }\\n\",\n              \".lx-animated-wrapper { max-width: 100%; font-family: Arial, sans-serif; }\\n\",\n              \".lx-controls {\\n\",\n              \"  background: #fafafa; border: 1px solid #90caf9; border-radius: 8px;\\n\",\n              \"  padding: 12px; margin-bottom: 16px;\\n\",\n              \"}\\n\",\n              \".lx-button-row {\\n\",\n              \"  display: flex; justify-content: center; gap: 8px; margin-bottom: 12px;\\n\",\n              \"}\\n\",\n              \".lx-control-btn {\\n\",\n              \"  background: #4285f4; color: white; border: none; border-radius: 4px;\\n\",\n              \"  padding: 8px 16px; cursor: pointer; font-size: 13px; font-weight: 500;\\n\",\n              \"  transition: background-color 0.2s;\\n\",\n              \"}\\n\",\n              \".lx-control-btn:hover { background: #3367d6; }\\n\",\n              \".lx-progress-container {\\n\",\n              \"  margin-bottom: 8px;\\n\",\n              \"}\\n\",\n              \".lx-progress-slider {\\n\",\n              \"  width: 100%; margin: 0; appearance: none; height: 6px;\\n\",\n              \"  background: #ddd; border-radius: 3px; outline: none;\\n\",\n              \"}\\n\",\n              \".lx-progress-slider::-webkit-slider-thumb {\\n\",\n              \"  appearance: none; width: 18px; height: 18px; background: #4285f4;\\n\",\n              \"  border-radius: 50%; cursor: pointer;\\n\",\n              \"}\\n\",\n              \".lx-progress-slider::-moz-range-thumb {\\n\",\n              \"  width: 18px; height: 18px; background: #4285f4; border-radius: 50%;\\n\",\n              \"  cursor: pointer; border: none;\\n\",\n              \"}\\n\",\n              \".lx-status-text {\\n\",\n              \"  text-align: center; font-size: 12px; color: #666; margin-top: 4px;\\n\",\n              \"}\\n\",\n              \".lx-text-window {\\n\",\n              \"  font-family: monospace; white-space: pre-wrap; border: 1px solid #90caf9;\\n\",\n              \"  padding: 12px; max-height: 260px; overflow-y: auto; margin-bottom: 12px;\\n\",\n              \"  line-height: 1.6;\\n\",\n              \"}\\n\",\n              \".lx-attributes-panel {\\n\",\n              \"  background: #fafafa; border: 1px solid #90caf9; border-radius: 6px;\\n\",\n              \"  padding: 8px 10px; margin-top: 8px; font-size: 13px;\\n\",\n              \"}\\n\",\n              \".lx-current-highlight {\\n\",\n              \"  border-bottom: 4px solid #ff4444;\\n\",\n              \"  font-weight: bold;\\n\",\n              \"  animation: lx-pulse 1s ease-in-out;\\n\",\n              \"}\\n\",\n              \"@keyframes lx-pulse {\\n\",\n              \"  0% { text-decoration-color: #ff4444; }\\n\",\n              \"  50% { text-decoration-color: #ff0000; }\\n\",\n              \"  100% { text-decoration-color: #ff4444; }\\n\",\n              \"}\\n\",\n              \".lx-legend {\\n\",\n              \"  font-size: 12px; margin-bottom: 8px;\\n\",\n              \"  padding-bottom: 8px; border-bottom: 1px solid #e0e0e0;\\n\",\n              \"}\\n\",\n              \".lx-label {\\n\",\n              \"  display: inline-block;\\n\",\n              \"  padding: 2px 4px;\\n\",\n              \"  border-radius: 3px;\\n\",\n              \"  margin-right: 4px;\\n\",\n              \"  color: #000;\\n\",\n              \"}\\n\",\n              \".lx-attr-key {\\n\",\n              \"  font-weight: 600;\\n\",\n              \"  color: #1565c0;\\n\",\n              \"  letter-spacing: 0.3px;\\n\",\n              \"}\\n\",\n              \".lx-attr-value {\\n\",\n              \"  font-weight: 400;\\n\",\n              \"  opacity: 0.85;\\n\",\n              \"  letter-spacing: 0.2px;\\n\",\n              \"}\\n\",\n              \"\\n\",\n              \"/* Add optimizations with larger fonts and better readability for GIFs */\\n\",\n              \".lx-gif-optimized .lx-text-window { font-size: 16px; line-height: 1.8; }\\n\",\n              \".lx-gif-optimized .lx-attributes-panel { font-size: 15px; }\\n\",\n              \".lx-gif-optimized .lx-current-highlight { text-decoration-thickness: 4px; }\\n\",\n              \"</style>\\n\",\n              \"<div class=\\\"lx-animated-wrapper lx-gif-optimized\\\">\\n\",\n              \"  <div class=\\\"lx-attributes-panel\\\">\\n\",\n              \"    <div class=\\\"lx-legend\\\">Highlights Legend: <span class=\\\"lx-label\\\" style=\\\"background-color:#D2E3FC;\\\">character</span> <span class=\\\"lx-label\\\" style=\\\"background-color:#C8E6C9;\\\">emotion</span> <span class=\\\"lx-label\\\" style=\\\"background-color:#FEF0C3;\\\">relationship</span></div>\\n\",\n              \"    <div id=\\\"attributesContainer\\\"></div>\\n\",\n              \"  </div>\\n\",\n              \"  <div class=\\\"lx-text-window\\\" id=\\\"textWindow\\\">\\n\",\n              \"    <span class=\\\"lx-highlight lx-current-highlight\\\" data-idx=\\\"0\\\" style=\\\"background-color:#D2E3FC;\\\">Lady Juliet</span> <span class=\\\"lx-highlight\\\" data-idx=\\\"1\\\" style=\\\"background-color:#C8E6C9;\\\">gazed longingly at the stars, <span class=\\\"lx-highlight\\\" data-idx=\\\"2\\\" style=\\\"background-color:#FEF0C3;\\\">her heart aching</span> for Romeo</span>\\n\",\n              \"  </div>\\n\",\n              \"  <div class=\\\"lx-controls\\\">\\n\",\n              \"    <div class=\\\"lx-button-row\\\">\\n\",\n              \"      <button class=\\\"lx-control-btn\\\" onclick=\\\"playPause()\\\">▶️ Play</button>\\n\",\n              \"      <button class=\\\"lx-control-btn\\\" onclick=\\\"prevExtraction()\\\">⏮ Previous</button>\\n\",\n              \"      <button class=\\\"lx-control-btn\\\" onclick=\\\"nextExtraction()\\\">⏭ Next</button>\\n\",\n              \"    </div>\\n\",\n              \"    <div class=\\\"lx-progress-container\\\">\\n\",\n              \"      <input type=\\\"range\\\" id=\\\"progressSlider\\\" class=\\\"lx-progress-slider\\\"\\n\",\n              \"             min=\\\"0\\\" max=\\\"2\\\" value=\\\"0\\\"\\n\",\n              \"             onchange=\\\"jumpToExtraction(this.value)\\\">\\n\",\n              \"    </div>\\n\",\n              \"    <div class=\\\"lx-status-text\\\">\\n\",\n              \"      Entity <span id=\\\"entityInfo\\\">1/3</span> |\\n\",\n              \"      Pos <span id=\\\"posInfo\\\">[0-11]</span>\\n\",\n              \"    </div>\\n\",\n              \"  </div>\\n\",\n              \"</div>\\n\",\n              \"\\n\",\n              \"<script>\\n\",\n              \"  (function() {\\n\",\n              \"    const extractions = [{\\\"index\\\": 0, \\\"class\\\": \\\"character\\\", \\\"text\\\": \\\"Lady Juliet\\\", \\\"color\\\": \\\"#D2E3FC\\\", \\\"startPos\\\": 0, \\\"endPos\\\": 11, \\\"beforeText\\\": \\\"\\\", \\\"extractionText\\\": \\\"Lady Juliet\\\", \\\"afterText\\\": \\\" gazed longingly at the stars, her heart aching for Romeo\\\", \\\"attributesHtml\\\": \\\"<div><strong>class:</strong> character</div><div><strong>attributes:</strong> {<span class=\\\\\\\"lx-attr-key\\\\\\\">emotional_state</span>: <span class=\\\\\\\"lx-attr-value\\\\\\\">longing</span>}</div>\\\"}, {\\\"index\\\": 1, \\\"class\\\": \\\"emotion\\\", \\\"text\\\": \\\"gazed longingly at the stars, her heart aching\\\", \\\"color\\\": \\\"#C8E6C9\\\", \\\"startPos\\\": 12, \\\"endPos\\\": 58, \\\"beforeText\\\": \\\"Lady Juliet \\\", \\\"extractionText\\\": \\\"gazed longingly at the stars, her heart aching\\\", \\\"afterText\\\": \\\" for Romeo\\\", \\\"attributesHtml\\\": \\\"<div><strong>class:</strong> emotion</div><div><strong>attributes:</strong> {<span class=\\\\\\\"lx-attr-key\\\\\\\">feeling</span>: <span class=\\\\\\\"lx-attr-value\\\\\\\">melancholy longing</span>}</div>\\\"}, {\\\"index\\\": 2, \\\"class\\\": \\\"relationship\\\", \\\"text\\\": \\\"her heart aching for Romeo\\\", \\\"color\\\": \\\"#FEF0C3\\\", \\\"startPos\\\": 42, \\\"endPos\\\": 68, \\\"beforeText\\\": \\\"Lady Juliet gazed longingly at the stars, \\\", \\\"extractionText\\\": \\\"her heart aching for Romeo\\\", \\\"afterText\\\": \\\"\\\", \\\"attributesHtml\\\": \\\"<div><strong>class:</strong> relationship</div><div><strong>attributes:</strong> {<span class=\\\\\\\"lx-attr-key\\\\\\\">type</span>: <span class=\\\\\\\"lx-attr-value\\\\\\\">romantic</span>}</div>\\\"}];\\n\",\n              \"    let currentIndex = 0;\\n\",\n              \"    let isPlaying = false;\\n\",\n              \"    let animationInterval = null;\\n\",\n              \"    let animationSpeed = 1.0;\\n\",\n              \"\\n\",\n              \"    function updateDisplay() {\\n\",\n              \"      const extraction = extractions[currentIndex];\\n\",\n              \"      if (!extraction) return;\\n\",\n              \"\\n\",\n              \"      document.getElementById('attributesContainer').innerHTML = extraction.attributesHtml;\\n\",\n              \"      document.getElementById('entityInfo').textContent = (currentIndex + 1) + '/' + extractions.length;\\n\",\n              \"      document.getElementById('posInfo').textContent = '[' + extraction.startPos + '-' + extraction.endPos + ']';\\n\",\n              \"      document.getElementById('progressSlider').value = currentIndex;\\n\",\n              \"\\n\",\n              \"      const playBtn = document.querySelector('.lx-control-btn');\\n\",\n              \"      if (playBtn) playBtn.textContent = isPlaying ? '⏸ Pause' : '▶️ Play';\\n\",\n              \"\\n\",\n              \"      const prevHighlight = document.querySelector('.lx-text-window .lx-current-highlight');\\n\",\n              \"      if (prevHighlight) prevHighlight.classList.remove('lx-current-highlight');\\n\",\n              \"      const currentSpan = document.querySelector('.lx-text-window span[data-idx=\\\"' + currentIndex + '\\\"]');\\n\",\n              \"      if (currentSpan) {\\n\",\n              \"        currentSpan.classList.add('lx-current-highlight');\\n\",\n              \"        currentSpan.scrollIntoView({block: 'center', behavior: 'smooth'});\\n\",\n              \"      }\\n\",\n              \"    }\\n\",\n              \"\\n\",\n              \"    function nextExtraction() {\\n\",\n              \"      currentIndex = (currentIndex + 1) % extractions.length;\\n\",\n              \"      updateDisplay();\\n\",\n              \"    }\\n\",\n              \"\\n\",\n              \"    function prevExtraction() {\\n\",\n              \"      currentIndex = (currentIndex - 1 + extractions.length) % extractions.length;\\n\",\n              \"      updateDisplay();\\n\",\n              \"    }\\n\",\n              \"\\n\",\n              \"    function jumpToExtraction(index) {\\n\",\n              \"      currentIndex = parseInt(index);\\n\",\n              \"      updateDisplay();\\n\",\n              \"    }\\n\",\n              \"\\n\",\n              \"    function playPause() {\\n\",\n              \"      if (isPlaying) {\\n\",\n              \"        clearInterval(animationInterval);\\n\",\n              \"        isPlaying = false;\\n\",\n              \"      } else {\\n\",\n              \"        animationInterval = setInterval(nextExtraction, animationSpeed * 1000);\\n\",\n              \"        isPlaying = true;\\n\",\n              \"      }\\n\",\n              \"      updateDisplay();\\n\",\n              \"    }\\n\",\n              \"\\n\",\n              \"    window.playPause = playPause;\\n\",\n              \"    window.nextExtraction = nextExtraction;\\n\",\n              \"    window.prevExtraction = prevExtraction;\\n\",\n              \"    window.jumpToExtraction = jumpToExtraction;\\n\",\n              \"\\n\",\n              \"    updateDisplay();\\n\",\n              \"  })();\\n\",\n              \"</script>\"\n            ],\n            \"text/plain\": [\n              \"<IPython.core.display.HTML object>\"\n            ]\n          },\n          \"execution_count\": 12,\n          \"metadata\": {},\n          \"output_type\": \"execute_result\"\n        }\n      ],\n      \"source\": [\n        \"# Save results to JSONL\\n\",\n        \"lx.io.save_annotated_documents([result], output_name=\\\"romeo_juliet.jsonl\\\", output_dir=\\\".\\\")\\n\",\n        \"\\n\",\n        \"# Generate interactive visualization\\n\",\n        \"html_content = lx.visualize(\\\"romeo_juliet.jsonl\\\")\\n\",\n        \"\\n\",\n        \"# Display in notebook\\n\",\n        \"print(\\\"Interactive visualization (hover over highlights to see attributes):\\\")\\n\",\n        \"html_content\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 13,\n      \"metadata\": {\n        \"id\": \"save_viz\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"✓ Visualization saved to romeo_juliet_visualization.html\\n\",\n            \"You can download this file from the Files panel on the left.\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# Save visualization to file (for downloading)\\n\",\n        \"with open(\\\"romeo_juliet_visualization.html\\\", \\\"w\\\") as f:\\n\",\n        \"    # Handle both Jupyter (HTML object) and non-Jupyter (string) environments\\n\",\n        \"    if hasattr(html_content, 'data'):\\n\",\n        \"        f.write(html_content.data)\\n\",\n        \"    else:\\n\",\n        \"        f.write(html_content)\\n\",\n        \"\\n\",\n        \"print(\\\"✓ Visualization saved to romeo_juliet_visualization.html\\\")\\n\",\n        \"print(\\\"You can download this file from the Files panel on the left.\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"experiment_header\"\n      },\n      \"source\": [\n        \"## Try Your Own Text\\n\",\n        \"\\n\",\n        \"Experiment with your own Shakespeare quotes or any literary text!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 14,\n      \"metadata\": {\n        \"id\": \"experiment\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[94m\\u001b[1mLangExtract\\u001b[0m: model=\\u001b[92mgemini-2.5-flash\\u001b[0m, current=\\u001b[92m163\\u001b[0m chars, processed=\\u001b[92m163\\u001b[0m chars:  [00:05]\"\n          ]\n        },\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\u001b[92m✓\\u001b[0m Extraction processing complete\\n\",\n            \"\\u001b[92m✓\\u001b[0m Extracted \\u001b[1m6\\u001b[0m entities (\\u001b[1m3\\u001b[0m unique types)\\n\",\n            \"  \\u001b[96m•\\u001b[0m Time: \\u001b[1m5.84s\\u001b[0m\\n\",\n            \"  \\u001b[96m•\\u001b[0m Speed: \\u001b[1m28\\u001b[0m chars/sec\\n\",\n            \"  \\u001b[96m•\\u001b[0m Chunks: \\u001b[1m1\\u001b[0m\\n\",\n            \"Extractions from your text:\\n\",\n            \"\\n\",\n            \"• character: 'JULIET'\\n\",\n            \"  - emotional_state: longing\\n\",\n            \"• emotion: 'O Romeo, Romeo! wherefore art thou Romeo?'\\n\",\n            \"  - feeling: desperate questioning\\n\",\n            \"• relationship: 'thy father'\\n\",\n            \"  - type: familial\\n\",\n            \"• relationship: 'thy name'\\n\",\n            \"  - type: lineage\\n\",\n            \"• relationship: 'my love'\\n\",\n            \"  - type: romantic bond\\n\",\n            \"• relationship: 'Capulet'\\n\",\n            \"  - type: family affiliation\\n\"\n          ]\n        },\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# Try your own text\\n\",\n        \"your_text = \\\"\\\"\\\"\\n\",\n        \"JULIET: O Romeo, Romeo! wherefore art thou Romeo?\\n\",\n        \"Deny thy father and refuse thy name;\\n\",\n        \"Or, if thou wilt not, be but sworn my love,\\n\",\n        \"And I'll no longer be a Capulet.\\n\",\n        \"\\\"\\\"\\\"\\n\",\n        \"\\n\",\n        \"custom_result = lx.extract(\\n\",\n        \"    text_or_documents=your_text,\\n\",\n        \"    prompt_description=prompt,\\n\",\n        \"    examples=examples,\\n\",\n        \"    model_id=\\\"gemini-2.5-flash\\\",\\n\",\n        \")\\n\",\n        \"\\n\",\n        \"print(\\\"Extractions from your text:\\\\n\\\")\\n\",\n        \"for e in custom_result.extractions:\\n\",\n        \"    print(f\\\"• {e.extraction_class}: '{e.extraction_text}'\\\")\\n\",\n        \"    if e.attributes:\\n\",\n        \"        for key, value in e.attributes.items():\\n\",\n        \"            print(f\\\"  - {key}: {value}\\\")\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"Romeo and Juliet Text Extraction with LangExtract\",\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"venv\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.13.5\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/ollama/.dockerignore",
    "content": "# Ignore Python cache\n__pycache__/\n*.pyc\n*.pyo\n*.pyd\n.Python\n\n# Ignore version control\n.git/\n.gitignore\n\n# Ignore OS files\n.DS_Store\nThumbs.db\n\n# Ignore virtual environments\nvenv/\nenv/\n.venv/\n\n# Ignore IDE files\n.vscode/\n.idea/\n*.swp\n*.swo\n\n# Ignore test artifacts\n.pytest_cache/\n.coverage\nhtmlcov/\n\n# Ignore build artifacts\nbuild/\ndist/\n*.egg-info/\n"
  },
  {
    "path": "examples/ollama/Dockerfile",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nFROM python:3.11-slim-bookworm\n\nWORKDIR /app\n\nRUN pip install langextract\n\nCOPY demo_ollama.py .\n\nCMD [\"python\", \"demo_ollama.py\"]\n"
  },
  {
    "path": "examples/ollama/README.md",
    "content": "# Ollama Examples\n\nThis directory contains examples for using LangExtract with Ollama for local LLM inference.\n\nFor setup instructions and documentation, see the [main README's Ollama section](../../README.md#using-local-llms-with-ollama).\n\n## Quick Reference\n\n**Option 1: Run locally**\n```bash\n# Install and start Ollama\nollama pull gemma2:2b\nollama serve  # Keep this running in a separate terminal\n\n# Run the demo\npython demo_ollama.py\n```\n\n**Option 2: Run with Docker**\n```bash\n# Runs both Ollama and the demo in containers\ndocker-compose up\n```\n\n## Files\n\n- `demo_ollama.py` - Comprehensive extraction examples demonstrating Ollama on README examples\n- `docker-compose.yml` - Production-ready Docker setup with health checks\n- `Dockerfile` - Container definition for LangExtract\n\n## Configuration Options\n\n### Timeout Settings\n\nFor slower models or large prompts, you may need to increase the timeout (default: 120 seconds):\n\n```python\nimport langextract as lx\n\nresult = lx.extract(\n    text_or_documents=input_text,\n    prompt_description=prompt,\n    examples=examples,\n    model_id=\"llama3.1:70b\",  # Larger model may need more time\n    timeout=300,  # 5 minutes\n    model_url=\"http://localhost:11434\",\n)\n```\n\nOr using ModelConfig:\n\n```python\nconfig = lx.factory.ModelConfig(\n    model_id=\"llama3.1:70b\",\n    provider_kwargs={\n        \"model_url\": \"http://localhost:11434\",\n        \"timeout\": 300,  # 5 minutes\n    }\n)\n```\n\n## Model License\n\nOllama models come with their own licenses. For example:\n- Gemma models: [Gemma Terms of Use](https://ai.google.dev/gemma/terms)\n- Llama models: [Meta Llama License](https://llama.meta.com/llama-downloads/)\n\nPlease review the license for any model you use.\n"
  },
  {
    "path": "examples/ollama/demo_ollama.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Comprehensive demo of Ollama integration with FormatHandler.\n\nThis example demonstrates:\n- Using the pre-configured OLLAMA_FORMAT_HANDLER for consistent configuration\n- Running multiple extraction examples with progress bars\n- Generating interactive HTML visualizations\n- Handling various extraction patterns (NER, relationships, dialogue extraction)\n\nPrerequisites:\n1. Install Ollama: https://ollama.com/\n2. Pull the model: ollama pull gemma2:2b\n3. Start Ollama: ollama serve\n\nUsage:\n    python demo_ollama.py [--model MODEL_NAME]\n\nExamples:\n    # Use default model (gemma2:2b)\n    python demo_ollama.py\n\n    # Use a different model\n    python demo_ollama.py --model llama3.2:3b\n\nOutput:\n    Results are saved to test_output/ directory (gitignored)\n    - JSONL files with extraction data\n    - Interactive HTML visualizations\n\"\"\"\n\nimport argparse\nimport os\nfrom pathlib import Path\nimport sys\nimport textwrap\nimport time\nimport traceback\nimport urllib.error\nimport urllib.request\n\nimport dotenv\n\nimport langextract as lx\nfrom langextract.providers import ollama\n\ndotenv.load_dotenv(override=True)\n\nDEFAULT_MODEL = \"gemma2:2b\"\nDEFAULT_OLLAMA_URL = os.environ.get(\"OLLAMA_HOST\", \"http://localhost:11434\")\nOUTPUT_DIR = \"test_output\"\n\n\ndef check_ollama_available(url: str = DEFAULT_OLLAMA_URL) -> bool:\n  \"\"\"Check if Ollama is available at the specified URL.\"\"\"\n  try:\n    with urllib.request.urlopen(f\"{url}/api/tags\", timeout=2) as response:\n      return response.status == 200\n  except (urllib.error.URLError, TimeoutError):\n    return False\n\n\ndef ensure_output_directory() -> Path:\n  \"\"\"Create output directory if it doesn't exist.\"\"\"\n  output_path = Path(OUTPUT_DIR)\n  output_path.mkdir(exist_ok=True)\n  return output_path\n\n\ndef print_header(title: str, width: int = 80) -> None:\n  \"\"\"Print a formatted header.\"\"\"\n  print(\"\\n\" + \"=\" * width)\n  print(f\"  {title}\")\n  print(\"=\" * width)\n\n\ndef print_section(title: str, width: int = 60) -> None:\n  \"\"\"Print a formatted section.\"\"\"\n  print(f\"\\n▶ {title}\")\n  print(\"-\" * width)\n\n\ndef print_results_summary(extractions: list[lx.data.Extraction]) -> None:\n  \"\"\"Print a summary of extraction results.\"\"\"\n  if not extractions:\n    print(\"  No extractions found\")\n    return\n\n  class_counts = {}\n  for ext in extractions:\n    class_counts[ext.extraction_class] = (\n        class_counts.get(ext.extraction_class, 0) + 1\n    )\n\n  print(f\"  Total extractions: {len(extractions)}\")\n  print(\"  By type:\")\n  for cls, count in sorted(class_counts.items()):\n    print(f\"    • {cls}: {count}\")\n\n\ndef example_romeo_juliet(\n    model_id: str, model_url: str\n) -> lx.data.AnnotatedDocument | None:\n  \"\"\"Romeo & Juliet character and emotion extraction example.\"\"\"\n  print_section(\"Example 1: Romeo & Juliet - Characters and Emotions\")\n\n  prompt = textwrap.dedent(\"\"\"\\\n      Extract characters, emotions, and relationships in order of appearance.\n      Use exact text for extractions. Do not paraphrase or overlap entities.\n      Provide meaningful attributes for each entity to add context.\"\"\")\n\n  examples = [\n      lx.data.ExampleData(\n          text=(\n              \"ROMEO. But soft! What light through yonder window breaks? It is\"\n              \" the east, and Juliet is the sun.\"\n          ),\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"character\",\n                  extraction_text=\"ROMEO\",\n                  attributes={\"emotional_state\": \"wonder\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"emotion\",\n                  extraction_text=\"But soft!\",\n                  attributes={\"feeling\": \"gentle awe\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"relationship\",\n                  extraction_text=\"Juliet is the sun\",\n                  attributes={\"type\": \"metaphor\"},\n              ),\n          ],\n      )\n  ]\n\n  input_text = (\n      \"Lady Juliet gazed longingly at the stars, her heart aching for Romeo\"\n  )\n\n  print(f\"  Input: {input_text}\")\n  print(f\"  Model: {model_id}\")\n  print(\"\\n  Extracting...\")\n\n  result = lx.extract(\n      text_or_documents=input_text,\n      prompt_description=prompt,\n      examples=examples,\n      model_id=model_id,\n      model_url=model_url,\n      resolver_params={\"format_handler\": ollama.OLLAMA_FORMAT_HANDLER},\n      show_progress=True,\n  )\n\n  print(\"\\n  Results:\")\n  print_results_summary(result.extractions)\n\n  return result\n\n\ndef example_medication_ner(\n    model_id: str, model_url: str\n) -> lx.data.AnnotatedDocument | None:\n  \"\"\"Medical named entity recognition example.\"\"\"\n  print_section(\"Example 2: Medication Named Entity Recognition\")\n\n  input_text = \"Patient took 400 mg PO Ibuprofen q4h for two days.\"\n\n  prompt_description = (\n      \"Extract medication information including medication name, dosage, route,\"\n      \" frequency, and duration in the order they appear in the text.\"\n  )\n\n  examples = [\n      lx.data.ExampleData(\n          text=\"Patient was given 250 mg IV Cefazolin TID for one week.\",\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"dosage\", extraction_text=\"250 mg\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"route\", extraction_text=\"IV\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"medication\", extraction_text=\"Cefazolin\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"frequency\", extraction_text=\"TID\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"duration\", extraction_text=\"for one week\"\n              ),\n          ],\n      )\n  ]\n\n  print(f\"  Input: {input_text}\")\n  print(f\"  Model: {model_id}\")\n  print(\"\\n  Extracting...\")\n\n  result = lx.extract(\n      text_or_documents=input_text,\n      prompt_description=prompt_description,\n      examples=examples,\n      model_id=model_id,\n      model_url=model_url,\n      resolver_params={\"format_handler\": ollama.OLLAMA_FORMAT_HANDLER},\n      show_progress=True,\n  )\n\n  print(\"\\n  Results:\")\n  print_results_summary(result.extractions)\n\n  return result\n\n\ndef example_medication_relationships(\n    model_id: str, model_url: str\n) -> lx.data.AnnotatedDocument | None:\n  \"\"\"Medication relationship extraction with grouped attributes.\"\"\"\n  print_section(\"Example 3: Medication Relationship Extraction\")\n\n  input_text = textwrap.dedent(\"\"\"\n      The patient was prescribed Lisinopril and Metformin last month.\n      He takes the Lisinopril 10mg daily for hypertension, but often misses\n      his Metformin 500mg dose which should be taken twice daily for diabetes.\n  \"\"\").strip()\n\n  prompt_description = textwrap.dedent(\"\"\"\n      Extract medications with their details, using attributes to group related information:\n\n      1. Extract entities in the order they appear in the text\n      2. Each entity must have a 'medication_group' attribute linking it to its medication\n      3. All details about a medication should share the same medication_group value\n  \"\"\").strip()\n\n  examples = [\n      lx.data.ExampleData(\n          text=(\n              \"Patient takes Aspirin 100mg daily for heart health and\"\n              \" Simvastatin 20mg at bedtime.\"\n          ),\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Aspirin\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"100mg\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"frequency\",\n                  extraction_text=\"daily\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"heart health\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Simvastatin\",\n                  attributes={\"medication_group\": \"Simvastatin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"20mg\",\n                  attributes={\"medication_group\": \"Simvastatin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"frequency\",\n                  extraction_text=\"at bedtime\",\n                  attributes={\"medication_group\": \"Simvastatin\"},\n              ),\n          ],\n      )\n  ]\n\n  print(f\"  Input: {input_text[:80]}...\")\n  print(f\"  Model: {model_id}\")\n  print(\"\\n  Extracting...\")\n\n  result = lx.extract(\n      text_or_documents=input_text,\n      prompt_description=prompt_description,\n      examples=examples,\n      model_id=model_id,\n      model_url=model_url,\n      resolver_params={\"format_handler\": ollama.OLLAMA_FORMAT_HANDLER},\n      show_progress=True,\n  )\n\n  print(\"\\n  Results:\")\n  print_results_summary(result.extractions)\n\n  medication_groups = {}\n  for ext in result.extractions:\n    if ext.attributes and \"medication_group\" in ext.attributes:\n      group_name = ext.attributes[\"medication_group\"]\n      medication_groups.setdefault(group_name, []).append(ext)\n\n  if medication_groups:\n    print(\"\\n  Grouped by medication:\")\n    for med_name in sorted(medication_groups.keys()):\n      print(f\"    {med_name}: {len(medication_groups[med_name])} attributes\")\n\n  return result\n\n\ndef example_shakespeare_dialogue(\n    model_id: str, model_url: str\n) -> lx.data.AnnotatedDocument | None:\n  \"\"\"Extract character dialogue from Shakespeare play excerpt.\"\"\"\n  print_section(\"Example 4: Shakespeare Dialogue Extraction\")\n\n  long_text = textwrap.dedent(\"\"\"\n      Act I, Scene I. Verona. A public place.\n\n      Enter SAMPSON and GREGORY, armed with swords and bucklers.\n\n      SAMPSON: Gregory, on my word, we'll not carry coals.\n      GREGORY: No, for then we should be colliers.\n      SAMPSON: I mean, an we be in choler, we'll draw.\n      GREGORY: Ay, while you live, draw your neck out of collar.\n\n      Enter ABRAHAM and BALTHASAR.\n\n      ABRAHAM: Do you bite your thumb at us, sir?\n      SAMPSON: I do bite my thumb, sir.\n      ABRAHAM: Do you bite your thumb at us, sir?\n      SAMPSON: No, sir, I do not bite my thumb at you, sir, but I bite my thumb, sir.\n      GREGORY: Do you quarrel, sir?\n      ABRAHAM: Quarrel, sir? No, sir.\n\n      Enter BENVOLIO.\n\n      BENVOLIO: Part, fools! Put up your swords. You know not what you do.\n\n      Enter TYBALT.\n\n      TYBALT: What, art thou drawn among these heartless hinds?\n      Turn thee, Benvolio; look upon thy death.\n      BENVOLIO: I do but keep the peace. Put up thy sword,\n      Or manage it to part these men with me.\n      TYBALT: What, drawn, and talk of peace? I hate the word,\n      As I hate hell, all Montagues, and thee.\n      Have at thee, coward!\n  \"\"\").strip()\n\n  prompt = (\n      \"Extract all character names and their dialogue in order of appearance.\"\n  )\n\n  examples = [\n      lx.data.ExampleData(\n          text=\"JULIET: O Romeo, Romeo! Wherefore art thou Romeo?\",\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"character\", extraction_text=\"JULIET\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=\"dialogue\",\n                  extraction_text=\"O Romeo, Romeo! Wherefore art thou Romeo?\",\n                  attributes={\"speaker\": \"JULIET\"},\n              ),\n          ],\n      )\n  ]\n\n  print(f\"  Input: Romeo and Juliet Act I, Scene I ({len(long_text)} chars)\")\n  print(f\"  Model: {model_id}\")\n  print(\"  Note: Automatically chunked for longer text processing\")\n  print(\"\\n  Extracting...\")\n\n  result = lx.extract(\n      text_or_documents=long_text,\n      prompt_description=prompt,\n      examples=examples,\n      model_id=model_id,\n      model_url=model_url,\n      resolver_params={\"format_handler\": ollama.OLLAMA_FORMAT_HANDLER},\n      max_char_buffer=500,\n      show_progress=True,\n  )\n\n  print(\"\\n  Results:\")\n  print_results_summary(result.extractions)\n\n  characters = set(\n      ext.extraction_text\n      for ext in result.extractions\n      if ext.extraction_class == \"character\"\n  )\n  if characters:\n    print(\"\\n  Characters found: \" + \", \".join(sorted(characters)))\n\n  return result\n\n\ndef save_results(\n    results: list[tuple[str, lx.data.AnnotatedDocument | None]],\n    output_dir: Path,\n) -> None:\n  \"\"\"Save all results to JSONL and generate HTML visualizations.\"\"\"\n  print_header(\"Saving Results and Generating Visualizations\")\n\n  saved_files = []\n\n  for name, result in results:\n    if result is None:\n      print(f\"  ✗ Skipping {name} (no result)\")\n      continue\n\n    jsonl_file = f\"{name}.jsonl\"\n    jsonl_path = output_dir / jsonl_file\n\n    lx.io.save_annotated_documents(\n        [result], output_name=jsonl_file, output_dir=str(output_dir)\n    )\n    print(f\"  ✓ Saved {jsonl_path}\")\n\n    html_file = f\"{name}.html\"\n    html_path = output_dir / html_file\n\n    try:\n      html_content = lx.visualize(str(jsonl_path))\n      with open(html_path, \"w\") as f:\n        if hasattr(html_content, \"data\"):\n          f.write(html_content.data)\n        else:\n          f.write(html_content)\n      print(f\"  ✓ Generated {html_path}\")\n      saved_files.append((jsonl_path, html_path))\n    except Exception as e:\n      print(f\"  ✗ Failed to generate {html_path}: {e}\")\n\n  return saved_files\n\n\ndef main():\n  \"\"\"Run all examples and generate outputs.\"\"\"\n  parser = argparse.ArgumentParser(\n      description=\"Ollama + FormatHandler Demo\",\n      formatter_class=argparse.RawDescriptionHelpFormatter,\n      epilog=__doc__,\n  )\n  parser.add_argument(\n      \"--model\",\n      default=DEFAULT_MODEL,\n      help=f\"Ollama model to use (default: {DEFAULT_MODEL})\",\n  )\n  parser.add_argument(\n      \"--url\",\n      default=DEFAULT_OLLAMA_URL,\n      help=f\"Ollama server URL (default: {DEFAULT_OLLAMA_URL})\",\n  )\n  parser.add_argument(\n      \"--skip-examples\",\n      nargs=\"+\",\n      choices=[\"1\", \"2\", \"3\", \"4\"],\n      help=\"Skip specific examples (e.g., --skip-examples 3 4)\",\n  )\n\n  args = parser.parse_args()\n  skip_examples = set(args.skip_examples or [])\n\n  print_header(\"Ollama + FormatHandler Demo\")\n  print(\"\\nConfiguration:\")\n  print(f\"  Model: {args.model}\")\n  print(f\"  Server: {args.url}\")\n  print(f\"  Output: {OUTPUT_DIR}/\")\n  print(f\"  Format Handler: {ollama.OLLAMA_FORMAT_HANDLER}\")\n\n  print(\"\\nChecking Ollama server...\")\n  if not check_ollama_available(args.url):\n    print(f\"\\n⚠️  ERROR: Ollama not available at {args.url}\")\n    print(\"\\nTroubleshooting:\")\n    print(\"  1. Install Ollama: https://ollama.com/\")\n    print(\"  2. Start server: ollama serve\")\n    print(f\"  3. Pull model: ollama pull {args.model}\")\n    print(\"\\nFor Docker setup, see examples/ollama/docker-compose.yml\")\n    sys.exit(1)\n\n  print(\"✓ Ollama server is available\")\n\n  output_dir = ensure_output_directory()\n  print(\"✓ Output directory ready: \" + str(output_dir) + \"/\")\n\n  print_header(\"Running Examples\")\n  results = []\n\n  try:\n    if \"1\" not in skip_examples:\n      result = example_romeo_juliet(args.model, args.url)\n      results.append((\"romeo_juliet\", result))\n      time.sleep(0.5)\n\n    if \"2\" not in skip_examples:\n      result = example_medication_ner(args.model, args.url)\n      results.append((\"medication_ner\", result))\n      time.sleep(0.5)\n\n    if \"3\" not in skip_examples:\n      result = example_medication_relationships(args.model, args.url)\n      results.append((\"medication_relationships\", result))\n      time.sleep(0.5)\n\n    if \"4\" not in skip_examples:\n      result = example_shakespeare_dialogue(args.model, args.url)\n      results.append((\"shakespeare_dialogue\", result))\n\n  except KeyboardInterrupt:\n    print(\"\\n\\n⚠️  Interrupted by user\")\n    print(\"Saving completed results...\")\n  except Exception as e:\n    print(f\"\\n\\n✗ Error during execution: {e}\")\n    traceback.print_exc()\n    print(\"\\nSaving completed results...\")\n\n  if results:\n    save_results(results, output_dir)\n\n  print_header(\"Summary\")\n\n  successful = sum(1 for _, r in results if r is not None)\n  print(f\"\\n✓ Successfully ran {successful}/{len(results)} examples\")\n\n  if results:\n    print(f\"\\nOutput files in {output_dir}/:\")\n    for name, result in results:\n      if result is not None:\n        print(f\"  • {name}.jsonl - Extraction data\")\n        print(f\"  • {name}.html  - Interactive visualization\")\n\n    print(\"\\nTo view results:\")\n    print(\"  open \" + str(output_dir) + \"/romeo_juliet.html\")\n    print(\"\\nOr serve locally:\")\n    print(\"  python -m http.server 8000 --directory \" + str(output_dir))\n    print(\"  Then visit http://localhost:8000\")\n\n\nif __name__ == \"__main__\":\n  main()\n"
  },
  {
    "path": "examples/ollama/docker-compose.yml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nservices:\n  ollama:\n    image: ollama/ollama:0.5.4\n    ports:\n      - \"127.0.0.1:11434:11434\"  # Bind only to localhost for security\n    volumes:\n      - ollama-data:/root/.ollama  # Cross-platform support\n    command: serve\n    healthcheck:\n      test: [\"CMD\", \"curl\", \"-f\", \"http://localhost:11434/api/version\"]\n      interval: 5s\n      timeout: 3s\n      retries: 5\n      start_period: 10s\n\n  langextract:\n    build: .\n    depends_on:\n      ollama:\n        condition: service_healthy\n    environment:\n      - OLLAMA_HOST=http://ollama:11434\n    volumes:\n      - .:/app\n    command: python demo_ollama.py\n\nvolumes:\n  ollama-data:\n"
  },
  {
    "path": "langextract/__init__.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"LangExtract: Extract structured information from text with LLMs.\n\nThis package provides the main extract and visualize functions,\nwith lazy loading for other submodules accessed via attribute access.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport importlib\nimport sys\nfrom typing import Any, Dict\n\nfrom langextract import visualization\nfrom langextract.extraction import extract as extract_func\n\n__all__ = [\n    # Public convenience functions (thin wrappers)\n    \"extract\",\n    \"visualize\",\n    # Submodules exposed lazily on attribute access for ergonomics:\n    \"annotation\",\n    \"data\",\n    \"providers\",\n    \"schema\",\n    \"inference\",\n    \"factory\",\n    \"resolver\",\n    \"prompting\",\n    \"io\",\n    \"visualization\",\n    \"exceptions\",\n    \"core\",\n    \"plugins\",\n]\n\n_CACHE: Dict[str, Any] = {}\n\n\ndef extract(*args: Any, **kwargs: Any):\n  \"\"\"Top-level API: lx.extract(...).\"\"\"\n  return extract_func(*args, **kwargs)\n\n\ndef visualize(*args: Any, **kwargs: Any):\n  \"\"\"Top-level API: lx.visualize(...).\"\"\"\n  return visualization.visualize(*args, **kwargs)\n\n\n# PEP 562 lazy loading\n_LAZY_MODULES = {\n    \"annotation\": \"langextract.annotation\",\n    \"chunking\": \"langextract.chunking\",\n    \"data\": \"langextract.data\",\n    \"data_lib\": \"langextract.data_lib\",\n    \"debug_utils\": \"langextract.core.debug_utils\",\n    \"exceptions\": \"langextract.exceptions\",\n    \"factory\": \"langextract.factory\",\n    \"inference\": \"langextract.inference\",\n    \"io\": \"langextract.io\",\n    \"progress\": \"langextract.progress\",\n    \"prompting\": \"langextract.prompting\",\n    \"providers\": \"langextract.providers\",\n    \"resolver\": \"langextract.resolver\",\n    \"schema\": \"langextract.schema\",\n    \"tokenizer\": \"langextract.tokenizer\",\n    \"visualization\": \"langextract.visualization\",\n    \"core\": \"langextract.core\",\n    \"plugins\": \"langextract.plugins\",\n    \"registry\": \"langextract.registry\",  # Backward compat - will emit warning\n}\n\n\ndef __getattr__(name: str) -> Any:\n  if name in _CACHE:\n    return _CACHE[name]\n  modpath = _LAZY_MODULES.get(name)\n  if modpath is None:\n    raise AttributeError(f\"module {__name__!r} has no attribute {name!r}\")\n  module = importlib.import_module(modpath)\n  # ensure future 'import langextract.<name>' returns the same module\n  sys.modules[f\"{__name__}.{name}\"] = module\n  setattr(sys.modules[__name__], name, module)\n  _CACHE[name] = module\n  return module\n\n\ndef __dir__():\n  return sorted(__all__)\n"
  },
  {
    "path": "langextract/_compat/README.md",
    "content": "# Backward Compatibility Layer\n\nThis directory contains backward compatibility shims for deprecated imports.\n\n## Deprecation Timeline\n\nAll code in this directory will be removed in LangExtract v2.0.0.\n\n## Migration Guide\n\nThe following imports are deprecated and should be updated:\n\n### Inference Module\n- `from langextract.inference import BaseLanguageModel` → `from langextract.core.base_model import BaseLanguageModel`\n- `from langextract.inference import ScoredOutput` → `from langextract.core.types import ScoredOutput`\n- `from langextract.inference import InferenceOutputError` → `from langextract.core.exceptions import InferenceOutputError`\n- `from langextract.inference import GeminiLanguageModel` → `from langextract.providers.gemini import GeminiLanguageModel`\n- `from langextract.inference import OpenAILanguageModel` → `from langextract.providers.openai import OpenAILanguageModel`\n- `from langextract.inference import OllamaLanguageModel` → `from langextract.providers.ollama import OllamaLanguageModel`\n\n### Schema Module\n- `from langextract.schema import BaseSchema` → `from langextract.core.schema import BaseSchema`\n- `from langextract.schema import Constraint` → `from langextract.core.schema import Constraint`\n- `from langextract.schema import ConstraintType` → `from langextract.core.schema import ConstraintType`\n- `from langextract.schema import EXTRACTIONS_KEY` → `from langextract.core.schema import EXTRACTIONS_KEY`\n- `from langextract.schema import GeminiSchema` → `from langextract.providers.schemas.gemini import GeminiSchema`\n\n### Exceptions Module\n- All exceptions: `from langextract.exceptions import *` → `from langextract.core.exceptions import *`\n\n### Registry Module\n- `from langextract.registry import *` → `from langextract.plugins import *`\n- `from langextract.providers.registry import *` → `from langextract.providers.router import *`\n\n## For Contributors\n\nDo not add new code to this directory. All new development should use the canonical imports from `core/` and `providers/`.\n"
  },
  {
    "path": "langextract/_compat/__init__.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Backward compatibility layer for LangExtract.\n\nThis package contains compatibility shims for deprecated imports. All code\nin this directory will be removed in v2.0.0.\n\"\"\"\n\nfrom __future__ import annotations\n\n__all__ = [\"inference\", \"schema\", \"exceptions\", \"registry\"]\n"
  },
  {
    "path": "langextract/_compat/exceptions.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.exceptions imports.\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport warnings\n\nfrom langextract.core import exceptions\n\n\n# Re-export exceptions from core.exceptions with a warning-on-first-access\ndef __getattr__(name: str):\n  allowed = {\n      \"LangExtractError\",\n      \"InferenceError\",\n      \"InferenceConfigError\",\n      \"InferenceRuntimeError\",\n      \"InferenceOutputError\",\n      \"ProviderError\",\n      \"SchemaError\",\n  }\n  if name in allowed:\n    warnings.warn(\n        \"`langextract.exceptions` is deprecated; import from\"\n        \" `langextract.core.exceptions`.\",\n        FutureWarning,\n        stacklevel=2,\n    )\n    return getattr(exceptions, name)\n  raise AttributeError(name)\n"
  },
  {
    "path": "langextract/_compat/inference.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.inference imports.\"\"\"\n\nfrom __future__ import annotations\n\nimport enum\nimport warnings\n\n\nclass InferenceType(enum.Enum):\n  \"\"\"Enum for inference types - kept for backward compatibility.\"\"\"\n\n  ITERATIVE = \"iterative\"\n  MULTIPROCESS = \"multiprocess\"\n\n\ndef __getattr__(name: str):\n  moved = {\n      \"BaseLanguageModel\": (\"langextract.core.base_model\", \"BaseLanguageModel\"),\n      \"ScoredOutput\": (\"langextract.core.types\", \"ScoredOutput\"),\n      \"InferenceOutputError\": (\n          \"langextract.core.exceptions\",\n          \"InferenceOutputError\",\n      ),\n      \"GeminiLanguageModel\": (\n          \"langextract.providers.gemini\",\n          \"GeminiLanguageModel\",\n      ),\n      \"OpenAILanguageModel\": (\n          \"langextract.providers.openai\",\n          \"OpenAILanguageModel\",\n      ),\n      \"OllamaLanguageModel\": (\n          \"langextract.providers.ollama\",\n          \"OllamaLanguageModel\",\n      ),\n  }\n  if name in moved:\n    mod, attr = moved[name]\n    warnings.warn(\n        f\"`langextract.inference.{name}` is deprecated and will be removed in\"\n        f\" v2.0.0; use `{mod}.{attr}` instead.\",\n        FutureWarning,\n        stacklevel=2,\n    )\n    module = __import__(mod, fromlist=[attr])\n    return getattr(module, attr)\n  raise AttributeError(name)\n"
  },
  {
    "path": "langextract/_compat/registry.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.registry imports.\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport warnings\n\nfrom langextract import plugins\n\n\ndef __getattr__(name: str):\n  \"\"\"Forward to plugins module with deprecation warning.\"\"\"\n  warnings.warn(\n      \"`langextract.registry` is deprecated and will be removed in v2.0.0; \"\n      \"use `langextract.plugins` instead.\",\n      FutureWarning,\n      stacklevel=2,\n  )\n  return getattr(plugins, name)\n"
  },
  {
    "path": "langextract/_compat/schema.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.schema imports.\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport warnings\n\n\ndef __getattr__(name: str):\n  moved = {\n      \"BaseSchema\": (\"langextract.core.schema\", \"BaseSchema\"),\n      \"Constraint\": (\"langextract.core.schema\", \"Constraint\"),\n      \"ConstraintType\": (\"langextract.core.schema\", \"ConstraintType\"),\n      \"EXTRACTIONS_KEY\": (\"langextract.core.schema\", \"EXTRACTIONS_KEY\"),\n      \"GeminiSchema\": (\"langextract.providers.schemas.gemini\", \"GeminiSchema\"),\n  }\n  if name in moved:\n    mod, attr = moved[name]\n    warnings.warn(\n        f\"`langextract.schema.{name}` is deprecated and will be removed in\"\n        f\" v2.0.0; use `{mod}.{attr}` instead.\",\n        FutureWarning,\n        stacklevel=2,\n    )\n    module = __import__(mod, fromlist=[attr])\n    return getattr(module, attr)\n  raise AttributeError(name)\n"
  },
  {
    "path": "langextract/annotation.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Provides functionality for annotating medical text using a language model.\n\nThe annotation process involves tokenizing the input text, generating prompts\nfor the language model, and resolving the language model's output into\nstructured annotations.\n\nUsage example:\n    annotator = Annotator(language_model, prompt_template)\n    annotated_documents = annotator.annotate_documents(documents, resolver)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport collections\nfrom collections.abc import Iterable, Iterator\nimport time\nfrom typing import DefaultDict\n\nfrom absl import logging\n\nfrom langextract import chunking\nfrom langextract import progress\nfrom langextract import prompting\nfrom langextract import resolver as resolver_lib\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import format_handler as fh\nfrom langextract.core import tokenizer as tokenizer_lib\n\n\ndef _merge_non_overlapping_extractions(\n    all_extractions: list[Iterable[data.Extraction]],\n) -> list[data.Extraction]:\n  \"\"\"Merges extractions from multiple extraction passes.\n\n  When extractions from different passes overlap in their character positions,\n  the extraction from the earlier pass is kept (first-pass wins strategy).\n  Only non-overlapping extractions from later passes are added to the result.\n\n  Args:\n    all_extractions: List of extraction iterables from different sequential\n      extraction passes, ordered by pass number.\n\n  Returns:\n    List of merged extractions with overlaps resolved in favor of earlier\n    passes.\n  \"\"\"\n  if not all_extractions:\n    return []\n\n  if len(all_extractions) == 1:\n    return list(all_extractions[0])\n\n  merged_extractions = list(all_extractions[0])\n\n  for pass_extractions in all_extractions[1:]:\n    for extraction in pass_extractions:\n      overlaps = False\n      if extraction.char_interval is not None:\n        for existing_extraction in merged_extractions:\n          if existing_extraction.char_interval is not None:\n            if _extractions_overlap(extraction, existing_extraction):\n              overlaps = True\n              break\n\n      if not overlaps:\n        merged_extractions.append(extraction)\n\n  return merged_extractions\n\n\ndef _extractions_overlap(\n    extraction1: data.Extraction, extraction2: data.Extraction\n) -> bool:\n  \"\"\"Checks if two extractions overlap based on their character intervals.\n\n  Args:\n    extraction1: First extraction to compare.\n    extraction2: Second extraction to compare.\n\n  Returns:\n    True if the extractions overlap, False otherwise.\n  \"\"\"\n  if extraction1.char_interval is None or extraction2.char_interval is None:\n    return False\n\n  start1, end1 = (\n      extraction1.char_interval.start_pos,\n      extraction1.char_interval.end_pos,\n  )\n  start2, end2 = (\n      extraction2.char_interval.start_pos,\n      extraction2.char_interval.end_pos,\n  )\n\n  if start1 is None or end1 is None or start2 is None or end2 is None:\n    return False\n\n  # Two intervals overlap if one starts before the other ends\n  return start1 < end2 and start2 < end1\n\n\ndef _document_chunk_iterator(\n    documents: Iterable[data.Document],\n    max_char_buffer: int,\n    restrict_repeats: bool = True,\n    tokenizer: tokenizer_lib.Tokenizer | None = None,\n) -> Iterator[chunking.TextChunk]:\n  \"\"\"Iterates over documents to yield text chunks along with the document ID.\n\n  Args:\n    documents: A sequence of Document objects.\n    max_char_buffer: The maximum character buffer size for the ChunkIterator.\n    restrict_repeats: Whether to restrict the same document id from being\n      visited more than once.\n    tokenizer: Optional tokenizer instance.\n\n  Yields:\n    TextChunk containing document ID for a corresponding document.\n\n  Raises:\n    InvalidDocumentError: If restrict_repeats is True and the same document ID\n      is visited more than once. Valid documents prior to the error will be\n      returned.\n  \"\"\"\n  visited_ids = set()\n  for document in documents:\n    if tokenizer:\n      tokenized_text = tokenizer.tokenize(document.text or \"\")\n    else:\n      tokenized_text = document.tokenized_text\n    document_id = document.document_id\n    if restrict_repeats and document_id in visited_ids:\n      raise exceptions.InvalidDocumentError(\n          f\"Document id {document_id} is already visited.\"\n      )\n    chunk_iter = chunking.ChunkIterator(\n        text=tokenized_text,\n        max_char_buffer=max_char_buffer,\n        document=document,\n        tokenizer_impl=tokenizer or tokenizer_lib.RegexTokenizer(),\n    )\n    visited_ids.add(document_id)\n\n    yield from chunk_iter\n\n\nclass Annotator:\n  \"\"\"Annotates documents with extractions using a language model.\"\"\"\n\n  def __init__(\n      self,\n      language_model: base_model.BaseLanguageModel,\n      prompt_template: prompting.PromptTemplateStructured,\n      format_type: data.FormatType = data.FormatType.YAML,\n      attribute_suffix: str = data.ATTRIBUTE_SUFFIX,\n      fence_output: bool = False,\n      format_handler: fh.FormatHandler | None = None,\n  ):\n    \"\"\"Initializes Annotator.\n\n    Args:\n      language_model: Model which performs language model inference.\n      prompt_template: Structured prompt template where the answer is expected\n        to be formatted text (YAML or JSON).\n      format_type: The format type for the output (YAML or JSON).\n      attribute_suffix: Suffix to append to attribute keys in the output.\n      fence_output: Whether to expect/generate fenced output (```json or\n        ```yaml). When True, the model is prompted to generate fenced output and\n        the resolver expects it. When False, raw JSON/YAML is expected.\n        Defaults to False. If format_handler is provided, it takes precedence.\n      format_handler: Optional FormatHandler for managing format-specific logic.\n    \"\"\"\n    self._language_model = language_model\n\n    if format_handler is None:\n      format_handler = fh.FormatHandler(\n          format_type=format_type,\n          use_wrapper=True,\n          wrapper_key=data.EXTRACTIONS_KEY,\n          use_fences=fence_output,\n          attribute_suffix=attribute_suffix,\n      )\n\n    self._prompt_generator = prompting.QAPromptGenerator(\n        template=prompt_template,\n        format_handler=format_handler,\n    )\n\n    logging.debug(\n        \"Annotator initialized with format_handler: %s\", format_handler\n    )\n\n  def annotate_documents(\n      self,\n      documents: Iterable[data.Document],\n      resolver: resolver_lib.AbstractResolver | None = None,\n      max_char_buffer: int = 200,\n      batch_length: int = 1,\n      debug: bool = True,\n      extraction_passes: int = 1,\n      context_window_chars: int | None = None,\n      show_progress: bool = True,\n      tokenizer: tokenizer_lib.Tokenizer | None = None,\n      **kwargs,\n  ) -> Iterator[data.AnnotatedDocument]:\n    \"\"\"Annotates a sequence of documents with NLP extractions.\n\n      Breaks documents into chunks, processes them into prompts and performs\n      batched inference, mapping annotated extractions back to the original\n      document. Batch processing is determined by batch_length, and can operate\n      across documents for optimized throughput.\n\n    Args:\n      documents: Documents to annotate. Each document is expected to have a\n        unique document_id.\n      resolver: Resolver to use for extracting information from text.\n      max_char_buffer: Max number of characters that we can run inference on.\n        The text will be broken into chunks up to this length.\n      batch_length: Number of chunks to process in a single batch.\n      debug: Whether to populate debug fields.\n      extraction_passes: Number of sequential extraction attempts to improve\n        recall by finding additional entities. Defaults to 1, which performs\n        standard single extraction.\n        Values > 1 reprocess tokens multiple times, potentially increasing\n        costs with the potential for a more thorough extraction.\n      context_window_chars: Number of characters from the previous chunk to\n        include as context for the current chunk. Helps with coreference\n        resolution across chunk boundaries. Defaults to None (disabled).\n      show_progress: Whether to show progress bar. Defaults to True.\n      tokenizer: Optional tokenizer to use. If None, uses default tokenizer.\n      **kwargs: Additional arguments passed to LanguageModel.infer and Resolver.\n\n    Yields:\n      Resolved annotations from input documents.\n\n    Raises:\n      ValueError: If there are no scored outputs during inference.\n    \"\"\"\n    if resolver is None:\n      resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML)\n\n    if extraction_passes == 1:\n      yield from self._annotate_documents_single_pass(\n          documents,\n          resolver,\n          max_char_buffer,\n          batch_length,\n          debug,\n          show_progress,\n          context_window_chars=context_window_chars,\n          tokenizer=tokenizer,\n          **kwargs,\n      )\n    else:\n      yield from self._annotate_documents_sequential_passes(\n          documents,\n          resolver,\n          max_char_buffer,\n          batch_length,\n          debug,\n          extraction_passes,\n          show_progress,\n          context_window_chars=context_window_chars,\n          tokenizer=tokenizer,\n          **kwargs,\n      )\n\n  def _annotate_documents_single_pass(\n      self,\n      documents: Iterable[data.Document],\n      resolver: resolver_lib.AbstractResolver,\n      max_char_buffer: int,\n      batch_length: int,\n      debug: bool,\n      show_progress: bool = True,\n      context_window_chars: int | None = None,\n      tokenizer: tokenizer_lib.Tokenizer | None = None,\n      **kwargs,\n  ) -> Iterator[data.AnnotatedDocument]:\n    \"\"\"Single-pass annotation with stable ordering and streaming emission.\n\n    Streams input without full materialization, maintains correct attribution\n    across batches, and emits completed documents immediately to minimize\n    peak memory usage. Handles generators from both infer() and align().\n\n    When context_window_chars is set, includes text from the previous chunk as\n    context for coreference resolution across chunk boundaries.\n    \"\"\"\n    doc_order: list[str] = []\n    doc_text_by_id: dict[str, str] = {}\n    per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict(\n        list\n    )\n    next_emit_idx = 0\n\n    def _capture_docs(src: Iterable[data.Document]) -> Iterator[data.Document]:\n      \"\"\"Captures document order and text lazily as chunks are produced.\"\"\"\n      for document in src:\n        document_id = document.document_id\n        if document_id in doc_text_by_id:\n          raise exceptions.InvalidDocumentError(\n              f\"Duplicate document_id: {document_id}\"\n          )\n        doc_order.append(document_id)\n        doc_text_by_id[document_id] = document.text or \"\"\n        yield document\n\n    def _emit_docs_iter(\n        keep_last_doc: bool,\n    ) -> Iterator[data.AnnotatedDocument]:\n      \"\"\"Yields documents that are guaranteed complete.\n\n      Args:\n        keep_last_doc: If True, retains the most recently started document\n          for additional extractions. If False, emits all remaining documents.\n      \"\"\"\n      nonlocal next_emit_idx\n      limit = max(0, len(doc_order) - 1) if keep_last_doc else len(doc_order)\n      while next_emit_idx < limit:\n        document_id = doc_order[next_emit_idx]\n        yield data.AnnotatedDocument(\n            document_id=document_id,\n            extractions=per_doc.get(document_id, []),\n            text=doc_text_by_id.get(document_id, \"\"),\n        )\n        per_doc.pop(document_id, None)\n        doc_text_by_id.pop(document_id, None)\n        next_emit_idx += 1\n\n    chunk_iter = _document_chunk_iterator(\n        _capture_docs(documents), max_char_buffer, tokenizer=tokenizer\n    )\n    batches = chunking.make_batches_of_textchunk(chunk_iter, batch_length)\n\n    model_info = progress.get_model_info(self._language_model)\n    batch_iter = progress.create_extraction_progress_bar(\n        batches, model_info=model_info, disable=not show_progress\n    )\n\n    chars_processed = 0\n\n    prompt_builder = prompting.ContextAwarePromptBuilder(\n        generator=self._prompt_generator,\n        context_window_chars=context_window_chars,\n    )\n\n    try:\n      for batch in batch_iter:\n        if not batch:\n          continue\n\n        prompts = [\n            prompt_builder.build_prompt(\n                chunk.chunk_text, chunk.document_id, chunk.additional_context\n            )\n            for chunk in batch\n        ]\n\n        if show_progress:\n          current_chars = sum(\n              len(text_chunk.chunk_text) for text_chunk in batch\n          )\n          try:\n            batch_iter.set_description(\n                progress.format_extraction_progress(\n                    model_info,\n                    current_chars=current_chars,\n                    processed_chars=chars_processed,\n                )\n            )\n          except AttributeError:\n            pass\n\n        outputs = self._language_model.infer(batch_prompts=prompts, **kwargs)\n        if not isinstance(outputs, list):\n          outputs = list(outputs)\n\n        for text_chunk, scored_outputs in zip(batch, outputs):\n          if not isinstance(scored_outputs, list):\n            scored_outputs = list(scored_outputs)\n          if not scored_outputs:\n            raise exceptions.InferenceOutputError(\n                \"No scored outputs from language model.\"\n            )\n\n          resolved_extractions = resolver.resolve(\n              scored_outputs[0].output, debug=debug, **kwargs\n          )\n\n          token_offset = (\n              text_chunk.token_interval.start_index\n              if text_chunk.token_interval\n              else 0\n          )\n          char_offset = (\n              text_chunk.char_interval.start_pos\n              if text_chunk.char_interval\n              else 0\n          )\n\n          aligned_extractions = resolver.align(\n              resolved_extractions,\n              text_chunk.chunk_text,\n              token_offset,\n              char_offset,\n              tokenizer_inst=tokenizer,\n              **kwargs,\n          )\n\n          for extraction in aligned_extractions:\n            per_doc[text_chunk.document_id].append(extraction)\n\n          if show_progress and text_chunk.char_interval is not None:\n            chars_processed += (\n                text_chunk.char_interval.end_pos\n                - text_chunk.char_interval.start_pos\n            )\n\n        yield from _emit_docs_iter(keep_last_doc=True)\n\n    finally:\n      batch_iter.close()\n\n    yield from _emit_docs_iter(keep_last_doc=False)\n\n  def _annotate_documents_sequential_passes(\n      self,\n      documents: Iterable[data.Document],\n      resolver: resolver_lib.AbstractResolver,\n      max_char_buffer: int,\n      batch_length: int,\n      debug: bool,\n      extraction_passes: int,\n      show_progress: bool = True,\n      context_window_chars: int | None = None,\n      tokenizer: tokenizer_lib.Tokenizer | None = None,\n      **kwargs,\n  ) -> Iterator[data.AnnotatedDocument]:\n    \"\"\"Sequential extraction passes logic for improved recall.\"\"\"\n\n    logging.info(\n        \"Starting sequential extraction passes for improved recall with %d\"\n        \" passes.\",\n        extraction_passes,\n    )\n\n    document_list = list(documents)\n\n    document_extractions_by_pass: dict[str, list[list[data.Extraction]]] = {}\n    document_texts: dict[str, str] = {}\n    # Preserve text up-front so we can emit documents even if later passes\n    # produce no extractions.\n    for _doc in document_list:\n      document_texts[_doc.document_id] = _doc.text or \"\"\n\n    for pass_num in range(extraction_passes):\n      logging.info(\n          \"Starting extraction pass %d of %d\", pass_num + 1, extraction_passes\n      )\n\n      for annotated_doc in self._annotate_documents_single_pass(\n          document_list,\n          resolver,\n          max_char_buffer,\n          batch_length,\n          debug=(debug and pass_num == 0),\n          show_progress=show_progress if pass_num == 0 else False,\n          context_window_chars=context_window_chars,\n          tokenizer=tokenizer,\n          **kwargs,\n      ):\n        doc_id = annotated_doc.document_id\n\n        if doc_id not in document_extractions_by_pass:\n          document_extractions_by_pass[doc_id] = []\n          # Keep first-seen text (already pre-filled above).\n\n        document_extractions_by_pass[doc_id].append(\n            annotated_doc.extractions or []\n        )\n\n    # Emit results strictly in original input order.\n    for doc in document_list:\n      doc_id = doc.document_id\n      all_pass_extractions = document_extractions_by_pass.get(doc_id, [])\n      merged_extractions = _merge_non_overlapping_extractions(\n          all_pass_extractions\n      )\n\n      if debug:\n        total_extractions = sum(\n            len(extractions) for extractions in all_pass_extractions\n        )\n        logging.info(\n            \"Document %s: Merged %d extractions from %d passes into \"\n            \"%d non-overlapping extractions.\",\n            doc_id,\n            total_extractions,\n            extraction_passes,\n            len(merged_extractions),\n        )\n\n      yield data.AnnotatedDocument(\n          document_id=doc_id,\n          extractions=merged_extractions,\n          text=document_texts.get(doc_id, doc.text or \"\"),\n      )\n\n    logging.info(\"Sequential extraction passes completed.\")\n\n  def annotate_text(\n      self,\n      text: str,\n      resolver: resolver_lib.AbstractResolver | None = None,\n      max_char_buffer: int = 200,\n      batch_length: int = 1,\n      additional_context: str | None = None,\n      debug: bool = True,\n      extraction_passes: int = 1,\n      context_window_chars: int | None = None,\n      show_progress: bool = True,\n      tokenizer: tokenizer_lib.Tokenizer | None = None,\n      **kwargs,\n  ) -> data.AnnotatedDocument:\n    \"\"\"Annotates text with NLP extractions for text input.\n\n    Args:\n      text: Source text to annotate.\n      resolver: Resolver to use for extracting information from text.\n      max_char_buffer: Max number of characters that we can run inference on.\n        The text will be broken into chunks up to this length.\n      batch_length: Number of chunks to process in a single batch.\n      additional_context: Additional context to supplement prompt instructions.\n      debug: Whether to populate debug fields.\n      extraction_passes: Number of sequential extraction passes to improve\n        recall by finding additional entities. Defaults to 1, which performs\n        standard single extraction. Values > 1 reprocess tokens multiple times,\n        potentially increasing costs.\n      context_window_chars: Number of characters from the previous chunk to\n        include as context for coreference resolution. Defaults to None\n        (disabled).\n      show_progress: Whether to show progress bar. Defaults to True.\n      tokenizer: Optional tokenizer instance.\n      **kwargs: Additional arguments for inference and resolver_lib.\n\n    Returns:\n      Resolved annotations from text for document.\n    \"\"\"\n    if resolver is None:\n      resolver = resolver_lib.Resolver(\n          format_type=data.FormatType.YAML,\n      )\n\n    start_time = time.time() if debug else None\n\n    documents = [\n        data.Document(\n            text=text,\n            document_id=None,\n            additional_context=additional_context,\n        )\n    ]\n\n    annotations = list(\n        self.annotate_documents(\n            documents=documents,\n            resolver=resolver,\n            max_char_buffer=max_char_buffer,\n            batch_length=batch_length,\n            debug=debug,\n            extraction_passes=extraction_passes,\n            context_window_chars=context_window_chars,\n            show_progress=show_progress,\n            tokenizer=tokenizer,\n            **kwargs,\n        )\n    )\n    assert (\n        len(annotations) == 1\n    ), f\"Expected 1 annotation but got {len(annotations)} annotations.\"\n\n    if debug and annotations[0].extractions:\n      elapsed_time = time.time() - start_time if start_time else None\n      num_extractions = len(annotations[0].extractions)\n      unique_classes = len(\n          set(e.extraction_class for e in annotations[0].extractions)\n      )\n      num_chunks = len(text) // max_char_buffer + (\n          1 if len(text) % max_char_buffer else 0\n      )\n\n      progress.print_extraction_summary(\n          num_extractions,\n          unique_classes,\n          elapsed_time=elapsed_time,\n          chars_processed=len(text),\n          num_chunks=num_chunks,\n      )\n\n    return data.AnnotatedDocument(\n        document_id=annotations[0].document_id,\n        extractions=annotations[0].extractions,\n        text=annotations[0].text,\n    )\n"
  },
  {
    "path": "langextract/chunking.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Library for breaking documents into chunks of sentences.\n\nWhen a text-to-text model (e.g. a large language model with a fixed context\nsize) can not accommodate a large document, this library can help us break the\ndocument into chunks of a required maximum length that we can perform\ninference on.\n\"\"\"\n\nfrom collections.abc import Iterable, Iterator, Sequence\nimport dataclasses\nimport re\n\nfrom absl import logging\nimport more_itertools\n\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import tokenizer as tokenizer_lib\n\n\nclass TokenUtilError(exceptions.LangExtractError):\n  \"\"\"Error raised when token_util returns unexpected values.\"\"\"\n\n\n@dataclasses.dataclass\nclass TextChunk:\n  \"\"\"Stores a text chunk with attributes to the source document.\n\n  Attributes:\n    token_interval: The token interval of the chunk in the source document.\n    document: The source document.\n  \"\"\"\n\n  token_interval: tokenizer_lib.TokenInterval\n  document: data.Document | None = None\n  _chunk_text: str | None = dataclasses.field(\n      default=None, init=False, repr=False\n  )\n  _sanitized_chunk_text: str | None = dataclasses.field(\n      default=None, init=False, repr=False\n  )\n  _char_interval: data.CharInterval | None = dataclasses.field(\n      default=None, init=False, repr=False\n  )\n\n  def __str__(self):\n    interval_repr = (\n        f\"start_index: {self.token_interval.start_index}, end_index:\"\n        f\" {self.token_interval.end_index}\"\n    )\n\n    doc_id_repr = (\n        f\"Document ID: {self.document_id}\"\n        if self.document_id\n        else \"Document ID: None\"\n    )\n\n    try:\n      chunk_text_repr = f\"'{self.chunk_text}'\"\n    except ValueError:\n      chunk_text_repr = \"<unavailable: document_text not set>\"\n\n    return (\n        \"TextChunk(\\n\"\n        f\"  interval=[{interval_repr}],\\n\"\n        f\"  {doc_id_repr},\\n\"\n        f\"  Chunk Text: {chunk_text_repr}\\n\"\n        \")\"\n    )\n\n  @property\n  def document_id(self) -> str | None:\n    \"\"\"Gets the document ID from the source document.\"\"\"\n    if self.document is not None:\n      return self.document.document_id\n    return None\n\n  @property\n  def document_text(self) -> tokenizer_lib.TokenizedText | None:\n    \"\"\"Gets the tokenized text from the source document.\"\"\"\n    if self.document is not None:\n      return self.document.tokenized_text\n    return None\n\n  @property\n  def chunk_text(self) -> str:\n    \"\"\"Gets the chunk text. Raises an error if `document_text` is not set.\"\"\"\n    if self.document_text is None:\n      raise ValueError(\"document_text must be set to access chunk_text.\")\n    if self._chunk_text is None:\n      self._chunk_text = get_token_interval_text(\n          self.document_text, self.token_interval\n      )\n    return self._chunk_text\n\n  @property\n  def sanitized_chunk_text(self) -> str:\n    \"\"\"Gets the sanitized chunk text.\"\"\"\n    if self._sanitized_chunk_text is None:\n      self._sanitized_chunk_text = _sanitize(self.chunk_text)\n    return self._sanitized_chunk_text\n\n  @property\n  def additional_context(self) -> str | None:\n    \"\"\"Gets the additional context for prompting from the source document.\"\"\"\n    if self.document is not None:\n      return self.document.additional_context\n    return None\n\n  @property\n  def char_interval(self) -> data.CharInterval:\n    \"\"\"Gets the character interval corresponding to the token interval.\n\n    Returns:\n      data.CharInterval: The character interval for this chunk.\n\n    Raises:\n      ValueError: If document_text is not set.\n    \"\"\"\n    if self._char_interval is None:\n      if self.document_text is None:\n        raise ValueError(\"document_text must be set to compute char_interval.\")\n      self._char_interval = get_char_interval(\n          self.document_text, self.token_interval\n      )\n    return self._char_interval\n\n\ndef create_token_interval(\n    start_index: int, end_index: int\n) -> tokenizer_lib.TokenInterval:\n  \"\"\"Creates a token interval.\n\n  Args:\n    start_index: first token's index (inclusive).\n    end_index: last token's index + 1 (exclusive).\n\n  Returns:\n    Token interval.\n\n  Raises:\n    ValueError: If the token indices are invalid.\n  \"\"\"\n  if start_index < 0:\n    raise ValueError(f\"Start index {start_index} must be positive.\")\n  if start_index >= end_index:\n    raise ValueError(\n        f\"Start index {start_index} must be < end index {end_index}.\"\n    )\n  return tokenizer_lib.TokenInterval(\n      start_index=start_index, end_index=end_index\n  )\n\n\ndef get_token_interval_text(\n    tokenized_text: tokenizer_lib.TokenizedText,\n    token_interval: tokenizer_lib.TokenInterval,\n) -> str:\n  \"\"\"Get the text within an interval of tokens.\n\n  Args:\n    tokenized_text: Tokenized documents.\n    token_interval: An interval specifying the start (inclusive) and end\n      (exclusive) indices of the tokens to extract. These indices refer to the\n      positions in the list of tokens within `tokenized_text.tokens`, not the\n      value of the field `index` of `token_pb2.Token`. If the tokens are\n      [(index:0, text:A), (index:5, text:B), (index:10, text:C)], we should use\n      token_interval=[0, 2] to represent taking A and B, not [0, 6]. Please see\n      details from the implementation of tokenizer_lib.tokens_text\n\n  Returns:\n    Text within the token interval.\n\n  Raises:\n    ValueError: If the token indices are invalid.\n    TokenUtilError: If tokenizer_lib.tokens_text returns an empty\n    string.\n  \"\"\"\n  if token_interval.start_index >= token_interval.end_index:\n    raise ValueError(\n        f\"Start index {token_interval.start_index} must be < end index \"\n        f\"{token_interval.end_index}.\"\n    )\n  return_string = tokenizer_lib.tokens_text(tokenized_text, token_interval)\n  logging.debug(\n      \"Token util returns string: %s for tokenized_text: %s, token_interval:\"\n      \" %s\",\n      return_string,\n      tokenized_text,\n      token_interval,\n  )\n  if tokenized_text.text and not return_string:\n    raise TokenUtilError(\n        \"Token util returns an empty string unexpectedly. Number of tokens is\"\n        f\" tokenized_text: {len(tokenized_text.tokens)}, token_interval is\"\n        f\" {token_interval.start_index} to {token_interval.end_index}, which\"\n        \" should not lead to empty string.\"\n    )\n  return return_string\n\n\ndef get_char_interval(\n    tokenized_text: tokenizer_lib.TokenizedText,\n    token_interval: tokenizer_lib.TokenInterval,\n) -> data.CharInterval:\n  \"\"\"Returns the char interval corresponding to the token interval.\n\n  Args:\n    tokenized_text: Document.\n    token_interval: Token interval.\n\n  Returns:\n    Char interval of the token interval of interest.\n\n  Raises:\n    ValueError: If the token_interval is invalid.\n  \"\"\"\n  if token_interval.start_index >= token_interval.end_index:\n    raise ValueError(\n        f\"Start index {token_interval.start_index} must be < end index \"\n        f\"{token_interval.end_index}.\"\n    )\n  start_token = tokenized_text.tokens[token_interval.start_index]\n  # Penultimate token prior to interval.end_index\n  final_token = tokenized_text.tokens[token_interval.end_index - 1]\n  return data.CharInterval(\n      start_pos=start_token.char_interval.start_pos,\n      end_pos=final_token.char_interval.end_pos,\n  )\n\n\ndef _sanitize(text: str) -> str:\n  \"\"\"Converts all whitespace characters in input text to a single space.\n\n  Args:\n    text: Input to sanitize.\n\n  Returns:\n    Sanitized text with newlines and excess spaces removed.\n\n  Raises:\n    ValueError: If the sanitized text is empty.\n  \"\"\"\n\n  sanitized_text = re.sub(r\"\\s+\", \" \", text.strip())\n  if not sanitized_text:\n    raise ValueError(\"Sanitized text is empty.\")\n  return sanitized_text\n\n\ndef make_batches_of_textchunk(\n    chunk_iter: Iterator[TextChunk],\n    batch_length: int,\n) -> Iterable[Sequence[TextChunk]]:\n  \"\"\"Processes chunks into batches of TextChunk for inference, using itertools.batched.\n\n  Args:\n    chunk_iter: Iterator of TextChunks.\n    batch_length: Number of chunks to include in each batch.\n\n  Yields:\n    Batches of TextChunks.\n  \"\"\"\n  for batch in more_itertools.batched(chunk_iter, batch_length):\n    yield list(batch)\n\n\nclass SentenceIterator:\n  \"\"\"Iterate through sentences of a tokenized text.\"\"\"\n\n  def __init__(\n      self,\n      tokenized_text: tokenizer_lib.TokenizedText,\n      curr_token_pos: int = 0,\n  ):\n    \"\"\"Constructor.\n\n    Args:\n      tokenized_text: Document to iterate through.\n      curr_token_pos: Iterate through sentences from this token position.\n\n    Raises:\n      IndexError: if curr_token_pos is not within the document.\n    \"\"\"\n    self.tokenized_text = tokenized_text\n    self.token_len = len(tokenized_text.tokens)\n    if curr_token_pos < 0:\n      raise IndexError(\n          f\"Current token position {curr_token_pos} can not be negative.\"\n      )\n    elif curr_token_pos > self.token_len:\n      raise IndexError(\n          f\"Current token position {curr_token_pos} is past the length of the \"\n          f\"document {self.token_len}.\"\n      )\n    self.curr_token_pos = curr_token_pos\n\n  def __iter__(self) -> Iterator[tokenizer_lib.TokenInterval]:\n    return self\n\n  def __next__(self) -> tokenizer_lib.TokenInterval:\n    \"\"\"Returns next sentence's interval starting from current token position.\n\n    Returns:\n      Next sentence token interval starting from current token position.\n\n    Raises:\n      StopIteration: If end of text is reached.\n    \"\"\"\n    assert self.curr_token_pos <= self.token_len\n    if self.curr_token_pos == self.token_len:\n      raise StopIteration\n    # This locates the sentence which contains the current token position.\n    sentence_range = tokenizer_lib.find_sentence_range(\n        self.tokenized_text.text,\n        self.tokenized_text.tokens,\n        self.curr_token_pos,\n    )\n    assert sentence_range\n    # Start the sentence from the current token position.\n    # If we are in the middle of a sentence, we should start from there.\n    sentence_range = create_token_interval(\n        self.curr_token_pos, sentence_range.end_index\n    )\n    self.curr_token_pos = sentence_range.end_index\n    return sentence_range\n\n\nclass ChunkIterator:\n  r\"\"\"Iterate through chunks of a tokenized text.\n\n  Chunks may consist of sentences or sentence fragments that can fit into the\n  maximum character buffer that we can run inference on.\n\n  A)\n  If a sentence length exceeds the max char buffer, then it needs to be broken\n  into chunks that can fit within the max char buffer. We do this in a way that\n  maximizes the chunk length while respecting newlines (if present) and token\n  boundaries.\n  Consider this sentence from a poem by John Donne:\n  ```\n  No man is an island,\n  Entire of itself,\n  Every man is a piece of the continent,\n  A part of the main.\n  ```\n  With max_char_buffer=40, the chunks are:\n  * \"No man is an island,\\nEntire of itself,\" len=38\n  * \"Every man is a piece of the continent,\" len=38\n  * \"A part of the main.\" len=19\n\n  B)\n  If a single token exceeds the max char buffer, it comprises the whole chunk.\n  Consider the sentence:\n  \"This is antidisestablishmentarianism.\"\n  With max_char_buffer=20, the chunks are:\n  * \"This is\" len=7\n  * \"antidisestablishmentarianism\" len=28\n  * \".\" len(1)\n\n  C)\n  If multiple *whole* sentences can fit within the max char buffer, then they\n  are used to form the chunk.\n  Consider the sentences:\n  \"Roses are red. Violets are blue. Flowers are nice. And so are you.\"\n  With max_char_buffer=60, the chunks are:\n  * \"Roses are red. Violets are blue. Flowers are nice.\" len=50\n  * \"And so are you.\" len=15\n  \"\"\"\n\n  def __init__(\n      self,\n      text: str | tokenizer_lib.TokenizedText | None,\n      max_char_buffer: int,\n      tokenizer_impl: tokenizer_lib.Tokenizer,\n      document: data.Document | None = None,\n  ):\n    \"\"\"Constructor.\n\n    Args:\n      text: Document to chunk. Can be either a string or a tokenized text.\n      max_char_buffer: Size of buffer that we can run inference on.\n      tokenizer_impl: Tokenizer instance to use.\n      document: Optional source document.\n    \"\"\"\n    if text is None:\n      if document is None:\n        raise ValueError(\"Either text or document must be provided.\")\n      text = document.text or \"\"\n\n    if isinstance(text, str):\n      text = tokenizer_impl.tokenize(text)\n    elif isinstance(text, tokenizer_lib.TokenizedText) and not text.tokens:\n      text_to_tokenize = text.text or (document.text if document else \"\")\n      text = tokenizer_impl.tokenize(text_to_tokenize)\n    self.tokenized_text = text\n    self.max_char_buffer = max_char_buffer\n    self.sentence_iter = SentenceIterator(self.tokenized_text)\n    self.broken_sentence = False\n\n    # TODO: Refactor redundancy between document and text.\n    if document is None:\n      self.document = data.Document(text=text.text)\n    else:\n      self.document = document\n    self.document.tokenized_text = self.tokenized_text\n\n  def __iter__(self) -> Iterator[TextChunk]:\n    return self\n\n  def _tokens_exceed_buffer(\n      self, token_interval: tokenizer_lib.TokenInterval\n  ) -> bool:\n    \"\"\"Check if the token interval exceeds the maximum buffer size.\n\n    Args:\n      token_interval: Token interval to check.\n\n    Returns:\n      True if the token interval exceeds the maximum buffer size.\n    \"\"\"\n    char_interval = get_char_interval(self.tokenized_text, token_interval)\n    return (\n        char_interval.end_pos - char_interval.start_pos\n    ) > self.max_char_buffer\n\n  def __next__(self) -> TextChunk:\n    sentence = next(self.sentence_iter)\n    # If the next token is greater than the max_char_buffer, let it be the\n    # entire chunk.\n    curr_chunk = create_token_interval(\n        sentence.start_index, sentence.start_index + 1\n    )\n    if self._tokens_exceed_buffer(curr_chunk):\n      self.sentence_iter = SentenceIterator(\n          self.tokenized_text, curr_token_pos=sentence.start_index + 1\n      )\n      self.broken_sentence = curr_chunk.end_index < sentence.end_index\n      return TextChunk(\n          token_interval=curr_chunk,\n          document=self.document,\n      )\n\n    # Append tokens to the chunk up to the max_char_buffer.\n    start_of_new_line = -1\n    for token_index in range(curr_chunk.start_index, sentence.end_index):\n      if self.tokenized_text.tokens[token_index].first_token_after_newline:\n        start_of_new_line = token_index\n      test_chunk = create_token_interval(\n          curr_chunk.start_index, token_index + 1\n      )\n      if self._tokens_exceed_buffer(test_chunk):\n        # Only break at newline if: 1) newline exists (> 0) and\n        # 2) it's after chunk start (prevents empty intervals)\n        if start_of_new_line > 0 and start_of_new_line > curr_chunk.start_index:\n          # Terminate the curr_chunk at the start of the most recent newline.\n          curr_chunk = create_token_interval(\n              curr_chunk.start_index, start_of_new_line\n          )\n        self.sentence_iter = SentenceIterator(\n            self.tokenized_text, curr_token_pos=curr_chunk.end_index\n        )\n        self.broken_sentence = True\n        return TextChunk(\n            token_interval=curr_chunk,\n            document=self.document,\n        )\n      else:\n        curr_chunk = test_chunk\n\n    if self.broken_sentence:\n      self.broken_sentence = False\n    else:\n      for sentence in self.sentence_iter:\n        test_chunk = create_token_interval(\n            curr_chunk.start_index, sentence.end_index\n        )\n        if self._tokens_exceed_buffer(test_chunk):\n          self.sentence_iter = SentenceIterator(\n              self.tokenized_text, curr_token_pos=curr_chunk.end_index\n          )\n          return TextChunk(\n              token_interval=curr_chunk,\n              document=self.document,\n          )\n        else:\n          curr_chunk = test_chunk\n\n    return TextChunk(\n        token_interval=curr_chunk,\n        document=self.document,\n    )\n"
  },
  {
    "path": "langextract/core/__init__.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Core abstractions for LangExtract.\n\nThis package contains the foundational base models and types used throughout\nLangExtract. Each module can be imported independently for fine-grained\ndependency management in build systems.\n\"\"\"\n\nfrom __future__ import annotations\n\n__all__ = [\n    \"base_model\",\n    \"types\",\n    \"exceptions\",\n    \"schema\",\n    \"data\",\n    \"tokenizer\",\n]\n"
  },
  {
    "path": "langextract/core/base_model.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Base interfaces for language models.\"\"\"\nfrom __future__ import annotations\n\nimport abc\nfrom collections.abc import Iterator, Sequence\nimport json\nfrom typing import Any, Mapping\n\nimport yaml\n\nfrom langextract.core import schema\nfrom langextract.core import types\n\n__all__ = ['BaseLanguageModel']\n\n\nclass BaseLanguageModel(abc.ABC):\n  \"\"\"An abstract inference class for managing LLM inference.\n\n  Attributes:\n    _constraint: A `Constraint` object specifying constraints for model output.\n  \"\"\"\n\n  def __init__(self, constraint: types.Constraint | None = None, **kwargs: Any):\n    \"\"\"Initializes the BaseLanguageModel with an optional constraint.\n\n    Args:\n      constraint: Applies constraints when decoding the output. Defaults to no\n        constraint.\n      **kwargs: Additional keyword arguments passed to the model.\n    \"\"\"\n    self._constraint = constraint or types.Constraint()\n    self._schema: schema.BaseSchema | None = None\n    self._fence_output_override: bool | None = None\n    self._extra_kwargs: dict[str, Any] = kwargs.copy()\n\n  @classmethod\n  def get_schema_class(cls) -> type[Any] | None:\n    \"\"\"Return the schema class this provider supports.\"\"\"\n    return None\n\n  def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None:\n    \"\"\"Apply a schema instance to this provider.\n\n    Optional method that providers can override to store the schema instance\n    for runtime use. The default implementation stores it as _schema.\n\n    Args:\n      schema_instance: The schema instance to apply, or None to clear.\n    \"\"\"\n    self._schema = schema_instance\n\n  @property\n  def schema(self) -> schema.BaseSchema | None:\n    \"\"\"The current schema instance if one is configured.\n\n    Returns:\n      The schema instance or None if no schema is applied.\n    \"\"\"\n    return self._schema\n\n  def set_fence_output(self, fence_output: bool | None) -> None:\n    \"\"\"Set explicit fence output preference.\n\n    Args:\n      fence_output: True to force fences, False to disable, None for auto.\n    \"\"\"\n    if not hasattr(self, '_fence_output_override'):\n      self._fence_output_override = None\n    self._fence_output_override = fence_output\n\n  @property\n  def requires_fence_output(self) -> bool:\n    \"\"\"Whether this model requires fence output for parsing.\n\n    Uses explicit override if set, otherwise computes from schema.\n    Returns True if no schema or schema doesn't require raw output.\n    \"\"\"\n    if (\n        hasattr(self, '_fence_output_override')\n        and self._fence_output_override is not None\n    ):\n      return self._fence_output_override\n\n    schema_obj = self.schema\n    if schema_obj is None:\n      return True\n    return not schema_obj.requires_raw_output\n\n  def merge_kwargs(\n      self, runtime_kwargs: Mapping[str, Any] | None = None\n  ) -> dict[str, Any]:\n    \"\"\"Merge stored extra kwargs with runtime kwargs.\n\n    Runtime kwargs take precedence over stored kwargs.\n\n    Args:\n      runtime_kwargs: Kwargs provided at inference time, or None.\n\n    Returns:\n      Merged kwargs dictionary.\n    \"\"\"\n    base = getattr(self, '_extra_kwargs', {}) or {}\n    incoming = dict(runtime_kwargs or {})\n    return {**base, **incoming}\n\n  @abc.abstractmethod\n  def infer(\n      self, batch_prompts: Sequence[str], **kwargs\n  ) -> Iterator[Sequence[types.ScoredOutput]]:\n    \"\"\"Implements language model inference.\n\n    Args:\n      batch_prompts: Batch of inputs for inference. Single element list can be\n        used for a single input.\n      **kwargs: Additional arguments for inference, like temperature and\n        max_decode_steps.\n\n    Returns: Batch of Sequence of probable output text outputs, sorted by\n      descending score.\n    \"\"\"\n\n  def infer_batch(\n      self, prompts: Sequence[str], batch_size: int = 32  # pylint: disable=unused-argument\n  ) -> list[list[types.ScoredOutput]]:\n    \"\"\"Batch inference with configurable batch size.\n\n    This is a convenience method that collects all results from infer().\n\n    Args:\n      prompts: List of prompts to process.\n      batch_size: Batch size (currently unused, for future optimization).\n\n    Returns:\n      List of lists of ScoredOutput objects.\n    \"\"\"\n    results = []\n    for output in self.infer(prompts):\n      results.append(list(output))\n    return results\n\n  def parse_output(self, output: str) -> Any:\n    \"\"\"Parses model output as JSON or YAML.\n\n    Note: This expects raw JSON/YAML without code fences.\n    Code fence extraction is handled by resolver.py.\n\n    Args:\n      output: Raw output string from the model.\n\n    Returns:\n      Parsed Python object (dict or list).\n\n    Raises:\n      ValueError: If output cannot be parsed as JSON or YAML.\n    \"\"\"\n    # Check if we have a format_type attribute (providers should set this)\n    format_type = getattr(self, 'format_type', types.FormatType.JSON)\n\n    try:\n      if format_type == types.FormatType.JSON:\n        return json.loads(output)\n      else:\n        return yaml.safe_load(output)\n    except Exception as e:\n      raise ValueError(\n          f'Failed to parse output as {format_type.name}: {str(e)}'\n      ) from e\n"
  },
  {
    "path": "langextract/core/data.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Classes used to represent core data types of annotation pipeline.\"\"\"\nfrom __future__ import annotations\n\nimport dataclasses\nimport enum\nimport uuid\n\nfrom langextract.core import tokenizer\nfrom langextract.core import types\n\nFormatType = types.FormatType  # Backward compat\n\nEXTRACTIONS_KEY = \"extractions\"\nATTRIBUTE_SUFFIX = \"_attributes\"\n\n__all__ = [\n    \"AlignmentStatus\",\n    \"CharInterval\",\n    \"Extraction\",\n    \"Document\",\n    \"AnnotatedDocument\",\n    \"ExampleData\",\n    \"FormatType\",\n    \"EXTRACTIONS_KEY\",\n    \"ATTRIBUTE_SUFFIX\",\n]\n\n\nclass AlignmentStatus(enum.Enum):\n  MATCH_EXACT = \"match_exact\"\n  MATCH_GREATER = \"match_greater\"\n  MATCH_LESSER = \"match_lesser\"\n  MATCH_FUZZY = \"match_fuzzy\"\n\n\n@dataclasses.dataclass\nclass CharInterval:\n  \"\"\"Class for representing a character interval.\n\n  Attributes:\n    start_pos: The starting position of the interval (inclusive).\n    end_pos: The ending position of the interval (exclusive).\n  \"\"\"\n\n  start_pos: int | None = None\n  end_pos: int | None = None\n\n\n@dataclasses.dataclass(init=False)\nclass Extraction:\n  \"\"\"Represents an extraction extracted from text.\n\n  This class encapsulates an extraction's characteristics and its position\n  within the source text. It can represent a diverse range of information for\n  NLP information extraction tasks.\n\n  Attributes:\n    extraction_class: The class of the extraction.\n    extraction_text: The text of the extraction.\n    char_interval: The character interval of the extraction in the original\n      text.\n    alignment_status: The alignment status of the extraction.\n    extraction_index: The index of the extraction in the list of extractions.\n    group_index: The index of the group the extraction belongs to.\n    description: A description of the extraction.\n    attributes: A list of attributes of the extraction.\n    token_interval: The token interval of the extraction.\n  \"\"\"\n\n  extraction_class: str\n  extraction_text: str\n  char_interval: CharInterval | None = None\n  alignment_status: AlignmentStatus | None = None\n  extraction_index: int | None = None\n  group_index: int | None = None\n  description: str | None = None\n  attributes: dict[str, str | list[str]] | None = None\n  _token_interval: tokenizer.TokenInterval | None = dataclasses.field(\n      default=None, repr=False, compare=False\n  )\n\n  def __init__(\n      self,\n      extraction_class: str,\n      extraction_text: str,\n      *,\n      token_interval: tokenizer.TokenInterval | None = None,\n      char_interval: CharInterval | None = None,\n      alignment_status: AlignmentStatus | None = None,\n      extraction_index: int | None = None,\n      group_index: int | None = None,\n      description: str | None = None,\n      attributes: dict[str, str | list[str]] | None = None,\n  ):\n    self.extraction_class = extraction_class\n    self.extraction_text = extraction_text\n    self.char_interval = char_interval\n    self._token_interval = token_interval\n    self.alignment_status = alignment_status\n    self.extraction_index = extraction_index\n    self.group_index = group_index\n    self.description = description\n    self.attributes = attributes\n\n  @property\n  def token_interval(self) -> tokenizer.TokenInterval | None:\n    return self._token_interval\n\n  @token_interval.setter\n  def token_interval(self, value: tokenizer.TokenInterval | None) -> None:\n    self._token_interval = value\n\n\n@dataclasses.dataclass\nclass Document:\n  \"\"\"Document class for annotating documents.\n\n  Attributes:\n    text: Raw text representation for the document.\n    document_id: Unique identifier for each document and is auto-generated if\n      not set.\n    additional_context: Additional context to supplement prompt instructions.\n    tokenized_text: Tokenized text for the document, computed from `text`.\n  \"\"\"\n\n  text: str\n  additional_context: str | None = None\n  _document_id: str | None = dataclasses.field(\n      default=None, init=False, repr=False, compare=False\n  )\n  _tokenized_text: tokenizer.TokenizedText | None = dataclasses.field(\n      init=False, default=None, repr=False, compare=False\n  )\n\n  def __init__(\n      self,\n      text: str,\n      *,\n      document_id: str | None = None,\n      additional_context: str | None = None,\n  ):\n    self.text = text\n    self.additional_context = additional_context\n    self._document_id = document_id\n\n  @property\n  def document_id(self) -> str:\n    \"\"\"Returns the document ID, generating a unique one if not set.\"\"\"\n    if self._document_id is None:\n      self._document_id = f\"doc_{uuid.uuid4().hex[:8]}\"\n    return self._document_id\n\n  @document_id.setter\n  def document_id(self, value: str | None) -> None:\n    \"\"\"Sets the document ID.\"\"\"\n    self._document_id = value\n\n  @property\n  def tokenized_text(self) -> tokenizer.TokenizedText:\n    if self._tokenized_text is None:\n      self._tokenized_text = tokenizer.tokenize(self.text)\n    return self._tokenized_text\n\n  @tokenized_text.setter\n  def tokenized_text(self, value: tokenizer.TokenizedText) -> None:\n    self._tokenized_text = value\n\n\n@dataclasses.dataclass\nclass AnnotatedDocument:\n  \"\"\"Class for representing annotated documents.\n\n  Attributes:\n    document_id: Unique identifier for each document - autogenerated if not\n      set.\n    extractions: List of extractions in the document.\n    text: Raw text representation of the document.\n    tokenized_text: Tokenized text of the document, computed from `text`.\n  \"\"\"\n\n  extractions: list[Extraction] | None = None\n  text: str | None = None\n  _document_id: str | None = dataclasses.field(\n      default=None, init=False, repr=False, compare=False\n  )\n  _tokenized_text: tokenizer.TokenizedText | None = dataclasses.field(\n      init=False, default=None, repr=False, compare=False\n  )\n\n  def __init__(\n      self,\n      *,\n      document_id: str | None = None,\n      extractions: list[Extraction] | None = None,\n      text: str | None = None,\n  ):\n    self.extractions = extractions\n    self.text = text\n    self._document_id = document_id\n\n  @property\n  def document_id(self) -> str:\n    \"\"\"Returns the document ID, generating a unique one if not set.\"\"\"\n    if self._document_id is None:\n      self._document_id = f\"doc_{uuid.uuid4().hex[:8]}\"\n    return self._document_id\n\n  @document_id.setter\n  def document_id(self, value: str | None) -> None:\n    \"\"\"Sets the document ID.\"\"\"\n    self._document_id = value\n\n  @property\n  def tokenized_text(self) -> tokenizer.TokenizedText | None:\n    if self._tokenized_text is None and self.text is not None:\n      self._tokenized_text = tokenizer.tokenize(self.text)\n    return self._tokenized_text\n\n  @tokenized_text.setter\n  def tokenized_text(self, value: tokenizer.TokenizedText) -> None:\n    self._tokenized_text = value\n\n\n@dataclasses.dataclass\nclass ExampleData:\n  \"\"\"A single training/example data instance for a structured prompting.\n\n  Attributes:\n    text: The raw input text (sentence, paragraph, etc.).\n    extractions: A list of Extraction objects extracted from the text.\n  \"\"\"\n\n  text: str\n  extractions: list[Extraction] = dataclasses.field(default_factory=list)\n"
  },
  {
    "path": "langextract/core/debug_utils.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Debug utilities for LangExtract.\"\"\"\nfrom __future__ import annotations\n\nimport functools\nimport inspect\nimport logging\nimport reprlib\nimport time\nfrom typing import Any, Callable, Mapping\n\nfrom absl import logging as absl_logging\n\n_LOG = logging.getLogger(\"langextract.debug\")\n\n# Add NullHandler to prevent \"No handler found\" warnings\n_langextract_logger = logging.getLogger(\"langextract\")\nif not _langextract_logger.handlers:\n  _langextract_logger.addHandler(logging.NullHandler())\n\n# Sensitive keys to redact\n_REDACT_KEYS = {\n    \"api_key\",\n    \"apikey\",\n    \"token\",\n    \"secret\",\n    \"password\",\n    \"authorization\",\n    \"bearer\",\n    \"jwt\",\n}\n_MAX_STR = 500\n_MAX_SEQ = 20\n\n\ndef _safe_repr(obj: Any) -> str:\n  \"\"\"Truncate object repr for safe logging.\"\"\"\n  r = reprlib.Repr()\n  r.maxstring = _MAX_STR\n  r.maxlist = r.maxtuple = r.maxset = r.maxdict = _MAX_SEQ\n  return r.repr(obj)\n\n\ndef _redact_value(name: str, value: Any) -> str:\n  \"\"\"Redact sensitive values based on parameter name.\"\"\"\n  if isinstance(name, str) and name.lower() in _REDACT_KEYS:\n    return \"<REDACTED>\"\n  # If a nested mapping, redact its sensitive keys too\n  if isinstance(value, Mapping):\n    redacted = {}\n    for k, v in value.items():\n      if isinstance(k, str) and k.lower() in _REDACT_KEYS:\n        redacted[k] = \"<REDACTED>\"\n      else:\n        redacted[k] = _safe_repr(v)\n    return _safe_repr(redacted)\n  return _safe_repr(value)\n\n\ndef _redact_mapping(mapping: Mapping[str, Any]) -> dict[str, str]:\n  \"\"\"Replace sensitive values with <REDACTED>.\"\"\"\n  out = {}\n  for k, v in mapping.items():\n    out[k] = _redact_value(k, v)\n  return out\n\n\ndef _format_bound_args(\n    fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]\n) -> str:\n  \"\"\"Format function arguments using signature inspection.\"\"\"\n  try:\n    sig = inspect.signature(fn)\n    bound = sig.bind_partial(*args, **kwargs)\n    bound.apply_defaults()\n  except Exception:\n    # Fallback (no names) if binding fails\n    parts = [_safe_repr(a) for a in args]\n    if kwargs:\n      red = _redact_mapping(kwargs)\n      parts += [f\"{k}={v}\" for k, v in sorted(red.items())]\n    return \", \".join(parts)\n\n  parts: list[str] = []\n  for name, value in bound.arguments.items():\n    if name in (\"self\", \"cls\"):\n      parts.append(f\"{name}=<{type(value).__name__}>\")\n    else:\n      parts.append(f\"{name}={_redact_value(name, value)}\")\n  return \", \".join(parts)\n\n\ndef debug_log_calls(fn: Callable) -> Callable:\n  \"\"\"Log function calls with redacted sensitive data and timing.\n\n  Automatically redacts api_key, token, etc. and truncates large outputs.\n  \"\"\"\n\n  @functools.wraps(fn)\n  def wrapper(*args, **kwargs):\n    logger = _LOG\n    if not logger.isEnabledFor(logging.DEBUG):\n      return fn(*args, **kwargs)\n\n    fn_qual = getattr(fn, \"__qualname__\", fn.__name__)\n    mod = getattr(fn, \"__module__\", \"\")\n\n    # Format arguments using signature inspection\n    arg_str = _format_bound_args(fn, args, kwargs)\n\n    logger.debug(\"[%s] CALL: %s(%s)\", mod, fn_qual, arg_str, stacklevel=2)\n\n    start = time.perf_counter()\n    try:\n      result = fn(*args, **kwargs)\n    except Exception:\n      dur_ms = (time.perf_counter() - start) * 1000\n      logger.exception(\n          \"[%s] EXCEPTION: %s (%.1f ms)\", mod, fn_qual, dur_ms, stacklevel=2\n      )\n      raise\n\n    dur_ms = (time.perf_counter() - start) * 1000\n    result_repr = _safe_repr(result)\n    logger.debug(\n        \"[%s] RETURN: %s -> %s (%.1f ms)\",\n        mod,\n        fn_qual,\n        result_repr,\n        dur_ms,\n        stacklevel=2,\n    )\n    return result\n\n  return wrapper\n\n\ndef configure_debug_logging() -> None:\n  \"\"\"Enable debug logging for the 'langextract' namespace only.\"\"\"\n  logger = logging.getLogger(\"langextract\")\n\n  # Skip if we already added our handler\n  our_handler_exists = any(\n      isinstance(h, logging.StreamHandler)\n      and getattr(h, \"langextract_debug\", False)\n      for h in logger.handlers\n  )\n  if our_handler_exists:\n    return\n\n  # Respect host handlers - only set level if they exist\n  non_null_handlers = [\n      h for h in logger.handlers if not isinstance(h, logging.NullHandler)\n  ]\n\n  if non_null_handlers:\n    logger.setLevel(logging.DEBUG)\n  else:\n    logger.setLevel(logging.DEBUG)\n    handler = logging.StreamHandler()\n    handler.setLevel(logging.DEBUG)\n    fmt = \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n    handler.setFormatter(logging.Formatter(fmt))\n    handler.langextract_debug = True\n    logger.addHandler(handler)\n    logger.propagate = False\n\n  # Best-effort absl configuration\n  try:\n    absl_logging.set_verbosity(absl_logging.DEBUG)\n  except Exception:\n    pass\n"
  },
  {
    "path": "langextract/core/exceptions.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Core error types for LangExtract.\n\nThis module defines all base exceptions for LangExtract. These are the\nfoundational error types that are used throughout the codebase.\n\"\"\"\n\nfrom __future__ import annotations\n\n__all__ = [\n    \"LangExtractError\",\n    \"InferenceError\",\n    \"InferenceConfigError\",\n    \"InferenceRuntimeError\",\n    \"InferenceOutputError\",\n    \"InternalError\",\n    \"InvalidDocumentError\",\n    \"ProviderError\",\n    \"SchemaError\",\n    \"FormatError\",\n    \"FormatParseError\",\n]\n\n\nclass LangExtractError(Exception):\n  \"\"\"Base exception for all LangExtract errors.\n\n  All exceptions raised by LangExtract should inherit from this class.\n  This allows users to catch all LangExtract-specific errors with a single\n  except clause.\n  \"\"\"\n\n\nclass InferenceError(LangExtractError):\n  \"\"\"Base exception for inference-related errors.\"\"\"\n\n\nclass InferenceConfigError(InferenceError):\n  \"\"\"Exception raised for configuration errors.\n\n  This includes missing API keys, invalid model IDs, or other\n  configuration-related issues that prevent model instantiation.\n  \"\"\"\n\n\nclass InferenceRuntimeError(InferenceError):\n  \"\"\"Exception raised for runtime inference errors.\n\n  This includes API call failures, network errors, or other issues\n  that occur during inference execution.\n  \"\"\"\n\n  def __init__(\n      self,\n      message: str,\n      *,\n      original: BaseException | None = None,\n      provider: str | None = None,\n  ) -> None:\n    \"\"\"Initialize the runtime error.\n\n    Args:\n      message: Error message.\n      original: Original exception from the provider SDK.\n      provider: Name of the provider that raised the error.\n    \"\"\"\n    super().__init__(message)\n    self.original = original\n    self.provider = provider\n\n\nclass InferenceOutputError(LangExtractError):\n  \"\"\"Exception raised when no scored outputs are available from the language model.\"\"\"\n\n  def __init__(self, message: str):\n    self.message = message\n    super().__init__(self.message)\n\n\nclass InvalidDocumentError(LangExtractError):\n  \"\"\"Exception raised when document input is invalid.\n\n  This includes cases like duplicate document IDs or malformed documents.\n  \"\"\"\n\n\nclass InternalError(LangExtractError):\n  \"\"\"Exception raised for internal invariant violations.\n\n  This indicates a bug in LangExtract itself rather than user error.\n  \"\"\"\n\n\nclass ProviderError(LangExtractError):\n  \"\"\"Provider/backend specific error.\"\"\"\n\n\nclass SchemaError(LangExtractError):\n  \"\"\"Schema validation/serialization error.\"\"\"\n\n\nclass FormatError(LangExtractError):\n  \"\"\"Base exception for format handling errors.\"\"\"\n\n\nclass FormatParseError(FormatError):\n  \"\"\"Raised when format parsing fails.\n\n  This consolidates all parsing errors including:\n  - Missing fence markers when required\n  - Multiple fenced blocks\n  - JSON/YAML decode errors\n  - Missing wrapper keys\n  - Invalid structure\n  \"\"\"\n"
  },
  {
    "path": "langextract/core/format_handler.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Centralized format handler for prompts and parsing.\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport re\nfrom typing import Mapping, Sequence\nimport warnings\n\nimport yaml\n\nfrom langextract.core import data\nfrom langextract.core import exceptions\n\nExtractionValueType = str | int | float | dict | list | None\n\n_JSON_FORMAT = \"json\"\n_YAML_FORMAT = \"yaml\"\n_YML_FORMAT = \"yml\"\n\n_FENCE_START = r\"```\"\n_LANGUAGE_TAG = r\"(?P<lang>[A-Za-z0-9_+-]+)?\"\n_FENCE_NEWLINE = r\"(?:\\s*\\n)?\"\n_FENCE_BODY = r\"(?P<body>[\\s\\S]*?)\"\n_FENCE_END = r\"```\"\n\n_FENCE_RE = re.compile(\n    _FENCE_START + _LANGUAGE_TAG + _FENCE_NEWLINE + _FENCE_BODY + _FENCE_END,\n    re.MULTILINE,\n)\n\n_THINK_TAG_RE = re.compile(r\"<think>[\\s\\S]*?</think>\\s*\", re.IGNORECASE)\n\n\nclass FormatHandler:\n  \"\"\"Handles all format-specific logic for prompts and parsing.\n\n  This class centralizes format handling for JSON and YAML outputs,\n  including fence detection, wrapper management, and parsing.\n\n  Attributes:\n    format_type: The output format ('json' or 'yaml').\n    use_wrapper: Whether to wrap extractions in a container dictionary.\n    wrapper_key: The key name for the container dictionary (e.g., creates\n      {\"extractions\": [...]} instead of just [...]).\n    use_fences: Whether to use code fences in formatted output.\n    attribute_suffix: Suffix for attribute fields in extractions.\n    strict_fences: Whether to enforce strict fence validation.\n    allow_top_level_list: Whether to allow top-level lists in parsing.\n  \"\"\"\n\n  def __init__(\n      self,\n      format_type: data.FormatType = data.FormatType.JSON,\n      use_wrapper: bool = True,\n      wrapper_key: str | None = None,\n      use_fences: bool = True,\n      attribute_suffix: str = data.ATTRIBUTE_SUFFIX,\n      strict_fences: bool = False,\n      allow_top_level_list: bool = True,\n  ) -> None:\n    \"\"\"Initialize format handler.\n\n    Args:\n      format_type: Output format type enum.\n      use_wrapper: Whether to wrap extractions in a container dictionary.\n        True: {\"extractions\": [...]}, False: [...]\n      wrapper_key: Key name for the container dictionary. When use_wrapper=True:\n        - If None: defaults to EXTRACTIONS_KEY (\"extractions\")\n        - If provided: uses the specified key as container\n        When use_wrapper=False, this parameter is ignored.\n      use_fences: Whether to use ```json or ```yaml fences.\n      attribute_suffix: Suffix for attribute fields.\n      strict_fences: If True, require exact fence format. If False, be lenient\n        with model output variations.\n      allow_top_level_list: Allow top-level list when not strict and\n        wrapper not required.\n    \"\"\"\n    self.format_type = format_type\n    self.use_wrapper = use_wrapper\n    if use_wrapper:\n      self.wrapper_key = (\n          wrapper_key if wrapper_key is not None else data.EXTRACTIONS_KEY\n      )\n    else:\n      self.wrapper_key = None\n    self.use_fences = use_fences\n    self.attribute_suffix = attribute_suffix\n    self.strict_fences = strict_fences\n    self.allow_top_level_list = allow_top_level_list\n\n  def __repr__(self) -> str:\n    return (\n        \"FormatHandler(\"\n        f\"format_type={self.format_type!r}, use_wrapper={self.use_wrapper}, \"\n        f\"wrapper_key={self.wrapper_key!r}, use_fences={self.use_fences}, \"\n        f\"attribute_suffix={self.attribute_suffix!r}, \"\n        f\"strict_fences={self.strict_fences}, \"\n        f\"allow_top_level_list={self.allow_top_level_list})\"\n    )\n\n  def format_extraction_example(\n      self, extractions: list[data.Extraction]\n  ) -> str:\n    \"\"\"Format extractions for a prompt example.\n\n    Args:\n      extractions: List of extractions to format\n\n    Returns:\n      Formatted string for the prompt\n    \"\"\"\n    items = [\n        {\n            ext.extraction_class: ext.extraction_text,\n            f\"{ext.extraction_class}{self.attribute_suffix}\": (\n                ext.attributes or {}\n            ),\n        }\n        for ext in extractions\n    ]\n\n    if self.use_wrapper and self.wrapper_key:\n      payload = {self.wrapper_key: items}\n    else:\n      payload = items\n\n    if self.format_type == data.FormatType.YAML:\n      formatted = yaml.safe_dump(\n          payload, default_flow_style=False, sort_keys=False\n      )\n    else:\n      formatted = json.dumps(payload, indent=2, ensure_ascii=False)\n\n    return self._add_fences(formatted) if self.use_fences else formatted\n\n  def parse_output(\n      self, text: str, *, strict: bool | None = None\n  ) -> Sequence[Mapping[str, ExtractionValueType]]:\n    \"\"\"Parse model output to extract data.\n\n    Args:\n      text: Raw model output.\n      strict: If True, enforce strict schema validation. When strict is\n        True, always require wrapper object if wrapper_key is configured,\n        reject top-level lists even if allow_top_level_list is True, and\n        enforce exact format compliance.\n\n    Returns:\n      List of extraction dictionaries.\n\n    Raises:\n      FormatError: Various subclasses for specific parsing failures.\n    \"\"\"\n    if not text:\n      raise exceptions.FormatParseError(\"Empty or invalid input string.\")\n\n    content = self._extract_content(text)\n\n    try:\n      parsed = self._parse_with_fallback(content, strict)\n    except (yaml.YAMLError, json.JSONDecodeError) as e:\n      msg = (\n          f\"Failed to parse {self.format_type.value.upper()} content:\"\n          f\" {str(e)[:200]}\"\n      )\n      raise exceptions.FormatParseError(msg) from e\n\n    if parsed is None:\n      if self.use_wrapper:\n        raise exceptions.FormatParseError(\n            f\"Content must be a mapping with an '{self.wrapper_key}' key.\"\n        )\n      else:\n        raise exceptions.FormatParseError(\n            \"Content must be a list of extractions or a dict.\"\n        )\n\n    require_wrapper = self.wrapper_key is not None and (\n        self.use_wrapper or bool(strict)\n    )\n\n    if isinstance(parsed, dict):\n      if require_wrapper:\n        if self.wrapper_key not in parsed:\n          raise exceptions.FormatParseError(\n              f\"Content must contain an '{self.wrapper_key}' key.\"\n          )\n        items = parsed[self.wrapper_key]\n      else:\n        if data.EXTRACTIONS_KEY in parsed:\n          items = parsed[data.EXTRACTIONS_KEY]\n        elif self.wrapper_key and self.wrapper_key in parsed:\n          items = parsed[self.wrapper_key]\n        else:\n          items = [parsed]\n    elif isinstance(parsed, list):\n      if require_wrapper and (strict or not self.allow_top_level_list):\n        raise exceptions.FormatParseError(\n            f\"Content must be a mapping with an '{self.wrapper_key}' key.\"\n        )\n      if strict and self.use_wrapper:\n        raise exceptions.FormatParseError(\n            \"Strict mode requires a wrapper object.\"\n        )\n      if not self.allow_top_level_list:\n        raise exceptions.FormatParseError(\"Top-level list is not allowed.\")\n      # Some models return [...] instead of {\"extractions\": [...]}.\n      items = parsed\n    else:\n      raise exceptions.FormatParseError(\n          f\"Expected list or dict, got {type(parsed)}\"\n      )\n\n    if not isinstance(items, list):\n      raise exceptions.FormatParseError(\n          \"The extractions must be a sequence (list) of mappings.\"\n      )\n\n    for item in items:\n      if not isinstance(item, dict):\n        raise exceptions.FormatParseError(\n            \"Each item in the sequence must be a mapping.\"\n        )\n      for k in item.keys():\n        if not isinstance(k, str):\n          raise exceptions.FormatParseError(\n              \"All extraction keys must be strings (got a non-string key).\"\n          )\n\n    return items\n\n  def _add_fences(self, content: str) -> str:\n    \"\"\"Add code fences around content.\"\"\"\n    fence_type = self.format_type.value\n    return f\"```{fence_type}\\n{content.strip()}\\n```\"\n\n  def _is_valid_language_tag(\n      self, lang: str | None, valid_tags: dict[data.FormatType, set[str]]\n  ) -> bool:\n    \"\"\"Check if language tag is valid for the format type.\"\"\"\n    if lang is None:\n      return True\n    tag = lang.strip().lower()\n    return tag in valid_tags.get(self.format_type, set())\n\n  def _parse_with_fallback(self, content: str, strict: bool):\n    \"\"\"Parse content, retrying without <think> tags on failure.\"\"\"\n    try:\n      if self.format_type == data.FormatType.YAML:\n        return yaml.safe_load(content)\n      return json.loads(content)\n    except (yaml.YAMLError, json.JSONDecodeError):\n      if strict:\n        raise\n      # Reasoning models (DeepSeek-R1, QwQ) emit <think> tags before JSON.\n      if _THINK_TAG_RE.search(content):\n        stripped = _THINK_TAG_RE.sub(\"\", content).strip()\n        if self.format_type == data.FormatType.YAML:\n          return yaml.safe_load(stripped)\n        return json.loads(stripped)\n      raise\n\n  def _extract_content(self, text: str) -> str:\n    \"\"\"Extract content from text, handling fences if configured.\n\n    Args:\n      text: Input text that may contain fenced blocks\n\n    Returns:\n      Extracted content\n\n    Raises:\n      FormatParseError: When fences required but not found or multiple\n        blocks found.\n    \"\"\"\n    if not self.use_fences:\n      return text.strip()\n\n    matches = list(_FENCE_RE.finditer(text))\n\n    valid_tags = {\n        data.FormatType.YAML: {_YAML_FORMAT, _YML_FORMAT},\n        data.FormatType.JSON: {_JSON_FORMAT},\n    }\n\n    candidates = [\n        m\n        for m in matches\n        if self._is_valid_language_tag(m.group(\"lang\"), valid_tags)\n    ]\n\n    if self.strict_fences:\n      if len(candidates) != 1:\n        if len(candidates) == 0:\n          raise exceptions.FormatParseError(\n              \"Input string does not contain valid fence markers.\"\n          )\n        else:\n          raise exceptions.FormatParseError(\n              \"Multiple fenced blocks found. Expected exactly one.\"\n          )\n      return candidates[0].group(\"body\").strip()\n\n    if len(candidates) == 1:\n      return candidates[0].group(\"body\").strip()\n    elif len(candidates) > 1:\n      raise exceptions.FormatParseError(\n          \"Multiple fenced blocks found. Expected exactly one.\"\n      )\n\n    if matches:\n      if not self.strict_fences and len(matches) == 1:\n        return matches[0].group(\"body\").strip()\n      raise exceptions.FormatParseError(\n          f\"No {self.format_type.value} code block found.\"\n      )\n\n    return text.strip()\n\n  # ---- Backward compatibility methods (to be removed in v2.0.0) ----\n\n  _LEGACY_FORMAT_KEYS = frozenset({\n      \"fence_output\",\n      \"format_type\",\n      \"strict_fences\",\n      \"require_extractions_key\",\n      \"extraction_attributes_suffix\",\n      \"attribute_suffix\",\n      \"format_handler\",\n  })\n\n  @classmethod\n  def from_resolver_params(\n      cls,\n      *,\n      resolver_params: dict | None,\n      base_format_type: data.FormatType,\n      base_use_fences: bool,\n      base_attribute_suffix: str = data.ATTRIBUTE_SUFFIX,\n      base_use_wrapper: bool = True,\n      base_wrapper_key: str | None = data.EXTRACTIONS_KEY,\n      warn_on_legacy: bool = True,\n  ) -> tuple[FormatHandler, dict]:\n    \"\"\"Create FormatHandler from resolver_params with legacy support.\n\n    This method handles backward compatibility for legacy resolver parameters\n    and will be removed in v2.0.0.\n\n    Args:\n      resolver_params: May contain legacy keys or a 'format_handler'.\n      base_format_type: Default format when not overridden.\n      base_use_fences: Default fence usage from the model.\n      base_attribute_suffix: Default attribute suffix.\n      base_use_wrapper: Default wrapper behavior.\n      base_wrapper_key: Default wrapper key.\n      warn_on_legacy: If True, emit DeprecationWarnings.\n\n    Returns:\n      (format_handler, remaining_resolver_params)\n    \"\"\"\n    rp = dict(resolver_params or {})\n\n    if rp.get(\"format_handler\") is not None:\n      handler = rp.pop(\"format_handler\")\n      for k in list(rp.keys()):\n        if k in cls._LEGACY_FORMAT_KEYS:\n          rp.pop(k, None)\n      return handler, rp\n\n    kwargs = {\n        \"format_type\": base_format_type,\n        \"use_fences\": base_use_fences,\n        \"attribute_suffix\": base_attribute_suffix,\n        \"use_wrapper\": base_use_wrapper,\n        \"wrapper_key\": base_wrapper_key if base_use_wrapper else None,\n    }\n\n    mapping = {\n        \"fence_output\": \"use_fences\",\n        \"format_type\": \"format_type\",\n        \"strict_fences\": \"strict_fences\",\n        \"require_extractions_key\": \"use_wrapper\",\n        \"extraction_attributes_suffix\": \"attribute_suffix\",\n        \"attribute_suffix\": \"attribute_suffix\",\n    }\n\n    used_legacy = []\n    for legacy_key, fh_key in mapping.items():\n      if legacy_key in rp and rp[legacy_key] is not None:\n        val = rp.pop(legacy_key)\n        if fh_key == \"format_type\" and hasattr(val, \"value\"):\n          val = val.value\n        kwargs[fh_key] = val\n        used_legacy.append(legacy_key)\n\n    if warn_on_legacy and used_legacy:\n      warnings.warn(\n          \"Resolver legacy params are deprecated and will be removed in\"\n          f\" v2.0.0: {used_legacy}. Pass a FormatHandler explicitly via\"\n          \" `resolver_params={'format_handler': FormatHandler(...)}` or rely\"\n          \" on defaults configured by the model.\",\n          DeprecationWarning,\n          stacklevel=3,\n      )\n\n    handler = cls(**kwargs)\n    return handler, rp\n\n  @classmethod\n  def from_kwargs(cls, **kwargs) -> FormatHandler:\n    \"\"\"Create FormatHandler from legacy resolver keyword arguments.\n\n    This method will be removed in v2.0.0.\n\n    Args:\n      **kwargs: Legacy parameters like fence_output, format_type, etc.\n\n    Returns:\n      FormatHandler configured with legacy parameters.\n    \"\"\"\n    legacy_params = {\n        \"fence_output\",\n        \"format_type\",\n        \"strict_fences\",\n        \"require_extractions_key\",\n    }\n    used_legacy = legacy_params.intersection(kwargs.keys())\n\n    if used_legacy:\n      warnings.warn(\n          f\"Using legacy Resolver parameters {used_legacy} is deprecated. \"\n          \"Please use FormatHandler directly. \"\n          \"This compatibility layer will be removed in v2.0.0.\",\n          DeprecationWarning,\n          stacklevel=3,\n      )\n\n    fence_output = kwargs.pop(\"fence_output\", True)\n    format_type = kwargs.pop(\"format_type\", None)\n    strict_fences = kwargs.pop(\"strict_fences\", False)\n    require_extractions_key = kwargs.pop(\"require_extractions_key\", True)\n    attribute_suffix = kwargs.pop(\"attribute_suffix\", data.ATTRIBUTE_SUFFIX)\n\n    if format_type is None:\n      format_type = data.FormatType.JSON\n    elif hasattr(format_type, \"value\"):\n      pass\n    else:\n      format_type = (\n          data.FormatType.JSON\n          if str(format_type).lower() == \"json\"\n          else data.FormatType.YAML\n      )\n\n    return cls(\n        format_type=format_type,\n        use_wrapper=require_extractions_key,\n        wrapper_key=data.EXTRACTIONS_KEY if require_extractions_key else None,\n        use_fences=fence_output,\n        strict_fences=strict_fences,\n        attribute_suffix=attribute_suffix,\n    )\n"
  },
  {
    "path": "langextract/core/schema.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Core schema abstractions for LangExtract.\"\"\"\nfrom __future__ import annotations\n\nimport abc\nfrom collections.abc import Sequence\nfrom typing import Any\n\nfrom langextract.core import data\nfrom langextract.core import format_handler as fh\nfrom langextract.core import types\n\n__all__ = [\n    \"ConstraintType\",\n    \"Constraint\",\n    \"BaseSchema\",\n    \"FormatModeSchema\",\n]\n\n# Backward compat re-exports\nConstraintType = types.ConstraintType\nConstraint = types.Constraint\n\n\nclass BaseSchema(abc.ABC):\n  \"\"\"Abstract base class for generating structured constraints from examples.\"\"\"\n\n  @classmethod\n  @abc.abstractmethod\n  def from_examples(\n      cls,\n      examples_data: Sequence[data.ExampleData],\n      attribute_suffix: str = data.ATTRIBUTE_SUFFIX,\n  ) -> BaseSchema:\n    \"\"\"Factory method to build a schema instance from example data.\"\"\"\n\n  @abc.abstractmethod\n  def to_provider_config(self) -> dict[str, Any]:\n    \"\"\"Convert schema to provider-specific configuration.\n\n    Returns:\n      Dictionary of provider kwargs (e.g., response_schema for Gemini).\n      Should be a pure data mapping with no side effects.\n    \"\"\"\n\n  @property\n  @abc.abstractmethod\n  def requires_raw_output(self) -> bool:\n    \"\"\"Whether this schema outputs raw JSON/YAML without fence markers.\n\n    When True, the provider emits syntactically valid JSON directly.\n    When False, the provider needs fence markers for structure.\n    \"\"\"\n\n  def validate_format(self, format_handler: fh.FormatHandler) -> None:\n    \"\"\"Validate format compatibility and warn about issues.\n\n    Override in subclasses to check format settings.\n    Default implementation does nothing (no validation needed).\n\n    Args:\n      format_handler: The format configuration to validate.\n    \"\"\"\n\n  def sync_with_provider_kwargs(self, kwargs: dict[str, Any]) -> None:\n    \"\"\"Hook to update schema state based on provider kwargs.\n\n    This allows schemas to adjust their behavior based on caller overrides.\n    For example, FormatModeSchema uses this to sync its format when the caller\n    overrides it, ensuring requires_raw_output stays accurate.\n\n    Default implementation does nothing. Override if your schema needs to\n    respond to provider kwargs.\n\n    Args:\n      kwargs: The effective provider kwargs after merging.\n    \"\"\"\n\n\nclass FormatModeSchema(BaseSchema):\n  \"\"\"Generic schema for providers that support format modes (JSON/YAML).\n\n  This schema doesn't enforce structure, only output format. Useful for\n  providers that can guarantee syntactically valid JSON or YAML but don't\n  support field-level constraints.\n  \"\"\"\n\n  def __init__(self, format_type: types.FormatType = types.FormatType.JSON):\n    \"\"\"Initialize with a format type.\"\"\"\n    self.format_type = format_type\n    # Keep _format for backward compatibility with tests\n    self._format = \"json\" if format_type == types.FormatType.JSON else \"yaml\"\n\n  @classmethod\n  def from_examples(\n      cls,\n      examples_data: Sequence[data.ExampleData],\n      attribute_suffix: str = data.ATTRIBUTE_SUFFIX,\n  ) -> FormatModeSchema:\n    \"\"\"Factory method to build a schema instance from example data.\"\"\"\n    # Default to JSON format\n    return cls(format_type=types.FormatType.JSON)\n\n  def to_provider_config(self) -> dict[str, Any]:\n    \"\"\"Convert schema to provider-specific configuration.\"\"\"\n    return {\"format\": self._format}\n\n  @property\n  def requires_raw_output(self) -> bool:\n    \"\"\"JSON format schemas output raw JSON without fences, YAML does not.\"\"\"\n    return self._format == \"json\"\n\n  def sync_with_provider_kwargs(self, kwargs: dict[str, Any]) -> None:\n    \"\"\"Sync format type with provider kwargs.\"\"\"\n    if \"format_type\" in kwargs:\n      self.format_type = kwargs[\"format_type\"]\n      self._format = (\n          \"json\" if self.format_type == types.FormatType.JSON else \"yaml\"\n      )\n    if \"format\" in kwargs:\n      self._format = kwargs[\"format\"]\n      self.format_type = (\n          types.FormatType.JSON\n          if self._format == \"json\"\n          else types.FormatType.YAML\n      )\n"
  },
  {
    "path": "langextract/core/tokenizer.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tokenization utilities for text.\n\nProvides methods to split text into regex-based or Unicode-aware tokens.\nTokenization is used for alignment in `resolver.py` and for determining\nsentence boundaries for smaller context use cases. This module is not used\nfor tokenization within the language model during inference.\n\"\"\"\n\nimport abc\nfrom collections.abc import Sequence, Set\nimport dataclasses\nimport enum\nimport functools\nimport unicodedata\n\nimport regex\n\nfrom langextract.core import debug_utils\nfrom langextract.core import exceptions\n\n__all__ = [\n    \"BaseTokenizerError\",\n    \"InvalidTokenIntervalError\",\n    \"SentenceRangeError\",\n    \"CharInterval\",\n    \"TokenInterval\",\n    \"TokenType\",\n    \"Token\",\n    \"TokenizedText\",\n    \"Tokenizer\",\n    \"RegexTokenizer\",\n    \"UnicodeTokenizer\",\n    \"tokenize\",\n    \"tokens_text\",\n    \"find_sentence_range\",\n]\n\n\nclass BaseTokenizerError(exceptions.LangExtractError):\n  \"\"\"Base class for all tokenizer-related errors.\"\"\"\n\n\nclass InvalidTokenIntervalError(BaseTokenizerError):\n  \"\"\"Error raised when a token interval is invalid or out of range.\"\"\"\n\n\nclass SentenceRangeError(BaseTokenizerError):\n  \"\"\"Error raised when the start token index for a sentence is out of range.\"\"\"\n\n\n@dataclasses.dataclass(slots=True)\nclass CharInterval:\n  \"\"\"Represents a range of character positions in the original text.\n\n  Attributes:\n    start_pos: The starting character index (inclusive).\n    end_pos: The ending character index (exclusive).\n  \"\"\"\n\n  start_pos: int\n  end_pos: int\n\n\n@dataclasses.dataclass(slots=True)\nclass TokenInterval:\n  \"\"\"Represents an interval over tokens in tokenized text.\n\n  The interval is defined by a start index (inclusive) and an end index\n  (exclusive).\n\n  Attributes:\n    start_index: The index of the first token in the interval.\n    end_index: The index one past the last token in the interval.\n  \"\"\"\n\n  start_index: int = 0\n  end_index: int = 0\n\n\nclass TokenType(enum.IntEnum):\n  \"\"\"Enumeration of token types produced during tokenization.\n\n  Attributes:\n    WORD: Represents an alphabetical word token.\n    NUMBER: Represents a numeric token.\n    PUNCTUATION: Represents punctuation characters.\n  \"\"\"\n\n  WORD = 0\n  NUMBER = 1\n  PUNCTUATION = 2\n\n\n@dataclasses.dataclass(slots=True)\nclass Token:\n  \"\"\"Represents a token extracted from text.\n\n  Each token is assigned an index and classified into a type (word, number,\n  punctuation, or acronym). The token also records the range of characters\n  (its CharInterval) that correspond to the substring from the original text.\n  Additionally, it tracks whether it follows a newline.\n\n  Attributes:\n    index: The position of the token in the sequence of tokens.\n    token_type: The type of the token, as defined by TokenType.\n    char_interval: The character interval within the original text that this\n      token spans.\n    first_token_after_newline: True if the token immediately follows a newline\n      or carriage return.\n  \"\"\"\n\n  index: int\n  token_type: TokenType\n  char_interval: CharInterval = dataclasses.field(\n      default_factory=lambda: CharInterval(0, 0)\n  )\n  first_token_after_newline: bool = False\n\n\n@dataclasses.dataclass\nclass TokenizedText:\n  \"\"\"Holds the result of tokenizing a text string.\n\n  Attributes:\n    text: The text that was tokenized. For UnicodeTokenizer, this is\n      NOT normalized to NFC (to preserve indices).\n    tokens: A list of Token objects extracted from the text.\n  \"\"\"\n\n  text: str\n  tokens: list[Token] = dataclasses.field(default_factory=list)\n\n\n_LETTERS_PATTERN = r\"[^\\W\\d_]+\"\n_DIGITS_PATTERN = r\"\\d+\"\n# Group identical symbols (e.g. \"!!\") but split mixed ones.\n_SYMBOLS_PATTERN = r\"([^\\w\\s]|_)\\1*\"\n_END_OF_SENTENCE_PATTERN = regex.compile(r\"[.?!。！？\\u0964][\\\"'”’»)\\]}]*$\")\n\n_TOKEN_PATTERN = regex.compile(\n    rf\"{_LETTERS_PATTERN}|{_DIGITS_PATTERN}|{_SYMBOLS_PATTERN}\"\n)\n_WORD_PATTERN = regex.compile(rf\"(?:{_LETTERS_PATTERN}|{_DIGITS_PATTERN})\\Z\")\n\n# Abbreviations that do not end sentences.\n# TODO: Evaluate removal for large-context use cases.\n_KNOWN_ABBREVIATIONS = frozenset({\"Mr.\", \"Mrs.\", \"Ms.\", \"Dr.\", \"Prof.\", \"St.\"})\n_CLOSING_PUNCTUATION = frozenset({'\"', \"'\", \"”\", \"’\", \"»\", \")\", \"]\", \"}\"})\n\n\nclass Tokenizer(abc.ABC):\n  \"\"\"Abstract base class for tokenizers.\"\"\"\n\n  @abc.abstractmethod\n  def tokenize(self, text: str) -> TokenizedText:\n    \"\"\"Splits text into tokens.\n\n    Args:\n      text: The text to tokenize.\n\n    Returns:\n      A TokenizedText object.\n    \"\"\"\n\n\nclass RegexTokenizer(Tokenizer):\n  \"\"\"Regex-based tokenizer (default).\n\n  The RegexTokenizer is faster than UnicodeTokenizer for English text because it\n  skips involved Unicode handling.\n  \"\"\"\n\n  @debug_utils.debug_log_calls\n  def tokenize(self, text: str) -> TokenizedText:\n    \"\"\"Splits text into tokens (words, digits, or punctuation).\n\n    Each token is annotated with its character position and type. Tokens\n    following a newline or carriage return have `first_token_after_newline`\n    set to True.\n\n    Args:\n      text: The text to tokenize.\n\n    Returns:\n      A TokenizedText object containing all extracted tokens.\n    \"\"\"\n    tokenized = TokenizedText(text=text)\n    previous_end = 0\n    for token_index, match in enumerate(_TOKEN_PATTERN.finditer(text)):\n      start_pos, end_pos = match.span()\n      matched_text = match.group()\n      token = Token(\n          index=token_index,\n          char_interval=CharInterval(start_pos=start_pos, end_pos=end_pos),\n          token_type=TokenType.WORD,\n          first_token_after_newline=False,\n      )\n      if token_index > 0:\n        # Optimization: Check gap without slicing.\n        has_newline = text.find(\"\\n\", previous_end, start_pos) != -1\n        if not has_newline:\n          has_newline = text.find(\"\\r\", previous_end, start_pos) != -1\n        if has_newline:\n          token.first_token_after_newline = True\n      if regex.fullmatch(_DIGITS_PATTERN, matched_text):\n        token.token_type = TokenType.NUMBER\n      elif _WORD_PATTERN.fullmatch(matched_text):\n        token.token_type = TokenType.WORD\n      else:\n        token.token_type = TokenType.PUNCTUATION\n      tokenized.tokens.append(token)\n      previous_end = end_pos\n    return tokenized\n\n\n# Default tokenizer instance for backward compatibility\n_DEFAULT_TOKENIZER = RegexTokenizer()\n\n\ndef tokenize(\n    text: str, tokenizer: Tokenizer = _DEFAULT_TOKENIZER\n) -> TokenizedText:\n  \"\"\"Splits text into tokens using the provided tokenizer (default: RegexTokenizer).\n\n  Args:\n    text: The text to tokenize.\n    tokenizer: The tokenizer instance to use.\n\n  Returns:\n    A TokenizedText object.\n  \"\"\"\n  return tokenizer.tokenize(text)\n\n\n_CJK_PATTERN = regex.compile(\n    r\"\\p{Is_Han}|\\p{Is_Hiragana}|\\p{Is_Katakana}|\\p{Is_Hangul}\"\n)\n_NON_SPACED_PATTERN = regex.compile(\n    r\"\\p{Is_Thai}|\\p{Is_Lao}|\\p{Is_Khmer}|\\p{Is_Myanmar}\"\n)\n\n\nclass Sentinel:\n  \"\"\"Sentinel class for unique object identification.\"\"\"\n\n  def __init__(self, name: str):\n    self.name = name\n\n  def __repr__(self) -> str:\n    return f\"<{self.name}>\"\n\n\n_NO_GROUP_SCRIPT = Sentinel(\"NO_GROUP\")\n_UNKNOWN_SCRIPT = Sentinel(\"UNKNOWN\")\n_LATIN_SCRIPT = \"Latin\"\n\n\n# Optimization: Direct mapping for common scripts avoids regex overhead.\ndef _get_script_fast(char: str) -> str | Sentinel:\n  # Fast path for ASCII: Avoids regex and unicodedata lookups.\n  if ord(char) < 128:\n    return _LATIN_SCRIPT\n\n  # Fallback to the robust regex method\n  return _get_common_script_cached(char)\n\n\ndef _classify_grapheme(g: str) -> TokenType:\n  if not g:\n    return TokenType.PUNCTUATION\n  c = g[0]\n  cat = unicodedata.category(c)\n  if cat.startswith(\"L\"):\n    return TokenType.WORD\n  if cat.startswith(\"N\"):\n    return TokenType.NUMBER\n  return TokenType.PUNCTUATION\n\n\n_COMMON_SCRIPTS = [\n    \"Latin\",\n    \"Cyrillic\",\n    \"Greek\",\n    \"Arabic\",\n    \"Hebrew\",\n    \"Devanagari\",\n]\n\n_COMMON_SCRIPTS_PATTERN = regex.compile(\n    \"|\".join(\n        rf\"(?P<{script}>\\p{{Script={script}}})\" for script in _COMMON_SCRIPTS\n    )\n)\n\n_GRAPHEME_CLUSTER_PATTERN = regex.compile(r\"\\X\")\n\n\n@functools.lru_cache(maxsize=4096)\ndef _get_common_script_cached(c: str) -> str | Sentinel:\n  \"\"\"Determines script using regex, cached for performance.\"\"\"\n  match = _COMMON_SCRIPTS_PATTERN.match(c)\n  if match:\n    return match.lastgroup\n  return _UNKNOWN_SCRIPT\n\n\nclass UnicodeTokenizer(Tokenizer):\n  \"\"\"Unicode-aware tokenizer for better non-English support.\n\n  This tokenizer uses Unicode character properties (Unicode Standard Annex #29)\n  via the `regex` library's `\\\\X` pattern to correctly handle grapheme clusters\n  like Emojis and Hangul.\n\n\n  Unlike some Unicode tokenizers, this class does NOT normalize text to NFC.\n  This ensures that token indices exactly match the original input string.\n\n  Note: Grapheme clustering makes this tokenizer slower than RegexTokenizer.\n  \"\"\"\n\n  @debug_utils.debug_log_calls\n  def tokenize(self, text: str) -> TokenizedText:\n    \"\"\"Splits text into tokens using Unicode properties.\n\n    Args:\n      text: The text to tokenize.\n\n    Returns:\n      A TokenizedText object.\n    \"\"\"\n    tokens: list[Token] = []\n\n    current_start = 0\n    current_type = None\n    current_script = None\n    previous_end = 0\n\n    for match in regex.finditer(r\"\\X\", text):\n      grapheme = match.group()\n      start, _ = match.span()\n\n      # 1. Handle Whitespace\n      if grapheme.isspace():\n        if current_type is not None:\n          self._emit_token(\n              tokens, text, current_start, start, current_type, previous_end\n          )\n          previous_end = start\n          current_type = None\n          current_script = None\n        # Keep `previous_end` to detect newlines within the whitespace gap.\n        continue\n\n      g_type = _classify_grapheme(grapheme)\n\n      # 2. Determine if we should merge with the current token\n      should_merge = False\n      if current_type is not None:\n        if current_type == g_type:\n          if current_type == TokenType.WORD:\n            # Script Check\n            first_char = grapheme[0]\n\n            # Fast path: Explicit NO_GROUP (CJK/Thai) never merges.\n            if current_script is _NO_GROUP_SCRIPT:\n              should_merge = False\n\n            # CJK and Non-Spaced scripts require fragmentation.\n            elif _CJK_PATTERN.match(first_char) or _NON_SPACED_PATTERN.match(\n                first_char\n            ):\n              should_merge = False\n\n            else:\n              g_script = _get_script_fast(first_char)\n              # Safety: Do not merge distinct unknown scripts.\n              if (\n                  current_script == g_script\n                  and current_script is not _UNKNOWN_SCRIPT\n              ):\n                should_merge = True\n\n          elif current_type == TokenType.NUMBER:\n            should_merge = True\n\n          elif current_type == TokenType.PUNCTUATION:\n            # Heuristic: Merge punctuation only if identical (e.g. \"!!\").\n            last_grapheme = text[current_start:start]\n            if last_grapheme == grapheme:\n              should_merge = True\n            elif len(last_grapheme) >= len(grapheme) and last_grapheme.endswith(\n                grapheme\n            ):\n              should_merge = True\n\n      # 3. State Transition\n      if should_merge:\n        # Extend current token\n        pass\n      else:\n        # Flush previous token if exists\n        if current_type is not None:\n          self._emit_token(\n              tokens, text, current_start, start, current_type, previous_end\n          )\n          previous_end = start\n\n        # Start new token\n        current_start = start\n        current_type = g_type\n\n        # Determine script for the new token\n        if current_type == TokenType.WORD:\n          c = grapheme[0]\n          if _CJK_PATTERN.match(c) or _NON_SPACED_PATTERN.match(c):\n            current_script = _NO_GROUP_SCRIPT\n          else:\n            current_script = _get_script_fast(c)\n        else:\n          current_script = None\n\n    # 4. Flush final token\n    if current_type is not None:\n      self._emit_token(\n          tokens, text, current_start, len(text), current_type, previous_end\n      )\n\n    return TokenizedText(text=text, tokens=tokens)\n\n  def _emit_token(\n      self,\n      tokens: list[Token],\n      text: str,\n      start: int,\n      end: int,\n      token_type: TokenType,\n      previous_end: int,\n  ):\n    \"\"\"Helper to create and append a token.\"\"\"\n    token = Token(\n        index=len(tokens),\n        char_interval=CharInterval(start_pos=start, end_pos=end),\n        token_type=token_type,\n        first_token_after_newline=False,\n    )\n\n    # Check for newlines in the gap between the previous token and this one\n    if start > previous_end:\n      gap = text[previous_end:start]\n      if \"\\n\" in gap or \"\\r\" in gap:\n        token.first_token_after_newline = True\n\n    tokens.append(token)\n\n\ndef tokens_text(\n    tokenized_text: TokenizedText,\n    token_interval: TokenInterval,\n) -> str:\n  \"\"\"Reconstructs the substring of the original text spanning a given token interval.\n\n  Args:\n    tokenized_text: A TokenizedText object containing token data.\n    token_interval: The interval specifying the range [start_index, end_index)\n      of tokens.\n\n  Returns:\n    The exact substring of the original text corresponding to the token\n    interval.\n\n  Raises:\n    InvalidTokenIntervalError: If the token_interval is invalid or out of range.\n  \"\"\"\n  if token_interval.start_index == token_interval.end_index:\n    return \"\"\n\n  if (\n      token_interval.start_index < 0\n      or token_interval.end_index > len(tokenized_text.tokens)\n      or token_interval.start_index > token_interval.end_index\n  ):\n\n    raise InvalidTokenIntervalError(\n        f\"Invalid token interval. start_index={token_interval.start_index}, \"\n        f\"end_index={token_interval.end_index}, \"\n        f\"total_tokens={len(tokenized_text.tokens)}.\"\n    )\n\n  start_token = tokenized_text.tokens[token_interval.start_index]\n  end_token = tokenized_text.tokens[token_interval.end_index - 1]\n  return tokenized_text.text[\n      start_token.char_interval.start_pos : end_token.char_interval.end_pos\n  ]\n\n\ndef _is_end_of_sentence_token(\n    text: str,\n    tokens: Sequence[Token],\n    current_idx: int,\n    known_abbreviations: Set[str] = _KNOWN_ABBREVIATIONS,\n) -> bool:\n  \"\"\"Checks if the punctuation token at `current_idx` ends a sentence.\n\n  A token is considered a sentence terminator and is not part of a known\n  abbreviation. Only searches the text corresponding to the current token.\n\n  Args:\n    text: The entire input text.\n    tokens: The sequence of Token objects.\n    current_idx: The current token index to check.\n    known_abbreviations: Abbreviations that should not count as sentence enders\n      (e.g., \"Dr.\").\n\n  Returns:\n    True if the token at `current_idx` ends a sentence, otherwise False.\n  \"\"\"\n  current_token_text = text[\n      tokens[current_idx]\n      .char_interval.start_pos : tokens[current_idx]\n      .char_interval.end_pos\n  ]\n  if _END_OF_SENTENCE_PATTERN.search(current_token_text):\n    if current_idx > 0:\n      prev_token_text = text[\n          tokens[current_idx - 1]\n          .char_interval.start_pos : tokens[current_idx - 1]\n          .char_interval.end_pos\n      ]\n      if f\"{prev_token_text}{current_token_text}\" in known_abbreviations:\n        return False\n    return True\n  return False\n\n\ndef _is_sentence_break_after_newline(\n    text: str,\n    tokens: Sequence[Token],\n    current_idx: int,\n) -> bool:\n  \"\"\"Checks if the next token starts uppercase and follows a newline.\n\n  Args:\n    text: The entire input text.\n    tokens: The sequence of Token objects.\n    current_idx: The current token index.\n\n  Returns:\n    True if a newline is found between current_idx and current_idx+1, and\n    the next token (if any) begins with an uppercase character.\n  \"\"\"\n  if current_idx + 1 >= len(tokens):\n    return False\n\n  next_token = tokens[current_idx + 1]\n\n  if not next_token.first_token_after_newline:\n    return False\n\n  next_token_text = text[\n      next_token.char_interval.start_pos : next_token.char_interval.end_pos\n  ]\n  # Assume break unless lowercase (covers numbers/quotes).\n  return bool(next_token_text) and not next_token_text[0].islower()\n\n\ndef find_sentence_range(\n    text: str,\n    tokens: Sequence[Token],\n    start_token_index: int,\n    known_abbreviations: Set[str] = _KNOWN_ABBREVIATIONS,\n) -> TokenInterval:\n  \"\"\"Finds a 'sentence' interval from a given start index.\n\n  Sentence boundaries are defined by:\n    - punctuation tokens in _END_OF_SENTENCE_PATTERN\n    - newline breaks followed by an uppercase letter\n    - not abbreviations in _KNOWN_ABBREVIATIONS (e.g., \"Dr.\")\n\n  This favors terminating a sentence prematurely over missing a sentence\n  boundary, and will terminate a sentence early if the first line ends with new\n  line and the second line begins with a capital letter.\n\n  Args:\n    text: The text to analyze.\n    tokens: The tokens that make up `text`.\n      Note: For UnicodeTokenizer, use normalized text.\n    start_token_index: The index of the token to start the sentence from.\n    known_abbreviations: A set of strings that are known abbreviations and\n      should not be treated as sentence boundaries.\n\n\n  Returns:\n    A TokenInterval representing the sentence range [start_token_index, end). If\n    no sentence boundary is found, the end index will be the length of\n    `tokens`.\n\n  Raises:\n    SentenceRangeError: If `start_token_index` is out of range.\n  \"\"\"\n  if not tokens:\n    return TokenInterval(0, 0)\n\n  if start_token_index < 0 or start_token_index >= len(tokens):\n    raise SentenceRangeError(\n        f\"start_token_index={start_token_index} out of range. \"\n        f\"Total tokens: {len(tokens)}.\"\n    )\n\n  i = start_token_index\n  while i < len(tokens):\n    if tokens[i].token_type == TokenType.PUNCTUATION:\n      if _is_end_of_sentence_token(text, tokens, i, known_abbreviations):\n        end_index = i + 1\n        # Consume any trailing closing punctuation (e.g. quotes, parens)\n        while end_index < len(tokens):\n          next_token_text = text[\n              tokens[end_index]\n              .char_interval.start_pos : tokens[end_index]\n              .char_interval.end_pos\n          ]\n          if (\n              tokens[end_index].token_type == TokenType.PUNCTUATION\n              and next_token_text in _CLOSING_PUNCTUATION\n          ):\n            end_index += 1\n          else:\n            break\n        return TokenInterval(start_index=start_token_index, end_index=end_index)\n    if _is_sentence_break_after_newline(text, tokens, i):\n      return TokenInterval(start_index=start_token_index, end_index=i + 1)\n    i += 1\n\n  return TokenInterval(start_index=start_token_index, end_index=len(tokens))\n"
  },
  {
    "path": "langextract/core/types.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Core data types for LangExtract.\"\"\"\nfrom __future__ import annotations\n\nimport dataclasses\nimport enum\nimport textwrap\n\n__all__ = [\n    'ScoredOutput',\n    'FormatType',\n    'ConstraintType',\n    'Constraint',\n]\n\n\nclass FormatType(enum.Enum):\n  \"\"\"Enumeration of prompt output formats.\"\"\"\n\n  YAML = 'yaml'\n  JSON = 'json'\n\n\nclass ConstraintType(enum.Enum):\n  \"\"\"Enumeration of constraint types.\"\"\"\n\n  NONE = 'none'\n\n\n@dataclasses.dataclass\nclass Constraint:\n  \"\"\"Represents a constraint for model output decoding.\n\n  Attributes:\n    constraint_type: The type of constraint applied.\n  \"\"\"\n\n  constraint_type: ConstraintType = ConstraintType.NONE\n\n\n@dataclasses.dataclass(frozen=True)\nclass ScoredOutput:\n  \"\"\"Scored output from language model inference.\"\"\"\n\n  score: float | None = None\n  output: str | None = None\n\n  def __str__(self) -> str:\n    score_str = '-' if self.score is None else f'{self.score:.2f}'\n    if self.output is None:\n      return f'Score: {score_str}\\nOutput: None'\n    formatted_lines = textwrap.indent(self.output, prefix='  ')\n    return f'Score: {score_str}\\nOutput:\\n{formatted_lines}'\n"
  },
  {
    "path": "langextract/data.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.data imports.\n\nThis module provides backward compatibility for code that imports from\nlangextract.data. All functionality has moved to langextract.core.data.\n\"\"\"\n\nfrom __future__ import annotations\n\n# Re-export everything from core.data for backward compatibility\n# pylint: disable=unused-wildcard-import\nfrom langextract.core.data import *\n"
  },
  {
    "path": "langextract/data_lib.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Library for data conversion between AnnotatedDocument and JSON.\"\"\"\nfrom __future__ import annotations\n\nimport dataclasses\nimport enum\nimport numbers\nfrom typing import Any, Iterable, Mapping\n\nfrom langextract.core import data\nfrom langextract.core import tokenizer\n\n\ndef enum_asdict_factory(items: Iterable[tuple[str, Any]]) -> dict[str, Any]:\n  \"\"\"Custom dict_factory for dataclasses.asdict.\n\n  Recursively converts dataclass instances, converts enum values to their\n  underlying values, converts integral numeric types to int, and skips any\n  field whose name starts with an underscore.\n\n  Args:\n    items: An iterable of (key, value) pairs from fields of a dataclass.\n\n  Returns:\n    A mapping of field names to their values, with special handling for\n    dataclasses, enums, and numeric types.\n  \"\"\"\n  result: dict[str, Any] = {}\n  for key, value in items:\n    # Skip internal fields.\n    if key.startswith(\"_\"):\n      continue\n    if dataclasses.is_dataclass(value):\n      result[key] = dataclasses.asdict(value, dict_factory=enum_asdict_factory)\n    elif isinstance(value, enum.Enum):\n      result[key] = value.value\n    elif isinstance(value, numbers.Integral) and not isinstance(value, bool):\n      result[key] = int(value)\n    else:\n      result[key] = value\n  return result\n\n\ndef annotated_document_to_dict(\n    adoc: data.AnnotatedDocument | None,\n) -> dict[str, Any]:\n  \"\"\"Converts an AnnotatedDocument into a Python dict.\n\n  This function converts an AnnotatedDocument object into a Python dict, making\n  it easier to serialize or deserialize the document. Enum values and NumPy\n  integers are converted to their underlying values, while other data types are\n  left unchanged. Private fields with an underscore prefix are not included in\n  the output.\n\n  Args:\n    adoc: The AnnotatedDocument object to convert.\n\n  Returns:\n    A Python dict representing the AnnotatedDocument.\n  \"\"\"\n\n  if not adoc:\n    return {}\n\n  result = dataclasses.asdict(adoc, dict_factory=enum_asdict_factory)\n\n  result[\"document_id\"] = adoc.document_id\n\n  return result\n\n\ndef dict_to_annotated_document(\n    adoc_dic: Mapping[str, Any],\n) -> data.AnnotatedDocument:\n  \"\"\"Converts a Python dict back to an AnnotatedDocument.\n\n  Args:\n    adoc_dic: A Python dict representing an AnnotatedDocument.\n\n  Returns:\n    An AnnotatedDocument object.\n  \"\"\"\n  if not adoc_dic:\n    return data.AnnotatedDocument()\n\n  for extractions in adoc_dic.get(\"extractions\", []):\n    token_int = extractions.get(\"token_interval\")\n    if token_int:\n      extractions[\"token_interval\"] = tokenizer.TokenInterval(**token_int)\n    else:\n      extractions[\"token_interval\"] = None\n\n    char_int = extractions.get(\"char_interval\")\n    if char_int:\n      extractions[\"char_interval\"] = data.CharInterval(**char_int)\n    else:\n      extractions[\"char_interval\"] = None\n\n    status_str = extractions.get(\"alignment_status\")\n    if status_str:\n      extractions[\"alignment_status\"] = data.AlignmentStatus(status_str)\n    else:\n      extractions[\"alignment_status\"] = None\n\n  return data.AnnotatedDocument(\n      document_id=adoc_dic.get(\"document_id\"),\n      text=adoc_dic.get(\"text\"),\n      extractions=[\n          data.Extraction(**ent) for ent in adoc_dic.get(\"extractions\", [])\n      ],\n  )\n"
  },
  {
    "path": "langextract/exceptions.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Public exceptions API for LangExtract.\n\nThis module re-exports exceptions from core.exceptions for backward compatibility.\nAll new code should import directly from langextract.core.exceptions.\n\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nfrom langextract.core import exceptions as core_exceptions\n\n# Backward compat re-exports\nInferenceConfigError = core_exceptions.InferenceConfigError\nInferenceError = core_exceptions.InferenceError\nInferenceOutputError = core_exceptions.InferenceOutputError\nInferenceRuntimeError = core_exceptions.InferenceRuntimeError\nLangExtractError = core_exceptions.LangExtractError\nProviderError = core_exceptions.ProviderError\nSchemaError = core_exceptions.SchemaError\n\n__all__ = [\n    \"LangExtractError\",\n    \"InferenceError\",\n    \"InferenceConfigError\",\n    \"InferenceRuntimeError\",\n    \"InferenceOutputError\",\n    \"ProviderError\",\n    \"SchemaError\",\n]\n"
  },
  {
    "path": "langextract/extraction.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Main extraction API for LangExtract.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Iterable\nimport typing\nfrom typing import cast\nimport warnings\n\nfrom langextract import annotation\nfrom langextract import factory\nfrom langextract import io\nfrom langextract import prompt_validation as pv\nfrom langextract import prompting\nfrom langextract import resolver\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import format_handler as fh\nfrom langextract.core import tokenizer as tokenizer_lib\n\n\ndef extract(\n    text_or_documents: typing.Any,\n    prompt_description: str | None = None,\n    examples: typing.Sequence[typing.Any] | None = None,\n    model_id: str = \"gemini-2.5-flash\",\n    api_key: str | None = None,\n    language_model_type: typing.Type[typing.Any] | None = None,\n    format_type: typing.Any = None,\n    max_char_buffer: int = 1000,\n    temperature: float | None = None,\n    fence_output: bool | None = None,\n    use_schema_constraints: bool = True,\n    batch_length: int = 10,\n    max_workers: int = 10,\n    additional_context: str | None = None,\n    resolver_params: dict | None = None,\n    language_model_params: dict | None = None,\n    debug: bool = False,\n    model_url: str | None = None,\n    extraction_passes: int = 1,\n    context_window_chars: int | None = None,\n    config: typing.Any = None,\n    model: typing.Any = None,\n    *,\n    fetch_urls: bool = True,\n    prompt_validation_level: pv.PromptValidationLevel = pv.PromptValidationLevel.WARNING,\n    prompt_validation_strict: bool = False,\n    show_progress: bool = True,\n    tokenizer: tokenizer_lib.Tokenizer | None = None,\n) -> list[data.AnnotatedDocument] | data.AnnotatedDocument:\n  \"\"\"Extracts structured information from text.\n\n  Retrieves structured information from the provided text or documents using a\n  language model based on the instructions in prompt_description and guided by\n  examples. Supports sequential extraction passes to improve recall at the cost\n  of additional API calls.\n\n  Args:\n      text_or_documents: The source text to extract information from, a URL to\n        download text from (starting with http:// or https:// when fetch_urls\n        is True), or an iterable of Document objects.\n      prompt_description: Instructions for what to extract from the text.\n      examples: List of ExampleData objects to guide the extraction.\n      tokenizer: Optional Tokenizer instance to use for chunking and alignment.\n        If None, defaults to RegexTokenizer.\n      api_key: API key for Gemini or other LLM services (can also use\n        environment variable LANGEXTRACT_API_KEY). Cost considerations: Most\n        APIs charge by token volume. Smaller max_char_buffer values increase the\n        number of API calls, while extraction_passes > 1 reprocesses tokens\n        multiple times. Note that max_workers improves processing speed without\n        additional token costs. Refer to your API provider's pricing details and\n        monitor usage with small test runs to estimate costs.\n      model_id: The model ID to use for extraction (e.g., 'gemini-2.5-flash').\n        If your model ID is not recognized or you need to use a custom provider,\n        use the 'config' parameter with factory.ModelConfig to specify the\n        provider explicitly.\n      language_model_type: [DEPRECATED] The type of language model to use for\n        inference. Warning triggers when value differs from the legacy default\n        (GeminiLanguageModel). This parameter will be removed in v2.0.0. Use\n        the model, config, or model_id parameters instead.\n      format_type: The format type for the output (JSON or YAML).\n      max_char_buffer: Max number of characters for inference.\n      temperature: The sampling temperature for generation. When None (default),\n        uses the model's default temperature. Set to 0.0 for deterministic output\n        or higher values for more variation.\n      fence_output: Whether to expect/generate fenced output (```json or\n        ```yaml). When True, the model is prompted to generate fenced output and\n        the resolver expects it. When False, raw JSON/YAML is expected. When None,\n        automatically determined based on provider schema capabilities: if a schema\n        is applied and requires_raw_output is True, defaults to False; otherwise\n        True. If your model utilizes schema constraints, this can generally be set\n        to False unless the constraint also accounts for code fence delimiters.\n      use_schema_constraints: Whether to generate schema constraints for models.\n        For supported models, this enables structured outputs. Defaults to True.\n      batch_length: Number of text chunks processed per batch. Higher values\n        enable greater parallelization when batch_length >= max_workers.\n        Defaults to 10.\n      max_workers: Maximum parallel workers for concurrent processing. Effective\n        parallelization is limited by min(batch_length, max_workers). Supported\n        by Gemini models. Defaults to 10.\n      additional_context: Additional context to be added to the prompt during\n        inference.\n      resolver_params: Parameters for the `resolver.Resolver`, which parses the\n        raw language model output string (e.g., extracting JSON from ```json ...\n        ``` blocks) into structured `data.Extraction` objects. This dictionary\n        overrides default settings. Keys include: - 'extraction_index_suffix'\n        (str | None): Suffix for keys indicating extraction order. Default is\n        None (order by appearance). Additional alignment parameters can be\n        included: 'enable_fuzzy_alignment' (bool): Whether to use fuzzy matching\n        if exact matching fails. Disabling this can improve performance but may\n        reduce recall. Default is True. 'fuzzy_alignment_threshold' (float):\n        Minimum token overlap ratio for fuzzy match (0.0-1.0). Default is 0.75.\n        'accept_match_lesser' (bool): Whether to accept partial exact matches.\n        Default is True. 'suppress_parse_errors' (bool): Whether to suppress\n        parsing errors and continue pipeline. Default is False.\n      language_model_params: Additional parameters for the language model.\n      debug: Whether to enable debug logging. When True, enables detailed logging\n        of function calls, arguments, return values, and timing for the langextract\n        namespace. Note: Debug logging remains enabled for the process once activated.\n      model_url: Endpoint URL for self-hosted or on-prem models. Only forwarded\n        when the selected `language_model_type` accepts this argument.\n      extraction_passes: Number of sequential extraction attempts to improve\n        recall and find additional entities. Defaults to 1 (standard single\n        extraction). When > 1, the system performs multiple independent\n        extractions and merges non-overlapping results (first extraction wins\n        for overlaps). WARNING: Each additional pass reprocesses tokens,\n        potentially increasing API costs. For example, extraction_passes=3\n        reprocesses tokens 3x.\n      context_window_chars: Number of characters from the previous chunk to\n        include as context for the current chunk. This helps with coreference\n        resolution across chunk boundaries (e.g., resolving \"She\" to a person\n        mentioned in the previous chunk). Defaults to None (disabled).\n      config: Model configuration to use for extraction. Takes precedence over\n        model_id, api_key, and language_model_type parameters. When both model\n        and config are provided, model takes precedence.\n      model: Pre-configured language model to use for extraction. Takes\n        precedence over all other parameters including config.\n      fetch_urls: Whether to automatically download content when the input is a\n        URL string. When True (default), strings starting with http:// or\n        https:// are fetched. When False, all strings are treated as literal\n        text to analyze. This is a keyword-only parameter.\n      prompt_validation_level: Controls pre-flight alignment checks on few-shot\n        examples. OFF skips validation, WARNING logs issues but continues, ERROR\n        raises on failures. Defaults to WARNING.\n      prompt_validation_strict: When True and prompt_validation_level is ERROR,\n        raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False.\n      show_progress: Whether to show progress bar during extraction. Defaults to True.\n\n  Returns:\n      An AnnotatedDocument with the extracted information when input is a\n      string or URL, or an iterable of AnnotatedDocuments when input is an\n      iterable of Documents.\n\n  Raises:\n      ValueError: If examples is None or empty.\n      ValueError: If no API key is provided or found in environment variables.\n      requests.RequestException: If URL download fails.\n      pv.PromptAlignmentError: If validation fails in ERROR mode.\n  \"\"\"\n  if not examples:\n    raise ValueError(\n        \"Examples are required for reliable extraction. Please provide at least\"\n        \" one ExampleData object with sample extractions.\"\n    )\n\n  if prompt_validation_level is not pv.PromptValidationLevel.OFF:\n    report = pv.validate_prompt_alignment(\n        examples=examples,\n        aligner=resolver.WordAligner(),\n        policy=pv.AlignmentPolicy(),\n        tokenizer=tokenizer,\n    )\n    pv.handle_alignment_report(\n        report,\n        level=prompt_validation_level,\n        strict_non_exact=prompt_validation_strict,\n    )\n\n  if debug:\n    # pylint: disable=import-outside-toplevel\n    from langextract.core import debug_utils\n\n    debug_utils.configure_debug_logging()\n\n  if format_type is None:\n    format_type = data.FormatType.JSON\n\n  if max_workers is not None and batch_length < max_workers:\n    warnings.warn(\n        f\"batch_length ({batch_length}) < max_workers ({max_workers}). \"\n        f\"Only {batch_length} workers will be used. \"\n        \"Set batch_length >= max_workers for optimal parallelization.\",\n        UserWarning,\n    )\n\n  if (\n      fetch_urls\n      and isinstance(text_or_documents, str)\n      and io.is_url(text_or_documents)\n  ):\n    text_or_documents = io.download_text_from_url(text_or_documents)\n\n  prompt_template = prompting.PromptTemplateStructured(\n      description=prompt_description\n  )\n  prompt_template.examples.extend(examples)\n\n  language_model: base_model.BaseLanguageModel | None = None\n\n  if model:\n    language_model = model\n    if fence_output is not None:\n      language_model.set_fence_output(fence_output)\n    if use_schema_constraints:\n      warnings.warn(\n          \"'use_schema_constraints' is ignored when 'model' is provided. \"\n          \"The model should already be configured with schema constraints.\",\n          UserWarning,\n          stacklevel=2,\n      )\n  elif config:\n    if use_schema_constraints:\n      warnings.warn(\n          \"With 'config', schema constraints are still applied via examples. \"\n          \"Or pass explicit schema in config.provider_kwargs.\",\n          UserWarning,\n          stacklevel=2,\n      )\n\n    language_model = factory.create_model(\n        config=config,\n        examples=prompt_template.examples if use_schema_constraints else None,\n        use_schema_constraints=use_schema_constraints,\n        fence_output=fence_output,\n    )\n  else:\n    if language_model_type is not None:\n      warnings.warn(\n          \"'language_model_type' is deprecated and will be removed in v2.0.0. \"\n          \"Use model, config, or model_id parameters instead.\",\n          FutureWarning,\n          stacklevel=2,\n      )\n\n    base_lm_kwargs: dict[str, typing.Any] = {\n        \"api_key\": api_key,\n        \"format_type\": format_type,\n        \"temperature\": temperature,\n        \"model_url\": model_url,\n        \"base_url\": model_url,\n        \"max_workers\": max_workers,\n    }\n\n    # TODO(v2.0.0): Remove gemini_schema parameter\n    if \"gemini_schema\" in (language_model_params or {}):\n      warnings.warn(\n          \"'gemini_schema' is deprecated. Schema constraints are now \"\n          \"automatically handled. This parameter will be ignored.\",\n          FutureWarning,\n          stacklevel=2,\n      )\n      language_model_params = dict(language_model_params or {})\n      language_model_params.pop(\"gemini_schema\", None)\n\n    base_lm_kwargs.update(language_model_params or {})\n    filtered_kwargs = {k: v for k, v in base_lm_kwargs.items() if v is not None}\n\n    config = factory.ModelConfig(\n        model_id=model_id, provider_kwargs=filtered_kwargs\n    )\n\n    language_model = factory.create_model(\n        config=config,\n        examples=prompt_template.examples if use_schema_constraints else None,\n        use_schema_constraints=use_schema_constraints,\n        fence_output=fence_output,\n    )\n\n  format_handler, remaining_params = fh.FormatHandler.from_resolver_params(\n      resolver_params=resolver_params,\n      base_format_type=format_type,\n      base_use_fences=language_model.requires_fence_output,\n      base_attribute_suffix=data.ATTRIBUTE_SUFFIX,\n      base_use_wrapper=True,\n      base_wrapper_key=data.EXTRACTIONS_KEY,\n  )\n\n  if language_model.schema is not None:\n    language_model.schema.validate_format(format_handler)\n\n  # Pull alignment settings from normalized params\n  alignment_kwargs = {}\n  for key in resolver.ALIGNMENT_PARAM_KEYS:\n    val = remaining_params.pop(key, None)\n    if val is not None:\n      alignment_kwargs[key] = val\n\n  effective_params = {\"format_handler\": format_handler, **remaining_params}\n\n  try:\n    res = resolver.Resolver(**effective_params)\n  except TypeError as e:\n    msg = str(e)\n    if (\n        \"unexpected keyword argument\" in msg\n        or \"got an unexpected keyword argument\" in msg\n    ):\n      raise TypeError(\n          f\"Unknown key in resolver_params; check spelling: {e}\"\n      ) from e\n    raise\n\n  annotator = annotation.Annotator(\n      language_model=language_model,\n      prompt_template=prompt_template,\n      format_handler=format_handler,\n  )\n\n  if isinstance(text_or_documents, str):\n    result = annotator.annotate_text(\n        text=text_or_documents,\n        resolver=res,\n        max_char_buffer=max_char_buffer,\n        batch_length=batch_length,\n        additional_context=additional_context,\n        debug=debug,\n        extraction_passes=extraction_passes,\n        context_window_chars=context_window_chars,\n        show_progress=show_progress,\n        max_workers=max_workers,\n        tokenizer=tokenizer,\n        **alignment_kwargs,\n    )\n    return result\n  else:\n    documents = cast(Iterable[data.Document], text_or_documents)\n    result = annotator.annotate_documents(\n        documents=documents,\n        resolver=res,\n        max_char_buffer=max_char_buffer,\n        batch_length=batch_length,\n        debug=debug,\n        extraction_passes=extraction_passes,\n        context_window_chars=context_window_chars,\n        show_progress=show_progress,\n        max_workers=max_workers,\n        tokenizer=tokenizer,\n        **alignment_kwargs,\n    )\n    return list(result)\n"
  },
  {
    "path": "langextract/factory.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Factory for creating language model instances.\n\nThis module provides a factory pattern for instantiating language models\nbased on configuration, with support for environment variable resolution\nand provider-specific defaults.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport dataclasses\nimport os\nimport typing\nimport warnings\n\nfrom langextract import providers\nfrom langextract.core import base_model\nfrom langextract.core import exceptions\nfrom langextract.providers import router\n\n\n@dataclasses.dataclass(slots=True, frozen=True)\nclass ModelConfig:\n  \"\"\"Configuration for instantiating a language model provider.\n\n  Attributes:\n    model_id: The model identifier (e.g., \"gemini-2.5-flash\", \"gpt-4o\").\n    provider: Optional explicit provider name or class name. Use this to\n      disambiguate when multiple providers support the same model_id.\n    provider_kwargs: Optional provider-specific keyword arguments.\n  \"\"\"\n\n  model_id: str | None = None\n  provider: str | None = None\n  provider_kwargs: dict[str, typing.Any] = dataclasses.field(\n      default_factory=dict\n  )\n\n\ndef _kwargs_with_environment_defaults(\n    model_id: str, kwargs: dict[str, typing.Any]\n) -> dict[str, typing.Any]:\n  \"\"\"Add environment-based defaults to provider kwargs.\n\n  Args:\n    model_id: The model identifier.\n    kwargs: Existing keyword arguments.\n\n  Returns:\n    Updated kwargs with environment defaults.\n  \"\"\"\n  resolved = dict(kwargs)\n\n  if \"api_key\" not in resolved and not resolved.get(\"vertexai\", False):\n    model_lower = model_id.lower()\n    env_vars_by_provider = {\n        \"gemini\": (\"GEMINI_API_KEY\", \"LANGEXTRACT_API_KEY\"),\n        \"gpt\": (\"OPENAI_API_KEY\", \"LANGEXTRACT_API_KEY\"),\n    }\n\n    for provider_prefix, env_vars in env_vars_by_provider.items():\n      if provider_prefix in model_lower:\n        found_keys = []\n        for env_var in env_vars:\n          key_val = os.getenv(env_var)\n          if key_val:\n            found_keys.append((env_var, key_val))\n\n        if found_keys:\n          resolved[\"api_key\"] = found_keys[0][1]\n\n          if len(found_keys) > 1:\n            keys_list = \", \".join(k[0] for k in found_keys)\n            warnings.warn(\n                f\"Multiple API keys detected in environment: {keys_list}. \"\n                f\"Using {found_keys[0][0]} and ignoring others.\",\n                UserWarning,\n                stacklevel=3,\n            )\n        break\n\n  if \"ollama\" in model_id.lower() and \"base_url\" not in resolved:\n    resolved[\"base_url\"] = os.getenv(\n        \"OLLAMA_BASE_URL\", \"http://localhost:11434\"\n    )\n\n  return resolved\n\n\ndef create_model(\n    config: ModelConfig,\n    examples: typing.Sequence[typing.Any] | None = None,\n    use_schema_constraints: bool = False,\n    fence_output: bool | None = None,\n    return_fence_output: bool = False,\n) -> base_model.BaseLanguageModel | tuple[base_model.BaseLanguageModel, bool]:\n  \"\"\"Create a language model instance from configuration.\n\n  Args:\n    config: Model configuration with optional model_id and/or provider.\n    examples: Optional examples for schema generation (if use_schema_constraints=True).\n    use_schema_constraints: Whether to apply schema constraints from examples.\n    fence_output: Explicit fence output preference. If None, computed from schema.\n    return_fence_output: If True, also return computed fence_output value.\n\n  Returns:\n    An instantiated language model provider.\n    If return_fence_output=True: Tuple of (model, model.requires_fence_output).\n\n  Raises:\n    ValueError: If neither model_id nor provider is specified.\n    ValueError: If no provider is registered for the model_id.\n    InferenceConfigError: If provider instantiation fails.\n  \"\"\"\n  if use_schema_constraints or fence_output is not None:\n    model = _create_model_with_schema(\n        config=config,\n        examples=examples,\n        use_schema_constraints=use_schema_constraints,\n        fence_output=fence_output,\n    )\n    if return_fence_output:\n      return model, model.requires_fence_output\n    return model\n\n  if not config.model_id and not config.provider:\n    raise ValueError(\"Either model_id or provider must be specified\")\n\n  providers.load_builtins_once()\n  providers.load_plugins_once()\n\n  try:\n    if config.provider:\n      provider_class = router.resolve_provider(config.provider)\n    else:\n      provider_class = router.resolve(config.model_id)\n  except (ModuleNotFoundError, ImportError) as e:\n    raise exceptions.InferenceConfigError(\n        \"Failed to load provider. \"\n        \"This may be due to missing dependencies. \"\n        f\"Check that all required packages are installed. Error: {e}\"\n    ) from e\n\n  model_id = config.model_id\n\n  model_id = config.model_id\n\n  kwargs = _kwargs_with_environment_defaults(\n      model_id or config.provider or \"\", config.provider_kwargs\n  )\n\n  if model_id:\n    kwargs[\"model_id\"] = model_id\n\n  try:\n    model = provider_class(**kwargs)\n    if return_fence_output:\n      return model, model.requires_fence_output\n    return model\n  except (ValueError, TypeError) as e:\n    raise exceptions.InferenceConfigError(\n        f\"Failed to create provider {provider_class.__name__}: {e}\"\n    ) from e\n\n\ndef create_model_from_id(\n    model_id: str | None = None,\n    provider: str | None = None,\n    **provider_kwargs: typing.Any,\n) -> base_model.BaseLanguageModel:\n  \"\"\"Convenience function to create a model.\n\n  Args:\n    model_id: The model identifier (e.g., \"gemini-2.5-flash\").\n    provider: Optional explicit provider name to disambiguate.\n    **provider_kwargs: Optional provider-specific keyword arguments.\n\n  Returns:\n    An instantiated language model provider.\n  \"\"\"\n  config = ModelConfig(\n      model_id=model_id, provider=provider, provider_kwargs=provider_kwargs\n  )\n  return create_model(config)\n\n\ndef _create_model_with_schema(\n    config: ModelConfig,\n    examples: typing.Sequence[typing.Any] | None = None,\n    use_schema_constraints: bool = True,\n    fence_output: bool | None = None,\n) -> base_model.BaseLanguageModel:\n  \"\"\"Internal helper to create a model with optional schema constraints.\n\n  This function creates a language model and optionally configures it with\n  schema constraints derived from the provided examples. It also computes\n  appropriate fence defaulting based on the schema's capabilities.\n\n  Args:\n    config: Model configuration with model_id and/or provider.\n    examples: Optional sequence of ExampleData for schema generation.\n    use_schema_constraints: Whether to generate and apply schema constraints.\n    fence_output: Whether to wrap output in markdown fences. If None,\n      will be computed based on schema's requires_raw_output.\n\n  Returns:\n    A model instance with fence_output configured appropriately.\n  \"\"\"\n\n  if config.provider:\n    provider_class = router.resolve_provider(config.provider)\n  else:\n    providers.load_builtins_once()\n    providers.load_plugins_once()\n    provider_class = router.resolve(config.model_id)\n\n  schema_instance = None\n  if use_schema_constraints and examples:\n    schema_class = provider_class.get_schema_class()\n    if schema_class is not None:\n      schema_instance = schema_class.from_examples(examples)\n\n  if schema_instance:\n    kwargs = schema_instance.to_provider_config()\n    kwargs.update(config.provider_kwargs)\n  else:\n    kwargs = dict(config.provider_kwargs)\n\n  if schema_instance:\n    schema_instance.sync_with_provider_kwargs(kwargs)\n\n  # Add environment defaults\n  model_id = config.model_id\n  kwargs = _kwargs_with_environment_defaults(\n      model_id or config.provider or \"\", kwargs\n  )\n\n  if model_id:\n    kwargs[\"model_id\"] = model_id\n\n  try:\n    model = provider_class(**kwargs)\n  except (ValueError, TypeError) as e:\n    raise exceptions.InferenceConfigError(\n        f\"Failed to create provider {provider_class.__name__}: {e}\"\n    ) from e\n\n  model.apply_schema(schema_instance)\n  model.set_fence_output(fence_output)\n\n  return model\n"
  },
  {
    "path": "langextract/inference.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Language model inference compatibility layer.\n\nThis module provides backward compatibility for the inference module.\nNew code should import from langextract.core.base_model instead.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom langextract._compat import inference\n\n\ndef __getattr__(name: str):\n  \"\"\"Forward to _compat.inference for backward compatibility.\"\"\"\n  # Handle InferenceType specially since it's defined in _compat\n  if name == \"InferenceType\":\n    return inference.InferenceType\n\n  return inference.__getattr__(name)\n"
  },
  {
    "path": "langextract/io.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Supports Input and Output Operations for Data Annotations.\"\"\"\nfrom __future__ import annotations\n\nimport abc\nimport dataclasses\nimport ipaddress\nimport json\nimport os\nimport pathlib\nfrom typing import Any, Iterator\nfrom urllib import parse as urlparse\n\nimport pandas as pd\nimport requests\n\nfrom langextract import data_lib\nfrom langextract import progress\nfrom langextract.core import data\nfrom langextract.core import exceptions\n\nDEFAULT_TIMEOUT_SECONDS = 30\n\n\nclass InvalidDatasetError(exceptions.LangExtractError):\n  \"\"\"Error raised when Dataset is empty or invalid.\"\"\"\n\n\n@dataclasses.dataclass(frozen=True)\nclass Dataset(abc.ABC):\n  \"\"\"A dataset for inputs to LLM Labeler.\"\"\"\n\n  input_path: pathlib.Path\n  id_key: str\n  text_key: str\n\n  def load(self, delimiter: str = ',') -> Iterator[data.Document]:\n    \"\"\"Loads the dataset from a CSV file.\n\n    Args:\n      delimiter: The delimiter to use when reading the CSV file.\n\n    Yields:\n      A Document for each row in the dataset.\n\n    Raises:\n      IOError: If the file does not exist.\n      InvalidDatasetError: If the dataset is empty or invalid.\n      NotImplementedError: If the file type is not supported.\n    \"\"\"\n    if not os.path.exists(self.input_path):\n      raise IOError(f'File does not exist: {self.input_path}')\n\n    if str(self.input_path).endswith('.csv'):\n      try:\n        csv_data = _read_csv(\n            self.input_path,\n            column_names=[self.text_key, self.id_key],\n            delimiter=delimiter,\n        )\n      except InvalidDatasetError as e:\n        raise InvalidDatasetError(f'Empty dataset: {self.input_path}') from e\n      for row in csv_data:\n        yield data.Document(\n            text=row[self.text_key],\n            document_id=row[self.id_key],\n        )\n    else:\n      raise NotImplementedError(f'Unsupported file type: {self.input_path}')\n\n\ndef save_annotated_documents(\n    annotated_documents: Iterator[data.AnnotatedDocument],\n    output_dir: pathlib.Path | str | None = None,\n    output_name: str = 'data.jsonl',\n    show_progress: bool = True,\n) -> None:\n  \"\"\"Saves annotated documents to a JSON Lines file.\n\n  Args:\n    annotated_documents: Iterator over AnnotatedDocument objects to save.\n    output_dir: The directory to which the JSONL file should be written.\n      Can be a Path object or a string. Defaults to 'test_output/' if None.\n    output_name: File name for the JSONL file.\n    show_progress: Whether to show a progress bar during saving.\n\n  Raises:\n    IOError: If the output directory cannot be created.\n    InvalidDatasetError: If no documents are produced.\n  \"\"\"\n  if output_dir is None:\n    output_dir = pathlib.Path('test_output')\n  else:\n    output_dir = pathlib.Path(output_dir)\n\n  output_dir.mkdir(parents=True, exist_ok=True)\n\n  output_file = output_dir / output_name\n  has_data = False\n  doc_count = 0\n\n  # Create progress bar\n  progress_bar = progress.create_save_progress_bar(\n      output_path=str(output_file), disable=not show_progress\n  )\n\n  with open(output_file, 'w', encoding='utf-8') as f:\n    for adoc in annotated_documents:\n      if not adoc.document_id:\n        continue\n\n      doc_dict = data_lib.annotated_document_to_dict(adoc)\n      f.write(json.dumps(doc_dict, ensure_ascii=False) + '\\n')\n      has_data = True\n      doc_count += 1\n      progress_bar.update(1)\n\n  progress_bar.close()\n\n  if not has_data:\n    raise InvalidDatasetError(f'No documents to save in: {output_file}')\n\n  if show_progress:\n    progress.print_save_complete(doc_count, str(output_file))\n\n\ndef load_annotated_documents_jsonl(\n    jsonl_path: pathlib.Path,\n    show_progress: bool = True,\n) -> Iterator[data.AnnotatedDocument]:\n  \"\"\"Loads annotated documents from a JSON Lines file.\n\n  Args:\n    jsonl_path: The file path to the JSON Lines file.\n    show_progress: Whether to show a progress bar during loading.\n\n  Yields:\n    AnnotatedDocument objects.\n\n  Raises:\n    IOError: If the file does not exist or is invalid.\n  \"\"\"\n  if not os.path.exists(jsonl_path):\n    raise IOError(f'File does not exist: {jsonl_path}')\n\n  # Get file size for progress bar\n  file_size = os.path.getsize(jsonl_path)\n\n  # Create progress bar\n  progress_bar = progress.create_load_progress_bar(\n      file_path=str(jsonl_path),\n      total_size=file_size if show_progress else None,\n      disable=not show_progress,\n  )\n\n  doc_count = 0\n  bytes_read = 0\n\n  with open(jsonl_path, 'r', encoding='utf-8') as f:\n    for line in f:\n      line_bytes = len(line.encode('utf-8'))\n      bytes_read += line_bytes\n      progress_bar.update(line_bytes)\n\n      line = line.strip()\n      if not line:\n        continue\n      doc_dict = json.loads(line)\n      doc_count += 1\n      yield data_lib.dict_to_annotated_document(doc_dict)\n\n  progress_bar.close()\n\n  if show_progress:\n    progress.print_load_complete(doc_count, str(jsonl_path))\n\n\ndef _read_csv(\n    filepath: pathlib.Path, column_names: list[str], delimiter: str = ','\n) -> Iterator[dict[str, Any]]:\n  \"\"\"Reads a CSV file and yields rows as dicts.\n\n  Args:\n    filepath: The path to the file.\n    column_names: The names of the columns to read.\n    delimiter: The delimiter to use when reading the CSV file.\n\n  Yields:\n    An iterator of dicts representing each row.\n\n  Raises:\n    IOError: If the file does not exist.\n    InvalidDatasetError: If the dataset is empty or invalid.\n  \"\"\"\n  if not os.path.exists(filepath):\n    raise IOError(f'File does not exist: {filepath}')\n\n  try:\n    with open(filepath, 'r', encoding='utf-8') as f:\n      df = pd.read_csv(f, usecols=column_names, dtype=str, delimiter=delimiter)\n      for _, row in df.iterrows():\n        yield row.to_dict()\n  except pd.errors.EmptyDataError as e:\n    raise InvalidDatasetError(f'Empty dataset: {filepath}') from e\n  except ValueError as e:\n    raise InvalidDatasetError(f'Invalid dataset file: {filepath}') from e\n\n\ndef is_url(text: str) -> bool:\n  \"\"\"Check if the given text is a valid URL.\n\n  Uses urllib.parse to validate that the text is a properly formed URL\n  with http or https scheme and a valid network location.\n\n  Args:\n    text: The string to check.\n\n  Returns:\n    True if the text is a valid URL with http(s) scheme, False otherwise.\n  \"\"\"\n  if not text or not isinstance(text, str):\n    return False\n\n  text = text.strip()\n\n  # Reject text with whitespace (not a pure URL)\n  if ' ' in text or '\\n' in text or '\\t' in text:\n    return False\n\n  try:\n    result = urlparse.urlparse(text)\n    hostname = result.hostname\n\n    # Must have valid scheme, netloc, and hostname\n    if not (result.scheme in ('http', 'https') and result.netloc and hostname):\n      return False\n\n    # Accept IPs, localhost, or domains with dots\n    try:\n      ipaddress.ip_address(hostname)\n      return True\n    except ValueError:\n      return hostname == 'localhost' or '.' in hostname\n  except (ValueError, AttributeError):\n    return False\n\n\ndef download_text_from_url(\n    url: str,\n    timeout: int = DEFAULT_TIMEOUT_SECONDS,\n    show_progress: bool = True,\n    chunk_size: int = 8192,\n) -> str:\n  \"\"\"Download text content from a URL with optional progress bar.\n\n  Args:\n    url: The URL to download from.\n    timeout: Request timeout in seconds.\n    show_progress: Whether to show a progress bar during download.\n    chunk_size: Size of chunks to download at a time.\n\n  Returns:\n    The text content of the URL.\n\n  Raises:\n    requests.RequestException: If the download fails.\n    ValueError: If the content is not text-based.\n  \"\"\"\n  try:\n    # Make initial request to get headers\n    response = requests.get(url, stream=True, timeout=timeout)\n    response.raise_for_status()\n\n    # Check content type\n    content_type = response.headers.get('Content-Type', '').lower()\n    if not any(\n        ct in content_type\n        for ct in ['text/', 'application/json', 'application/xml']\n    ):\n      # Try to proceed anyway, but warn\n      print(f\"Warning: Content-Type '{content_type}' may not be text-based\")\n\n    # Get content length for progress bar\n    total_size = int(response.headers.get('Content-Length', 0))\n\n    filename = url.split('/')[-1][:50]\n\n    # Download content with progress bar\n    chunks = []\n    if show_progress and total_size > 0:\n      progress_bar = progress.create_download_progress_bar(\n          total_size=total_size, url=url\n      )\n\n      for chunk in response.iter_content(chunk_size=chunk_size):\n        if chunk:\n          chunks.append(chunk)\n          progress_bar.update(len(chunk))\n\n      progress_bar.close()\n    else:\n      # Download without progress bar\n      for chunk in response.iter_content(chunk_size=chunk_size):\n        if chunk:\n          chunks.append(chunk)\n\n    # Combine chunks and decode\n    content = b''.join(chunks)\n\n    # Try to decode as text\n    encodings = ['utf-8', 'latin-1', 'ascii', 'utf-16']\n    text_content = None\n    for encoding in encodings:\n      try:\n        text_content = content.decode(encoding)\n        break\n      except UnicodeDecodeError:\n        continue\n\n    if text_content is None:\n      raise ValueError(f'Could not decode content from {url} as text')\n\n    # Show content summary with clean formatting\n    if show_progress:\n      char_count = len(text_content)\n      word_count = len(text_content.split())\n      progress.print_download_complete(char_count, word_count, filename)\n\n    return text_content\n\n  except requests.RequestException as e:\n    raise requests.RequestException(\n        f'Failed to download from {url}: {str(e)}'\n    ) from e\n"
  },
  {
    "path": "langextract/plugins.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Provider discovery and registration system.\n\nThis module provides centralized provider discovery without circular imports.\nIt supports both built-in providers and third-party providers via entry points.\n\"\"\"\nfrom __future__ import annotations\n\nimport functools\nimport importlib\nfrom importlib import metadata\n\nfrom absl import logging\n\nfrom langextract.core import base_model\n\n__all__ = [\"available_providers\", \"get_provider_class\"]\n\n# Static mapping for built-in providers (always available)\n_BUILTINS: dict[str, str] = {\n    \"gemini\": \"langextract.providers.gemini:GeminiLanguageModel\",\n    \"ollama\": \"langextract.providers.ollama:OllamaLanguageModel\",\n}\n\n# Optional built-in providers (require extra dependencies)\n_OPTIONAL_BUILTINS: dict[str, str] = {\n    \"openai\": \"langextract.providers.openai:OpenAILanguageModel\",\n}\n\n\ndef _safe_entry_points(group: str) -> list:\n  \"\"\"Get entry points with Python 3.8-3.12 compatibility.\n\n  Args:\n    group: Entry point group name.\n\n  Returns:\n    List of entry points in the specified group.\n  \"\"\"\n  eps = metadata.entry_points()\n  try:\n    # Python 3.10+\n    return list(eps.select(group=group))\n  except AttributeError:\n    # Python 3.8-3.9\n    return list(getattr(eps, \"get\")(group, []))\n\n\n@functools.lru_cache(maxsize=1)\ndef _discovered() -> dict[str, str]:\n  \"\"\"Cache discovered third-party providers.\n\n  Returns:\n    Dictionary mapping provider names to import specs.\n  \"\"\"\n  discovered: dict[str, str] = {}\n  for ep in _safe_entry_points(\"langextract.providers\"):\n    # Handle both old and new entry_points API\n    if hasattr(ep, \"value\"):\n\n      discovered.setdefault(ep.name, ep.value)\n    else:\n      # Legacy API - construct from module and attr\n      value = f\"{ep.module}:{ep.attr}\" if ep.attr else ep.module\n      discovered.setdefault(ep.name, value)\n\n  if discovered:\n    logging.debug(\n        \"Discovered third-party providers: %s\", list(discovered.keys())\n    )\n\n  return discovered\n\n\ndef available_providers(\n    allow_override: bool = False, include_optional: bool = True\n) -> dict[str, str]:\n  \"\"\"Get all available providers (built-in + optional + third-party).\n\n  Args:\n    allow_override: If True, third-party providers can override built-ins.\n                   If False (default), built-ins take precedence.\n    include_optional: If True (default), include optional built-in providers\n                     that may require extra dependencies.\n\n  Returns:\n    Dictionary mapping provider names to import specifications.\n  \"\"\"\n\n  providers = dict(_discovered())\n\n  if include_optional:\n    if allow_override:\n      # Third-party can override optional built-ins\n      providers.update(_OPTIONAL_BUILTINS)\n    else:\n      # Optional built-ins override third-party\n      providers = {**providers, **_OPTIONAL_BUILTINS}\n\n  # Always add core built-ins with highest precedence (unless allow_override)\n  if allow_override:\n    # Third-party and optional can override core built-ins\n    providers.update(_BUILTINS)\n  else:\n    # Core built-ins take precedence over everything\n    providers = {**providers, **_BUILTINS}\n\n  return providers\n\n\ndef _load_class(spec: str) -> type[base_model.BaseLanguageModel]:\n  \"\"\"Load a provider class from module:Class specification.\n\n  Args:\n    spec: Import specification in format \"module.path:ClassName\".\n\n  Returns:\n    The loaded provider class.\n\n  Raises:\n    ImportError: If the spec is invalid or module cannot be imported.\n    TypeError: If the loaded class is not a BaseLanguageModel.\n  \"\"\"\n  module_path, _, class_name = spec.partition(\":\")\n  if not module_path or not class_name:\n    raise ImportError(\n        f\"Invalid provider spec '{spec}' - expected 'module:Class'\"\n    )\n\n  try:\n    module = importlib.import_module(module_path)\n  except ImportError as e:\n    raise ImportError(\n        f\"Failed to import provider module '{module_path}': {e}\"\n    ) from e\n\n  try:\n    cls = getattr(module, class_name)\n  except AttributeError as e:\n    raise ImportError(\n        f\"Provider class '{class_name}' not found in module '{module_path}'\"\n    ) from e\n\n  # Validate it's a language model\n  if not isinstance(cls, type) or not issubclass(\n      cls, base_model.BaseLanguageModel\n  ):\n    # Fallback: check structural compatibility for non-ABC classes\n    missing = []\n    for method in (\"infer\", \"parse_output\"):\n      if not hasattr(cls, method):\n        missing.append(method)\n\n    if missing:\n      raise TypeError(\n          f\"{cls} is not a BaseLanguageModel and missing required methods:\"\n          f\" {missing}\"\n      )\n\n    logging.warning(\n        \"Provider %s does not inherit from BaseLanguageModel but appears\"\n        \" compatible\",\n        cls,\n    )\n\n  return cls\n\n\n@functools.lru_cache(maxsize=None)  # Cache all loaded classes\ndef get_provider_class(\n    name: str, allow_override: bool = False, include_optional: bool = True\n) -> type[base_model.BaseLanguageModel]:\n  \"\"\"Get a provider class by name.\n\n  Args:\n    name: Provider name (e.g., \"gemini\", \"openai\", \"ollama\").\n    allow_override: If True, allow third-party providers to override built-ins.\n    include_optional: If True (default), include optional providers that\n                     may require extra dependencies.\n\n  Returns:\n    The provider class.\n\n  Raises:\n    KeyError: If the provider name is not found.\n    ImportError: If the provider module cannot be imported (including\n                missing optional dependencies).\n    TypeError: If the provider class is not compatible.\n  \"\"\"\n  providers = available_providers(allow_override, include_optional)\n\n  if name not in providers:\n    available = sorted(providers.keys())\n    raise KeyError(\n        f\"Unknown provider '{name}'. Available providers:\"\n        f\" {', '.join(available) if available else 'none'}.\\nHint: Did you\"\n        \" install the necessary extras (e.g., pip install\"\n        f\" langextract[{name}])?\"\n    )\n\n  return _load_class(providers[name])\n"
  },
  {
    "path": "langextract/progress.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Progress and visualization utilities for LangExtract.\"\"\"\nfrom __future__ import annotations\n\nfrom typing import Any\nimport urllib.parse\n\nimport tqdm\n\n# ANSI color codes for terminal output\nBLUE = \"\\033[94m\"\nGREEN = \"\\033[92m\"\nCYAN = \"\\033[96m\"\nBOLD = \"\\033[1m\"\nRESET = \"\\033[0m\"\n\n# Google Blue color for progress bars\nGOOGLE_BLUE = \"#4285F4\"\n\n\ndef create_download_progress_bar(\n    total_size: int, url: str, ncols: int = 100, max_url_length: int = 50\n) -> tqdm.tqdm:\n  \"\"\"Create a styled progress bar for downloads.\n\n  Args:\n    total_size: Total size in bytes.\n    url: The URL being downloaded.\n    ncols: Number of columns for the progress bar.\n    max_url_length: Maximum length to show for the URL.\n\n  Returns:\n    A configured tqdm progress bar.\n  \"\"\"\n  # Truncate URL if too long, keeping the domain and end\n  if len(url) > max_url_length:\n    parsed = urllib.parse.urlparse(url)\n    domain = parsed.netloc or parsed.hostname or \"unknown\"\n\n    path_parts = parsed.path.strip(\"/\").split(\"/\")\n    filename = path_parts[-1] if path_parts and path_parts[-1] else \"file\"\n\n    available = max_url_length - len(domain) - len(filename) - 5\n    if available > 0:\n      url_display = f\"{domain}/.../{filename}\"\n    else:\n      url_display = url[: max_url_length - 3] + \"...\"\n  else:\n    url_display = url\n\n  return tqdm.tqdm(\n      total=total_size,\n      unit=\"B\",\n      unit_scale=True,\n      desc=(\n          f\"{BLUE}{BOLD}LangExtract{RESET}: Downloading\"\n          f\" {GREEN}{url_display}{RESET}\"\n      ),\n      bar_format=(\n          \"{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}\"\n          \" [{elapsed}<{remaining}, {rate_fmt}]\"\n      ),\n      colour=GOOGLE_BLUE,\n      ncols=ncols,\n  )\n\n\ndef create_extraction_progress_bar(\n    iterable: Any, model_info: str | None = None, disable: bool = False\n) -> tqdm.tqdm:\n  \"\"\"Create a styled progress bar for extraction.\n\n  Args:\n    iterable: The iterable to wrap with progress bar.\n    model_info: Optional model information to display (e.g., \"gemini-1.5-pro\").\n    disable: Whether to disable the progress bar.\n\n  Returns:\n    A configured tqdm progress bar.\n  \"\"\"\n  desc = format_extraction_progress(model_info)\n\n  return tqdm.tqdm(\n      iterable,\n      desc=desc,\n      bar_format=\"{desc} [{elapsed}]\",\n      disable=disable,\n      dynamic_ncols=True,\n  )\n\n\ndef print_download_complete(\n    char_count: int, word_count: int, filename: str\n) -> None:\n  \"\"\"Print a styled download completion message.\n\n  Args:\n    char_count: Number of characters downloaded.\n    word_count: Number of words downloaded.\n    filename: Name of the downloaded file.\n  \"\"\"\n  print(\n      f\"{GREEN}✓{RESET} Downloaded {BOLD}{char_count:,}{RESET} characters \"\n      f\"({BOLD}{word_count:,}{RESET} words) from {BLUE}{filename}{RESET}\",\n      flush=True,\n  )\n\n\ndef print_extraction_complete() -> None:\n  \"\"\"Print a generic extraction completion message.\"\"\"\n  print(f\"{GREEN}✓{RESET} Extraction processing complete\", flush=True)\n\n\ndef print_extraction_summary(\n    num_extractions: int,\n    unique_classes: int,\n    elapsed_time: float | None = None,\n    chars_processed: int | None = None,\n    num_chunks: int | None = None,\n) -> None:\n  \"\"\"Print a styled extraction summary with optional performance metrics.\n\n  Args:\n    num_extractions: Total number of extractions.\n    unique_classes: Number of unique extraction classes.\n    elapsed_time: Optional elapsed time in seconds.\n    chars_processed: Optional number of characters processed.\n    num_chunks: Optional number of chunks processed.\n  \"\"\"\n  print(\n      f\"{GREEN}✓{RESET} Extracted {BOLD}{num_extractions}{RESET} entities \"\n      f\"({BOLD}{unique_classes}{RESET} unique types)\",\n      flush=True,\n  )\n\n  if elapsed_time is not None:\n    metrics = []\n\n    # Time\n    metrics.append(f\"Time: {BOLD}{elapsed_time:.2f}s{RESET}\")\n\n    # Speed\n    if chars_processed is not None and elapsed_time > 0:\n      speed = chars_processed / elapsed_time\n      metrics.append(f\"Speed: {BOLD}{speed:,.0f}{RESET} chars/sec\")\n\n    if num_chunks is not None:\n      metrics.append(f\"Chunks: {BOLD}{num_chunks}{RESET}\")\n\n    for metric in metrics:\n      print(f\"  {CYAN}•{RESET} {metric}\", flush=True)\n\n\ndef create_save_progress_bar(\n    output_path: str, disable: bool = False\n) -> tqdm.tqdm:\n  \"\"\"Create a progress bar for saving documents.\n\n  Args:\n    output_path: The output file path.\n    disable: Whether to disable the progress bar.\n\n  Returns:\n    A configured tqdm progress bar.\n  \"\"\"\n  filename = output_path.split(\"/\")[-1]\n  return tqdm.tqdm(\n      desc=(\n          f\"{BLUE}{BOLD}LangExtract{RESET}: Saving to {GREEN}{filename}{RESET}\"\n      ),\n      unit=\" docs\",\n      disable=disable,\n  )\n\n\ndef create_load_progress_bar(\n    file_path: str, total_size: int | None = None, disable: bool = False\n) -> tqdm.tqdm:\n  \"\"\"Create a progress bar for loading documents.\n\n  Args:\n    file_path: The file path being loaded.\n    total_size: Optional total file size in bytes.\n    disable: Whether to disable the progress bar.\n\n  Returns:\n    A configured tqdm progress bar.\n  \"\"\"\n  filename = file_path.split(\"/\")[-1]\n  if total_size:\n    return tqdm.tqdm(\n        total=total_size,\n        desc=(\n            f\"{BLUE}{BOLD}LangExtract{RESET}: Loading {GREEN}{filename}{RESET}\"\n        ),\n        unit=\"B\",\n        unit_scale=True,\n        disable=disable,\n    )\n  else:\n    return tqdm.tqdm(\n        desc=(\n            f\"{BLUE}{BOLD}LangExtract{RESET}: Loading {GREEN}{filename}{RESET}\"\n        ),\n        unit=\" docs\",\n        disable=disable,\n    )\n\n\ndef print_save_complete(num_docs: int, file_path: str) -> None:\n  \"\"\"Print a save completion message.\n\n  Args:\n    num_docs: Number of documents saved.\n    file_path: Path to the saved file.\n  \"\"\"\n  filename = file_path.split(\"/\")[-1]\n  print(\n      f\"{GREEN}✓{RESET} Saved {BOLD}{num_docs}{RESET} documents to\"\n      f\" {GREEN}{filename}{RESET}\",\n      flush=True,\n  )\n\n\ndef print_load_complete(num_docs: int, file_path: str) -> None:\n  \"\"\"Print a load completion message.\n\n  Args:\n    num_docs: Number of documents loaded.\n    file_path: Path to the loaded file.\n  \"\"\"\n  filename = file_path.split(\"/\")[-1]\n  print(\n      f\"{GREEN}✓{RESET} Loaded {BOLD}{num_docs}{RESET} documents from\"\n      f\" {GREEN}{filename}{RESET}\",\n      flush=True,\n  )\n\n\ndef get_model_info(language_model: Any) -> str | None:\n  \"\"\"Extract model information from a language model instance.\n\n  Args:\n    language_model: A language model instance.\n\n  Returns:\n    A string describing the model, or None if not available.\n  \"\"\"\n  if hasattr(language_model, \"model_id\"):\n    return language_model.model_id\n\n  if hasattr(language_model, \"model_url\"):\n    return language_model.model_url\n\n  return None\n\n\ndef format_extraction_stats(current_chars: int, processed_chars: int) -> str:\n  \"\"\"Format extraction progress statistics with colors.\n\n  Args:\n    current_chars: Number of characters in current batch.\n    processed_chars: Total number of characters processed so far.\n\n  Returns:\n    Formatted string with colored statistics.\n  \"\"\"\n  current_str = f\"{GREEN}{current_chars:,}{RESET}\"\n  processed_str = f\"{GREEN}{processed_chars:,}{RESET}\"\n  return f\"current={current_str} chars, processed={processed_str} chars\"\n\n\ndef create_extraction_postfix(current_chars: int, processed_chars: int) -> str:\n  \"\"\"Create a formatted postfix string for extraction progress.\n\n  Args:\n    current_chars: Number of characters in current batch.\n    processed_chars: Total number of characters processed so far.\n\n  Returns:\n    Formatted string with statistics.\n  \"\"\"\n  current_str = f\"{GREEN}{current_chars:,}{RESET}\"\n  processed_str = f\"{GREEN}{processed_chars:,}{RESET}\"\n  return f\"current={current_str} chars, processed={processed_str} chars\"\n\n\ndef format_extraction_progress(\n    model_info: str | None,\n    current_chars: int | None = None,\n    processed_chars: int | None = None,\n) -> str:\n  \"\"\"Format the complete extraction progress bar description.\n\n  Args:\n    model_info: Optional model information (e.g., \"gemini-2.0-flash\").\n    current_chars: Number of characters in current batch (optional).\n    processed_chars: Total number of characters processed so far (optional).\n\n  Returns:\n    Formatted description string.\n  \"\"\"\n  # Base description\n  if model_info:\n    desc = f\"{BLUE}{BOLD}LangExtract{RESET}: model={GREEN}{model_info}{RESET}\"\n  else:\n    desc = f\"{BLUE}{BOLD}LangExtract{RESET}: Processing\"\n\n  # Add stats if provided\n  if current_chars is not None and processed_chars is not None:\n    current_str = f\"{GREEN}{current_chars:,}{RESET}\"\n    processed_str = f\"{GREEN}{processed_chars:,}{RESET}\"\n    desc += f\", current={current_str} chars, processed={processed_str} chars\"\n\n  return desc\n\n\ndef create_pass_progress_bar(\n    total_passes: int, disable: bool = False\n) -> tqdm.tqdm:\n  \"\"\"Create a progress bar for sequential extraction passes.\n\n  Args:\n    total_passes: Total number of sequential passes.\n    disable: Whether to disable the progress bar.\n\n  Returns:\n    A configured tqdm progress bar.\n  \"\"\"\n  desc = f\"{BLUE}{BOLD}LangExtract{RESET}: Extraction passes\"\n  return tqdm.tqdm(\n      total=total_passes,\n      desc=desc,\n      bar_format=(\n          \"{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}]\"\n      ),\n      disable=disable,\n      colour=GOOGLE_BLUE,\n      ncols=100,\n  )\n"
  },
  {
    "path": "langextract/prompt_validation.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Prompt validation for alignment checks on few-shot examples.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nimport copy\nimport dataclasses\nimport enum\n\nfrom absl import logging\n\nfrom langextract import resolver\nfrom langextract.core import data\nfrom langextract.core import tokenizer as tokenizer_lib\n\n__all__ = [\n    \"PromptValidationLevel\",\n    \"ValidationIssue\",\n    \"ValidationReport\",\n    \"PromptAlignmentError\",\n    \"AlignmentPolicy\",\n    \"validate_prompt_alignment\",\n    \"handle_alignment_report\",\n]\n\n\n_FUZZY_ALIGNMENT_MIN_THRESHOLD = 0.75\n\n\nclass PromptValidationLevel(enum.Enum):\n  \"\"\"Validation levels for prompt alignment checks.\"\"\"\n\n  OFF = \"off\"\n  WARNING = \"warning\"\n  ERROR = \"error\"\n\n\nclass _IssueKind(enum.Enum):\n  \"\"\"Internal categorization of alignment issues.\"\"\"\n\n  FAILED = \"failed\"  # alignment_status is None\n  NON_EXACT = \"non_exact\"  # MATCH_FUZZY or MATCH_LESSER\n\n\n@dataclasses.dataclass(frozen=True)\nclass ValidationIssue:\n  \"\"\"Represents a single validation issue found during alignment.\"\"\"\n\n  example_index: int\n  example_id: str | None\n  extraction_class: str\n  extraction_text_preview: str\n  alignment_status: data.AlignmentStatus | None\n  issue_kind: _IssueKind\n  char_interval: tuple[int, int] | None = None\n  token_interval: tuple[int, int] | None = None\n\n  def short_msg(self) -> str:\n    \"\"\"Returns a concise message describing the issue.\"\"\"\n    ex_id = f\" id={self.example_id}\" if self.example_id else \"\"\n    span = \"\"\n    if self.char_interval:\n      span = f\" char_span={self.char_interval}\"\n    return (\n        f\"[example#{self.example_index}{ex_id}] \"\n        f\"class='{self.extraction_class}' \"\n        f\"status={self.alignment_status} \"\n        f\"text='{self.extraction_text_preview}'{span}\"\n    )\n\n\n@dataclasses.dataclass\nclass ValidationReport:\n  \"\"\"Collection of validation issues from prompt alignment checks.\"\"\"\n\n  issues: list[ValidationIssue]\n\n  @property\n  def has_failed(self) -> bool:\n    \"\"\"Returns True if any extraction failed to align.\"\"\"\n    return any(i.issue_kind is _IssueKind.FAILED for i in self.issues)\n\n  @property\n  def has_non_exact(self) -> bool:\n    \"\"\"Returns True if any extraction has non-exact alignment.\"\"\"\n    return any(i.issue_kind is _IssueKind.NON_EXACT for i in self.issues)\n\n\nclass PromptAlignmentError(RuntimeError):\n  \"\"\"Raised when prompt alignment validation fails under ERROR mode.\"\"\"\n\n\n@dataclasses.dataclass(frozen=True)\nclass AlignmentPolicy:\n  \"\"\"Configuration for alignment validation behavior.\"\"\"\n\n  enable_fuzzy_alignment: bool = True\n  fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD\n  accept_match_lesser: bool = True\n\n\ndef _preview(s: str, n: int = 120) -> str:\n  \"\"\"Creates a preview of text for logging, collapsing whitespace.\"\"\"\n  s = \" \".join(s.split())  # Collapse whitespace for logs\n  return s if len(s) <= n else s[: n - 1] + \"…\"\n\n\ndef validate_prompt_alignment(\n    examples: Sequence[data.ExampleData],\n    aligner: resolver.WordAligner | None = None,\n    policy: AlignmentPolicy | None = None,\n    tokenizer: tokenizer_lib.Tokenizer | None = None,\n) -> ValidationReport:\n  \"\"\"Align extractions to their own example text and collect issues.\n\n  Args:\n    examples: The few-shot examples to validate.\n    aligner: WordAligner instance to use (creates new if None).\n    policy: Alignment configuration (uses defaults if None).\n    tokenizer: Optional tokenizer to use for alignment. If None, defaults to\n      RegexTokenizer.\n\n  Returns:\n    ValidationReport containing any alignment issues found.\n  \"\"\"\n  if not examples:\n    return ValidationReport(issues=[])\n\n  aligner = aligner or resolver.WordAligner()\n  policy = policy or AlignmentPolicy()\n\n  issues: list[ValidationIssue] = []\n\n  for idx, ex in enumerate(examples):\n    # Defensive copy so validation never mutates user examples.\n    copied_extractions = [[copy.deepcopy(e) for e in ex.extractions]]\n    aligned_groups = aligner.align_extractions(\n        extraction_groups=copied_extractions,\n        source_text=ex.text,\n        token_offset=0,\n        char_offset=0,\n        enable_fuzzy_alignment=policy.enable_fuzzy_alignment,\n        fuzzy_alignment_threshold=policy.fuzzy_alignment_threshold,\n        accept_match_lesser=policy.accept_match_lesser,\n        tokenizer_impl=tokenizer,\n    )\n\n    for aligned in aligned_groups[0]:\n      status = getattr(aligned, \"alignment_status\", None)\n      char_interval = getattr(aligned, \"char_interval\", None)\n      token_interval = getattr(aligned, \"token_interval\", None)\n      klass = getattr(aligned, \"extraction_class\", \"<unknown>\")\n      text = getattr(aligned, \"extraction_text\", \"\")\n\n      if status is None:\n        issues.append(\n            ValidationIssue(\n                example_index=idx,\n                example_id=getattr(ex, \"example_id\", None),\n                extraction_class=klass,\n                extraction_text_preview=_preview(text),\n                alignment_status=None,\n                issue_kind=_IssueKind.FAILED,\n                char_interval=None,\n                token_interval=None,\n            )\n        )\n      elif status in (\n          data.AlignmentStatus.MATCH_FUZZY,\n          data.AlignmentStatus.MATCH_LESSER,\n      ):\n        char_interval_tuple = None\n        token_interval_tuple = None\n        if char_interval:\n          char_interval_tuple = (char_interval.start_pos, char_interval.end_pos)\n        if token_interval:\n          token_interval_tuple = (\n              token_interval.start_index,\n              token_interval.end_index,\n          )\n\n        issues.append(\n            ValidationIssue(\n                example_index=idx,\n                example_id=getattr(ex, \"example_id\", None),\n                extraction_class=klass,\n                extraction_text_preview=_preview(text),\n                alignment_status=status,\n                issue_kind=_IssueKind.NON_EXACT,\n                char_interval=char_interval_tuple,\n                token_interval=token_interval_tuple,\n            )\n        )\n\n  return ValidationReport(issues=issues)\n\n\ndef handle_alignment_report(\n    report: ValidationReport,\n    level: PromptValidationLevel,\n    *,\n    strict_non_exact: bool = False,\n) -> None:\n  \"\"\"Log or raise based on validation level.\n\n  Args:\n    report: The validation report to handle.\n    level: The validation level determining behavior.\n    strict_non_exact: If True, treat non-exact matches as errors in ERROR mode.\n\n  Raises:\n    PromptAlignmentError: If validation fails in ERROR mode.\n  \"\"\"\n  if level is PromptValidationLevel.OFF:\n    return\n\n  for issue in report.issues:\n    if issue.issue_kind is _IssueKind.NON_EXACT:\n      logging.warning(\n          \"Prompt alignment: non-exact match: %s\", issue.short_msg()\n      )\n    else:\n      logging.warning(\n          \"Prompt alignment: FAILED to align: %s\", issue.short_msg()\n      )\n\n  if level is PromptValidationLevel.ERROR:\n    failed = [i for i in report.issues if i.issue_kind is _IssueKind.FAILED]\n    non_exact = [\n        i for i in report.issues if i.issue_kind is _IssueKind.NON_EXACT\n    ]\n\n    if failed:\n      sample = failed[0].short_msg()\n      raise PromptAlignmentError(\n          f\"Prompt alignment validation failed: {len(failed)} extraction(s) \"\n          f\"could not be aligned (e.g., {sample})\"\n      )\n    if strict_non_exact and non_exact:\n      sample = non_exact[0].short_msg()\n      raise PromptAlignmentError(\n          \"Prompt alignment validation failed under strict mode: \"\n          f\"{len(non_exact)} non-exact match(es) found (e.g., {sample})\"\n      )\n"
  },
  {
    "path": "langextract/prompting.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Library for building prompts.\"\"\"\nfrom __future__ import annotations\n\nimport dataclasses\nimport json\nimport pathlib\n\nimport pydantic\nfrom typing_extensions import override\nimport yaml\n\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import format_handler\n\n\nclass PromptBuilderError(exceptions.LangExtractError):\n  \"\"\"Failure to build prompt.\"\"\"\n\n\nclass ParseError(PromptBuilderError):\n  \"\"\"Prompt template cannot be parsed.\"\"\"\n\n\n@dataclasses.dataclass\nclass PromptTemplateStructured:\n  \"\"\"A structured prompt template for few-shot examples.\n\n  Attributes:\n    description: Instructions or guidelines for the LLM.\n    examples: ExampleData objects demonstrating expected input→output behavior.\n  \"\"\"\n\n  description: str\n  examples: list[data.ExampleData] = dataclasses.field(default_factory=list)\n\n\ndef read_prompt_template_structured_from_file(\n    prompt_path: str,\n    format_type: data.FormatType = data.FormatType.YAML,\n) -> PromptTemplateStructured:\n  \"\"\"Reads a structured prompt template from a file.\n\n  Args:\n    prompt_path: Path to a file containing PromptTemplateStructured data.\n    format_type: The format of the file; YAML or JSON.\n\n  Returns:\n    A PromptTemplateStructured object loaded from the file.\n\n  Raises:\n    ParseError: If the file cannot be parsed successfully.\n  \"\"\"\n  adapter = pydantic.TypeAdapter(PromptTemplateStructured)\n  try:\n    with pathlib.Path(prompt_path).open(\"rt\") as f:\n      data_dict = {}\n      prompt_content = f.read()\n      if format_type == data.FormatType.YAML:\n        data_dict = yaml.safe_load(prompt_content)\n      elif format_type == data.FormatType.JSON:\n        data_dict = json.loads(prompt_content)\n      return adapter.validate_python(data_dict)\n  except Exception as e:\n    raise ParseError(\n        f\"Failed to parse prompt template from file: {prompt_path}\"\n    ) from e\n\n\n@dataclasses.dataclass\nclass QAPromptGenerator:\n  \"\"\"Generates question-answer prompts from the provided template.\"\"\"\n\n  template: PromptTemplateStructured\n  format_handler: format_handler.FormatHandler\n  examples_heading: str = \"Examples\"\n  question_prefix: str = \"Q: \"\n  answer_prefix: str = \"A: \"\n\n  def __str__(self) -> str:\n    \"\"\"Returns a string representation of the prompt with an empty question.\"\"\"\n    return self.render(\"\")\n\n  def format_example_as_text(self, example: data.ExampleData) -> str:\n    \"\"\"Formats a single example for the prompt.\n\n    Args:\n      example: The example data to format.\n\n    Returns:\n      A string representation of the example, including the question and answer.\n    \"\"\"\n    question = example.text\n    answer = self.format_handler.format_extraction_example(example.extractions)\n\n    return \"\\n\".join([\n        f\"{self.question_prefix}{question}\",\n        f\"{self.answer_prefix}{answer}\\n\",\n    ])\n\n  def render(self, question: str, additional_context: str | None = None) -> str:\n    \"\"\"Generate a text representation of the prompt.\n\n    Args:\n      question: That will be presented to the model.\n      additional_context: Additional context to include in the prompt. An empty\n        string is ignored.\n\n    Returns:\n      Text prompt with a question to be presented to a language model.\n    \"\"\"\n    prompt_lines: list[str] = [f\"{self.template.description}\\n\"]\n\n    if additional_context:\n      prompt_lines.append(f\"{additional_context}\\n\")\n\n    if self.template.examples:\n      prompt_lines.append(self.examples_heading)\n      for ex in self.template.examples:\n        prompt_lines.append(self.format_example_as_text(ex))\n\n    prompt_lines.append(f\"{self.question_prefix}{question}\")\n    prompt_lines.append(self.answer_prefix)\n    return \"\\n\".join(prompt_lines)\n\n\nclass PromptBuilder:\n  \"\"\"Builds prompts for text chunks using a QAPromptGenerator.\n\n  This base class provides a simple interface for prompt generation. Subclasses\n  can extend this to add stateful behavior like cross-chunk context tracking.\n  \"\"\"\n\n  def __init__(self, generator: QAPromptGenerator):\n    \"\"\"Initializes the builder with the given prompt generator.\n\n    Args:\n      generator: The underlying prompt generator to use.\n    \"\"\"\n    self._generator = generator\n\n  def build_prompt(\n      self,\n      chunk_text: str,\n      document_id: str,\n      additional_context: str | None = None,\n  ) -> str:\n    \"\"\"Builds a prompt for the given chunk.\n\n    Args:\n      chunk_text: The text of the current chunk to process.\n      document_id: Identifier for the source document.\n      additional_context: Optional additional context from the document.\n\n    Returns:\n      The rendered prompt string ready for the language model.\n    \"\"\"\n    del document_id  # Unused in base class.\n    return self._generator.render(\n        question=chunk_text,\n        additional_context=additional_context,\n    )\n\n\nclass ContextAwarePromptBuilder(PromptBuilder):\n  \"\"\"Prompt builder with cross-chunk context tracking.\n\n  Extends PromptBuilder to inject text from the previous chunk into each\n  prompt. This helps language models resolve coreferences across chunk\n  boundaries (e.g., connecting \"She\" to \"Dr. Sarah Johnson\" from the\n  previous chunk).\n\n  Context is tracked per document_id, so multiple documents can be processed\n  without context bleeding between them.\n  \"\"\"\n\n  _CONTEXT_PREFIX = \"[Previous text]: ...\"\n\n  def __init__(\n      self,\n      generator: QAPromptGenerator,\n      context_window_chars: int | None = None,\n  ):\n    \"\"\"Initializes the builder with context tracking configuration.\n\n    Args:\n      generator: The underlying prompt generator to use.\n      context_window_chars: Number of characters from the previous chunk's\n          tail to include as context. Defaults to None (disabled).\n    \"\"\"\n    super().__init__(generator)\n    self._context_window_chars = context_window_chars\n    self._prev_chunk_by_doc_id: dict[str, str] = {}\n\n  @property\n  def context_window_chars(self) -> int | None:\n    \"\"\"Number of trailing characters from previous chunk to include.\"\"\"\n    return self._context_window_chars\n\n  @override\n  def build_prompt(\n      self,\n      chunk_text: str,\n      document_id: str,\n      additional_context: str | None = None,\n  ) -> str:\n    \"\"\"Builds a prompt, injecting previous chunk context if enabled.\n\n    Args:\n      chunk_text: The text of the current chunk to process.\n      document_id: Identifier for the source document (used to track context\n          per document).\n      additional_context: Optional additional context from the document.\n\n    Returns:\n      The rendered prompt string ready for the language model.\n    \"\"\"\n    effective_context = self._build_effective_context(\n        document_id, additional_context\n    )\n    prompt = self._generator.render(\n        question=chunk_text,\n        additional_context=effective_context,\n    )\n    self._update_state(document_id, chunk_text)\n    return prompt\n\n  def _build_effective_context(\n      self,\n      document_id: str,\n      additional_context: str | None,\n  ) -> str | None:\n    \"\"\"Combines previous chunk context with any additional context.\n\n    Args:\n      document_id: Identifier for the source document.\n      additional_context: Optional additional context from the document.\n\n    Returns:\n      Combined context string, or None if no context is available.\n    \"\"\"\n    context_parts: list[str] = []\n\n    if self._context_window_chars and document_id in self._prev_chunk_by_doc_id:\n      prev_text = self._prev_chunk_by_doc_id[document_id]\n      window = prev_text[-self._context_window_chars :]\n      context_parts.append(f\"{self._CONTEXT_PREFIX}{window}\")\n\n    if additional_context:\n      context_parts.append(additional_context)\n\n    return \"\\n\\n\".join(context_parts) if context_parts else None\n\n  def _update_state(self, document_id: str, chunk_text: str) -> None:\n    \"\"\"Stores current chunk as context for the next chunk in this document.\n\n    Args:\n      document_id: Identifier for the source document.\n      chunk_text: The current chunk text to store.\n    \"\"\"\n    if self._context_window_chars:\n      self._prev_chunk_by_doc_id[document_id] = chunk_text\n"
  },
  {
    "path": "langextract/providers/README.md",
    "content": "# LangExtract Provider System\n\nThis directory contains the provider system for LangExtract, which enables support for different Large Language Model (LLM) backends.\n\n**Quick Start**: Use the [provider plugin generator script](../../scripts/create_provider_plugin.py) to create a new provider in minutes:\n```bash\npython scripts/create_provider_plugin.py MyProvider --with-schema\n```\n\n## Architecture Overview\n\nThe provider system uses a **registry pattern** with **automatic discovery**:\n\n1. **Registry** (`registry.py`): Maps model ID patterns to provider classes\n2. **Factory** (`../factory.py`): Creates provider instances based on model IDs\n3. **Providers**: Implement the `BaseLanguageModel` interface\n\n### Provider Resolution Flow\n\n```\nUser Code                    LangExtract                      Provider\n─────────                    ───────────                      ────────\n    |                             |                              |\n    | lx.extract(                 |                              |\n    |   model_id=\"gemini-2.5-flash\")                             |\n    |─────────────────────────────>                              |\n    |                             |                              |\n    |                    factory.create_model()                  |\n    |                             |                              |\n    |                    registry.resolve(\"gemini-2.5-flash\")    |\n    |                       Pattern match: ^gemini               |\n    |                             ↓                              |\n    |                       GeminiLanguageModel                  |\n    |                             |                              |\n    |                    Instantiate provider                    |\n    |                             |─────────────────────────────>|\n    |                             |                              |\n    |                             |       Provider API calls     |\n    |                             |<─────────────────────────────|\n    |                             |                              |\n    |<────────────────────────────                               |\n    | AnnotatedDocument           |                              |\n```\n\n### Explicit Provider Selection\n\nWhen multiple providers might support the same model ID, or when you want to use a specific provider, you can explicitly specify the provider:\n\n```python\nimport langextract as lx\n\n# Method 1: Using factory directly with provider parameter\nconfig = lx.factory.ModelConfig(\n    model_id=\"gpt-4\",\n    provider=\"OpenAILanguageModel\",  # Explicit provider\n    provider_kwargs={\"api_key\": \"...\"}\n)\nmodel = lx.factory.create_model(config)\n\n# Method 2: Using provider without model_id (uses provider's default)\nconfig = lx.factory.ModelConfig(\n    provider=\"GeminiLanguageModel\",  # Will use default gemini-2.5-flash\n    provider_kwargs={\"api_key\": \"...\"}\n)\nmodel = lx.factory.create_model(config)\n\n# Method 3: Auto-detection (when no conflicts exist)\nconfig = lx.factory.ModelConfig(\n    model_id=\"gemini-2.5-flash\"  # Provider auto-detected\n)\nmodel = lx.factory.create_model(config)\n```\n\nProvider names can be:\n- Full class name: `\"GeminiLanguageModel\"`, `\"OpenAILanguageModel\"`, `\"OllamaLanguageModel\"`\n- Partial match: `\"gemini\"`, `\"openai\"`, `\"ollama\"` (case-insensitive)\n\n## Provider Types\n\n### 1. Core Providers (Always Available)\nShips with langextract, dependencies included:\n- **Gemini** (`gemini.py`): Google's Gemini models\n- **Ollama** (`ollama.py`): Local models via Ollama\n\n### 2. Built-in Provider with Optional Dependencies\nShips with langextract, but requires extra installation:\n- **OpenAI** (`openai.py`): OpenAI's GPT models\n  - Code included in package\n  - Requires: `pip install langextract[openai]` to install OpenAI SDK\n  - Future: May be moved to external plugin package\n\n### 3. External Plugins (Third-party)\nSeparate packages that extend LangExtract with new providers:\n- **Installed separately**: `pip install langextract-yourprovider`\n- **Auto-discovered**: Uses Python entry points for automatic registration\n- **Zero configuration**: Import langextract and the provider is available\n- **Independent updates**: Update providers without touching core\n\n```python\n# Install a third-party provider\npip install langextract-yourprovider\n\n# Use it immediately - no imports needed!\nimport langextract as lx\nresult = lx.extract(\n    text=\"...\",\n    model_id=\"yourmodel-latest\"  # Automatically finds the provider\n)\n```\n\n#### How Plugin Discovery Works\n\n```\n1. pip install langextract-yourprovider\n   └── Installs package containing:\n       • Provider class with @lx.providers.registry.register decorator\n       • Python entry point pointing to this class\n\n2. import langextract\n   └── Loads providers/__init__.py\n       └── Plugin loading is lazy (on-demand)\n\n3. lx.extract(model_id=\"yourmodel-latest\")\n   └── Triggers plugin discovery via entry points\n       └── @lx.providers.registry.register decorator fires\n           └── Provider patterns added to registry\n               └── Registry matches pattern and uses your provider\n```\n\n**Important Notes:**\n- Plugin loading is **lazy** - plugins are discovered when first needed\n- To manually trigger plugin loading: `lx.providers.load_plugins_once()`\n- Set `LANGEXTRACT_DISABLE_PLUGINS=1` to disable plugin loading\n- Registry entries are tuples: `(patterns_list, priority_int)`\n\n## How Provider Selection Works\n\nWhen you call `lx.extract(model_id=\"gemini-2.5-flash\", ...)`, here's what happens:\n\n1. **Factory receives model_id**: \"gemini-2.5-flash\"\n2. **Registry searches patterns**: Each provider registers regex patterns\n3. **First match wins**: Returns the matching provider class\n4. **Provider instantiated**: With model_id and any kwargs\n5. **Inference runs**: Using the selected provider\n\n### Pattern Registration Example\n\n```python\nimport langextract as lx\n\n# Gemini provider registration:\n@lx.providers.registry.register(\n    r'^GeminiLanguageModel$',  # Explicit: model_id=\"GeminiLanguageModel\"\n    r'^gemini',                # Prefix: model_id=\"gemini-2.5-flash\"\n    r'^palm'                   # Legacy: model_id=\"palm-2\"\n)\nclass GeminiLanguageModel(lx.inference.BaseLanguageModel):\n    def __init__(self, model_id: str, api_key: str = None, **kwargs):\n        # Initialize Gemini client\n        ...\n\n    def infer(self, batch_prompts, **kwargs):\n        # Call Gemini API\n        ...\n```\n\n## Usage Examples\n\n### Using Default Provider Selection\n```python\nimport langextract as lx\n\n# Automatically selects Gemini provider\nresult = lx.extract(\n    text=\"...\",\n    model_id=\"gemini-2.5-flash\"\n)\n```\n\n### Passing Parameters to Providers\n\nParameters flow from `lx.extract()` to providers through several mechanisms:\n\n```python\n# 1. Common parameters handled by lx.extract itself:\nresult = lx.extract(\n    text=\"Your document\",\n    model_id=\"gemini-2.5-flash\",\n    prompt_description=\"Extract key facts\",\n    examples=[...],           # Used for few-shot prompting\n    num_workers=4,            # Parallel processing\n    max_chunk_size=3000,      # Document chunking\n)\n\n# 2. Provider-specific parameters passed via **kwargs:\nresult = lx.extract(\n    text=\"Your document\",\n    model_id=\"gemini-2.5-flash\",\n    prompt_description=\"Extract entities\",\n    # These go directly to the Gemini provider:\n    temperature=0.7,          # Sampling temperature\n    api_key=\"your-key\",      # Override environment variable\n    max_output_tokens=1000,  # Token limit\n)\n```\n\n### Using the Factory for Advanced Control\n```python\n# When you need explicit provider selection or advanced configuration\nfrom langextract import factory\n\n# Specify both model and provider (useful when multiple providers support same model)\nconfig = factory.ModelConfig(\n    model_id=\"gemma2:2b\",\n    provider=\"OllamaLanguageModel\",  # Explicitly use Ollama\n    provider_kwargs={\n        \"model_url\": \"http://localhost:11434\"\n    }\n)\nmodel = factory.create_model(config)\n```\n\n### Direct Provider Usage\n```python\nimport langextract as lx\n\n# Direct import if you prefer (optional)\nfrom langextract.providers.gemini import GeminiLanguageModel\n\nmodel = GeminiLanguageModel(\n    model_id=\"gemini-2.5-flash\",\n    api_key=\"your-key\"\n)\noutputs = model.infer([\"prompt1\", \"prompt2\"])\n```\n\n## Creating a New Provider\n\n**📁 Complete Example**: See [examples/custom_provider_plugin/](../../examples/custom_provider_plugin/) for a fully-functional plugin template with testing and documentation.\n\n### Quick Start Checklist\n\nCreating a provider plugin? Follow this checklist:\n\n#### ☐ **1. Setup Package Structure**\n```\nlangextract-yourprovider/\n├── pyproject.toml              # Package config with entry point\n├── README.md                    # Documentation\n├── LICENSE                      # License file\n└── langextract_yourprovider/   # Package directory\n    ├── __init__.py             # Exports provider class\n    ├── provider.py             # Provider implementation\n    └── schema.py               # (Optional) Custom schema\n```\n\n#### ☐ **2. Configure Entry Point** (`pyproject.toml`)\n```toml\n[build-system]\nrequires = [\"setuptools>=61.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"langextract-yourprovider\"\nversion = \"0.1.0\"\ndependencies = [\"langextract>=1.0.0\"]\n\n[project.entry-points.\"langextract.providers\"]\nyourprovider = \"langextract_yourprovider:YourProviderLanguageModel\"\n```\n\n#### ☐ **3. Implement Provider** (`provider.py`)\n- [ ] Import required modules\n- [ ] Add `@lx.providers.registry.register()` decorator with patterns\n- [ ] Inherit from `lx.inference.BaseLanguageModel`\n- [ ] Implement `__init__()` method\n- [ ] Implement `infer()` method returning `ScoredOutput` objects\n- [ ] Export class from `__init__.py`\n\n#### ☐ **4. (Optional) Add Schema Support** (`schema.py`)\n- [ ] Create schema class inheriting from `lx.schema.BaseSchema`\n- [ ] Implement `from_examples()` class method\n- [ ] Implement `to_provider_config()` method\n- [ ] Add `get_schema_class()` to provider\n- [ ] Handle schema in provider's `__init__()` and `infer()`\n\n#### ☐ **5. Testing**\n- [ ] Install plugin with `pip install -e .`\n- [ ] Test that your provider loads and handles basic inference\n- [ ] Verify schema support works (if implemented)\n\n#### ☐ **6. Documentation**\n- [ ] Document supported model IDs and patterns\n- [ ] List required environment variables\n- [ ] Provide usage examples\n- [ ] Document any provider-specific parameters\n\n#### ☐ **7. Distribution & Community**\n- [ ] Test installation with `pip install -e .`\n- [ ] Build package with `python -m build`\n- [ ] Test in clean environment\n- [ ] Publish to PyPI with `twine upload dist/*`\n- [ ] Share your provider by opening an issue on [LangExtract GitHub](https://github.com/google/langextract/issues) to get feedback and help others discover it\n- [ ] Consider submitting a PR to add your provider to the community providers list (coming soon)\n\n### Option 1: External Plugin (Recommended)\n\nExternal plugins are the recommended approach for adding new providers. They're easy to maintain, distribute, and don't require changes to the core package.\n\n#### For Users (Installing an External Plugin)\nSimply install the plugin package:\n```bash\npip install langextract-yourprovider\n# That's it! The provider is now available in langextract\n```\n\n#### For Developers (Creating an External Plugin)\n\n1. Create a new package:\n```\nlangextract-myprovider/\n├── pyproject.toml\n├── README.md\n└── langextract_myprovider/\n    └── __init__.py\n```\n\n2. Configure entry point in `pyproject.toml`:\n```toml\n[build-system]\nrequires = [\"setuptools>=61.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"langextract-myprovider\"\nversion = \"0.1.0\"\ndependencies = [\"langextract>=1.0.0\", \"your-sdk\"]\n\n[project.entry-points.\"langextract.providers\"]\n# Pattern 1: Register the class directly\nmyprovider = \"langextract_myprovider:MyProviderLanguageModel\"\n\n# Pattern 2: Register a module that self-registers\n# myprovider = \"langextract_myprovider\"\n```\n\n3. Implement your provider:\n```python\n# langextract_myprovider/__init__.py\nimport os\nimport langextract as lx\n\n@lx.providers.registry.register(r'^mymodel', r'^custom', priority=10)\nclass MyProviderLanguageModel(lx.inference.BaseLanguageModel):\n    def __init__(self, model_id: str, api_key: str = None, **kwargs):\n        super().__init__()\n        self.model_id = model_id\n        self.api_key = api_key or os.environ.get('MYPROVIDER_API_KEY')\n        # Initialize your client\n        self.client = MyProviderClient(api_key=self.api_key)\n\n    def infer(self, batch_prompts, **kwargs):\n        # Implement inference\n        for prompt in batch_prompts:\n            result = self.client.generate(prompt, **kwargs)\n            yield [lx.inference.ScoredOutput(score=1.0, output=result)]\n```\n\n**Pattern Registration Explained:**\n- The `@register` decorator patterns (e.g., `r'^mymodel'`, `r'^custom'`) define which model IDs your provider supports\n- When users call `lx.extract(model_id=\"mymodel-3b\")`, the registry matches against these patterns\n- Your provider will handle any model_id starting with \"mymodel\" or \"custom\"\n- Users can explicitly select your provider using its class name:\n  ```python\n  config = lx.factory.ModelConfig(provider=\"MyProviderLanguageModel\")\n  # Or partial match: provider=\"myprovider\" (matches class name)\n\n4. Publish your package to PyPI:\n```bash\npip install build twine\npython -m build\ntwine upload dist/*\n```\n\nNow users can install and use your provider with just `pip install langextract-myprovider`!\n\n### Adding Schema Support\n\nSchemas enable structured output with strict JSON constraints. Here's how to add schema support to your provider:\n\n#### 1. Create a Schema Class\n\n```python\n# langextract_myprovider/schema.py\nimport langextract as lx\nfrom langextract import schema\n\nclass MyProviderSchema(lx.schema.BaseSchema):\n    def __init__(self, schema_dict: dict):\n        self._schema_dict = schema_dict\n\n    @property\n    def schema_dict(self) -> dict:\n        return self._schema_dict\n\n    @classmethod\n    def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n        \"\"\"Build schema from example extractions.\"\"\"\n        # Analyze examples to determine structure\n        extraction_types = {}\n        for example in examples_data:\n            for extraction in example.extractions:\n                class_name = extraction.extraction_class\n                if class_name not in extraction_types:\n                    extraction_types[class_name] = set()\n                if extraction.attributes:\n                    extraction_types[class_name].update(extraction.attributes.keys())\n\n        # Build JSON schema\n        schema_dict = {\n            \"type\": \"object\",\n            \"properties\": {\n                \"extractions\": {\n                    \"type\": \"array\",\n                    \"items\": {\"type\": \"object\"}  # Simplified\n                }\n            }\n        }\n        return cls(schema_dict)\n\n    def to_provider_config(self) -> dict:\n        \"\"\"Convert to provider-specific configuration.\"\"\"\n        return {\n            \"response_schema\": self._schema_dict,\n            \"structured_output\": True\n        }\n\n    @property\n    def supports_strict_mode(self) -> bool:\n        \"\"\"Return True if provider enforces valid JSON output.\"\"\"\n        return True\n```\n\n#### 2. Update Your Provider\n\n```python\n# langextract_myprovider/provider.py\nclass MyProviderLanguageModel(lx.inference.BaseLanguageModel):\n    def __init__(self, model_id: str, **kwargs):\n        super().__init__()\n        self.model_id = model_id\n        # Schema config will be in kwargs when use_schema_constraints=True\n        self.response_schema = kwargs.get('response_schema')\n        self.structured_output = kwargs.get('structured_output', False)\n\n    @classmethod\n    def get_schema_class(cls):\n        \"\"\"Tell LangExtract about our schema support.\"\"\"\n        from langextract_myprovider.schema import MyProviderSchema\n        return MyProviderSchema\n\n    def apply_schema(self, schema_instance):\n        \"\"\"Apply or clear schema configuration.\"\"\"\n        super().apply_schema(schema_instance)\n        if schema_instance:\n            config = schema_instance.to_provider_config()\n            self.response_schema = config.get('response_schema')\n            self.structured_output = config.get('structured_output', False)\n        else:\n            self.response_schema = None\n            self.structured_output = False\n\n    def infer(self, batch_prompts, **kwargs):\n        for prompt in batch_prompts:\n            # Use schema in API call if available\n            api_params = {}\n            if self.response_schema:\n                api_params['response_schema'] = self.response_schema\n\n            result = self.client.generate(prompt, **api_params)\n            yield [lx.inference.ScoredOutput(score=1.0, output=result)]\n```\n\n#### 3. Schema Usage\n\nWhen users set `use_schema_constraints=True`, LangExtract will:\n1. Call your provider's `get_schema_class()`\n2. Use `from_examples()` to build a schema from provided examples\n3. Call `to_provider_config()` to get provider-specific kwargs\n4. Pass these kwargs to your provider's `__init__()`\n5. Your provider uses the schema for structured output\n\n### Option 2: Built-in Provider (Requires Core Team Approval)\n\n**⚠️ Note**: Adding a provider to the core package requires:\n- Significant community demand and support\n- Commitment to long-term maintenance\n- Approval from the LangExtract maintainers\n- A pull request to the main repository\n\nThis approach should only be used for providers that benefit a large portion of the user base.\n\n1. Create your provider file:\n```python\n# langextract/providers/myprovider.py\nimport langextract as lx\n\n@lx.providers.registry.register(r'^mymodel', r'^custom')\nclass MyProviderLanguageModel(lx.inference.BaseLanguageModel):\n    # Implementation same as above\n```\n\n2. Import it in `providers/__init__.py`:\n```python\n# In langextract/providers/__init__.py\nfrom langextract.providers import myprovider  # noqa: F401\n```\n\n3. Submit a pull request with:\n   - Provider implementation\n   - Comprehensive tests\n   - Documentation\n   - Justification for inclusion in core\n\n## Environment Variables\n\nThe factory automatically resolves API keys from environment:\n\n| Provider | Environment Variables (in priority order) |\n|----------|------------------------------------------|\n| Gemini   | `GEMINI_API_KEY`, `LANGEXTRACT_API_KEY` |\n| OpenAI   | `OPENAI_API_KEY`, `LANGEXTRACT_API_KEY` |\n| Ollama   | `OLLAMA_BASE_URL` (default: http://localhost:11434) |\n\n## Design Principles\n\n1. **Zero Configuration**: Providers auto-register when imported\n2. **Extensible**: Easy to add new providers without modifying core\n3. **Lazy Loading**: Optional dependencies only loaded when needed\n4. **Explicit Control**: Users can force specific providers when needed\n5. **Pattern Priority**: All patterns have equal priority (0) by default\n\n## Common Issues\n\n### Provider Not Found\n```python\nValueError: No provider registered for model_id='unknown-model'\n```\n**Solution**: Check available patterns with `registry.list_entries()`\n\n### Plugin Not Loading\n```python\n# Your plugin isn't being discovered\n```\n**Solutions**:\n1. Manually trigger loading: `lx.providers.load_plugins_once()`\n2. Check entry points are installed: `pip show -f your-package`\n3. Verify no typos in `pyproject.toml` entry point\n4. Ensure package is installed: `pip list | grep your-package`\n\n### Missing Dependencies\n```python\nInferenceConfigError: OpenAI provider requires openai package\n```\n**Solution**: Install optional dependencies: `pip install langextract[openai]`\n\n### Schema Not Working\n```python\n# Schema constraints not being applied\n```\n**Solutions**:\n1. Ensure provider implements `get_schema_class()`\n2. Check `use_schema_constraints=True` is set\n3. Verify schema's `supports_strict_mode` returns `True`\n4. Test schema creation with `Schema.from_examples(examples)`\n\n### Pattern Conflicts\n```python\n# Multiple providers match the same model_id\n```\n**Solution**: Use explicit provider selection:\n```python\nconfig = lx.factory.ModelConfig(\n    model_id=\"model-name\",\n    provider=\"YourProviderClass\"  # Explicit selection\n)\n"
  },
  {
    "path": "langextract/providers/__init__.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Provider package for LangExtract.\n\nThis package contains provider implementations for various LLM backends.\nEach provider can be imported independently for fine-grained dependency\nmanagement in build systems.\n\"\"\"\n\nimport importlib\nfrom importlib import metadata\nimport os\n\nfrom absl import logging\n\nfrom langextract.providers import builtin_registry\nfrom langextract.providers import router\n\nregistry = router  # Backward compat alias\n\n__all__ = [\n    \"gemini\",\n    \"openai\",\n    \"ollama\",\n    \"router\",\n    \"registry\",  # Backward compat\n    \"schemas\",\n    \"load_plugins_once\",\n    \"load_builtins_once\",\n]\n\n# Track provider loading for lazy initialization\n_plugins_loaded = False  # pylint: disable=invalid-name\n_builtins_loaded = False  # pylint: disable=invalid-name\n\n\ndef load_builtins_once() -> None:\n  \"\"\"Load built-in providers to register their patterns.\n\n  Idempotent function that ensures provider patterns are available\n  for model resolution. Uses lazy registration to ensure providers\n  can be re-registered after registry.clear() even if their modules\n  are already in sys.modules.\n  \"\"\"\n  global _builtins_loaded  # pylint: disable=global-statement\n\n  if _builtins_loaded:\n    return\n\n  # Register built-ins lazily so they can be re-registered after a registry.clear()\n  # even if their modules were already imported earlier in the test run.\n  for config in builtin_registry.BUILTIN_PROVIDERS:\n    router.register_lazy(\n        *config[\"patterns\"],\n        target=config[\"target\"],\n        priority=config[\"priority\"],\n    )\n\n  _builtins_loaded = True\n\n\ndef load_plugins_once() -> None:\n  \"\"\"Load provider plugins from installed packages.\n\n  Discovers and loads langextract provider plugins using entry points.\n  This function is idempotent - multiple calls have no effect.\n  \"\"\"\n  global _plugins_loaded  # pylint: disable=global-statement\n  if _plugins_loaded:\n    return\n\n  if os.environ.get(\"LANGEXTRACT_DISABLE_PLUGINS\", \"\").lower() in (\n      \"1\",\n      \"true\",\n      \"yes\",\n  ):\n    logging.info(\"Plugin loading disabled via LANGEXTRACT_DISABLE_PLUGINS\")\n    _plugins_loaded = True\n    return\n\n  load_builtins_once()\n\n  try:\n\n    eps = metadata.entry_points()\n\n    # Try different APIs based on what's available\n    if hasattr(eps, \"select\"):\n      # Python 3.10+ API\n      provider_eps = eps.select(group=\"langextract.providers\")\n    elif hasattr(eps, \"get\"):\n      # Python 3.9 API\n      provider_eps = eps.get(\"langextract.providers\", [])\n    else:\n      # Fallback for older versions\n      provider_eps = [\n          ep\n          for ep in eps\n          if getattr(ep, \"group\", None) == \"langextract.providers\"\n      ]\n\n    for entry_point in provider_eps:\n      try:\n\n        provider_class = entry_point.load()\n        logging.info(\"Loaded provider plugin: %s\", entry_point.name)\n\n        if hasattr(provider_class, \"get_model_patterns\"):\n          patterns = provider_class.get_model_patterns()\n          for pattern in patterns:\n            router.register(\n                pattern,\n                priority=getattr(\n                    provider_class,\n                    \"pattern_priority\",\n                    20,  # Default plugin priority\n                ),\n            )(provider_class)\n          logging.info(\n              \"Registered %d patterns for %s\", len(patterns), entry_point.name\n          )\n      except Exception as e:\n        logging.warning(\n            \"Failed to load provider plugin %s: %s\", entry_point.name, e\n        )\n\n  except Exception as e:\n    logging.warning(\"Error discovering provider plugins: %s\", e)\n\n  _plugins_loaded = True\n\n\ndef _reset_for_testing() -> None:\n  \"\"\"Reset plugin loading state for testing. Should only be used in tests.\"\"\"\n  global _plugins_loaded, _builtins_loaded  # pylint: disable=global-statement\n  _plugins_loaded = False\n  _builtins_loaded = False\n\n\ndef __getattr__(name: str):\n  \"\"\"Lazy loading for submodules.\"\"\"\n  if name == \"router\":\n    return importlib.import_module(\"langextract.providers.router\")\n  elif name == \"schemas\":\n    return importlib.import_module(\"langextract.providers.schemas\")\n  elif name == \"_plugins_loaded\":\n    return _plugins_loaded\n  elif name == \"_builtins_loaded\":\n    return _builtins_loaded\n  raise AttributeError(f\"module {__name__!r} has no attribute {name!r}\")\n"
  },
  {
    "path": "langextract/providers/builtin_registry.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Built-in provider registration configuration.\n\nThis module defines the registration details for all built-in providers,\nusing patterns from the centralized patterns module.\n\"\"\"\n\nfrom typing import TypedDict\n\nfrom langextract.providers import patterns\n\n\nclass ProviderConfig(TypedDict):\n  \"\"\"Configuration for a provider registration.\"\"\"\n\n  patterns: tuple[str, ...]\n  target: str\n  priority: int\n\n\n# Built-in provider configurations using centralized patterns\nBUILTIN_PROVIDERS: list[ProviderConfig] = [\n    {\n        'patterns': patterns.GEMINI_PATTERNS,\n        'target': 'langextract.providers.gemini:GeminiLanguageModel',\n        'priority': patterns.GEMINI_PRIORITY,\n    },\n    {\n        'patterns': patterns.OLLAMA_PATTERNS,\n        'target': 'langextract.providers.ollama:OllamaLanguageModel',\n        'priority': patterns.OLLAMA_PRIORITY,\n    },\n    {\n        'patterns': patterns.OPENAI_PATTERNS,\n        'target': 'langextract.providers.openai:OpenAILanguageModel',\n        'priority': patterns.OPENAI_PRIORITY,\n    },\n]\n"
  },
  {
    "path": "langextract/providers/gemini.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Gemini provider for LangExtract.\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport concurrent.futures\nimport dataclasses\nfrom typing import Any, Final, Iterator, Sequence\n\nfrom absl import logging\n\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import schema\nfrom langextract.core import types as core_types\nfrom langextract.providers import gemini_batch\nfrom langextract.providers import patterns\nfrom langextract.providers import router\nfrom langextract.providers import schemas\n\n_DEFAULT_MODEL_ID = 'gemini-2.5-flash'\n_DEFAULT_LOCATION = 'us-central1'\n_MIME_TYPE_JSON = 'application/json'\n\n_API_CONFIG_KEYS: Final[set[str]] = {\n    'response_mime_type',\n    'response_schema',\n    'safety_settings',\n    'system_instruction',\n    'tools',\n    'stop_sequences',\n    'candidate_count',\n}\n\n\n@router.register(\n    *patterns.GEMINI_PATTERNS,\n    priority=patterns.GEMINI_PRIORITY,\n)\n@dataclasses.dataclass(init=False)\nclass GeminiLanguageModel(base_model.BaseLanguageModel):  # pylint: disable=too-many-instance-attributes\n  \"\"\"Language model inference using Google's Gemini API with structured output.\"\"\"\n\n  model_id: str = _DEFAULT_MODEL_ID\n  api_key: str | None = None\n  vertexai: bool = False\n  credentials: Any | None = None\n  project: str | None = None\n  location: str | None = None\n  http_options: Any | None = None\n  gemini_schema: schemas.gemini.GeminiSchema | None = None\n  format_type: data.FormatType = data.FormatType.JSON\n  temperature: float = 0.0\n  max_workers: int = 10\n  fence_output: bool = False\n  _extra_kwargs: dict[str, Any] = dataclasses.field(\n      default_factory=dict, repr=False, compare=False\n  )\n\n  @classmethod\n  def get_schema_class(cls) -> type[schema.BaseSchema] | None:\n    \"\"\"Return the GeminiSchema class for structured output support.\n\n    Returns:\n      The GeminiSchema class that supports strict schema constraints.\n    \"\"\"\n    return schemas.gemini.GeminiSchema\n\n  def apply_schema(self, schema_instance: schema.BaseSchema | None) -> None:\n    \"\"\"Apply a schema instance to this provider.\n\n    Args:\n      schema_instance: The schema instance to apply, or None to clear.\n    \"\"\"\n    super().apply_schema(schema_instance)\n    if isinstance(schema_instance, schemas.gemini.GeminiSchema):\n      self.gemini_schema = schema_instance\n\n  def __init__(\n      self,\n      model_id: str = _DEFAULT_MODEL_ID,\n      api_key: str | None = None,\n      vertexai: bool = False,\n      credentials: Any | None = None,\n      project: str | None = None,\n      location: str | None = None,\n      http_options: Any | None = None,\n      gemini_schema: schemas.gemini.GeminiSchema | None = None,\n      format_type: data.FormatType = data.FormatType.JSON,\n      temperature: float = 0.0,\n      max_workers: int = 10,\n      fence_output: bool = False,\n      **kwargs,\n  ) -> None:\n    \"\"\"Initialize the Gemini language model.\n\n    Args:\n      model_id: The Gemini model ID to use.\n      api_key: API key for Gemini service.\n      vertexai: Whether to use Vertex AI instead of API key authentication.\n      credentials: Optional Google auth credentials for Vertex AI.\n      project: Google Cloud project ID for Vertex AI.\n      location: Vertex AI location (e.g., 'global', 'us-central1').\n      http_options: Optional HTTP options for the client (e.g., for VPC endpoints).\n      gemini_schema: Optional schema for structured output.\n      format_type: Output format (JSON or YAML).\n      temperature: Sampling temperature.\n      max_workers: Maximum number of parallel API calls.\n      fence_output: Whether to wrap output in markdown fences (ignored,\n        Gemini handles this based on schema).\n      **kwargs: Additional Gemini API parameters. Only allowlisted keys are\n        forwarded to the API (response_schema, response_mime_type, tools,\n        safety_settings, stop_sequences, candidate_count, system_instruction).\n        See https://ai.google.dev/api/generate-content for details.\n    \"\"\"\n    try:\n      # pylint: disable=import-outside-toplevel\n      from google import genai\n    except ImportError as e:\n      raise exceptions.InferenceConfigError(\n          'google-genai is required for Gemini. Install it with: pip install'\n          ' google-genai'\n      ) from e\n\n    self.model_id = model_id\n    self.api_key = api_key\n    self.vertexai = vertexai\n    self.credentials = credentials\n    self.project = project\n    self.location = location\n    self.http_options = http_options\n    self.gemini_schema = gemini_schema\n    self.format_type = format_type\n    self.temperature = temperature\n    self.max_workers = max_workers\n    self.fence_output = fence_output\n\n    # Extract batch config before we filter kwargs into _extra_kwargs\n    batch_cfg_dict = kwargs.pop('batch', None)\n    self._batch_cfg = gemini_batch.BatchConfig.from_dict(batch_cfg_dict)\n\n    if not self.api_key and not self.vertexai:\n      raise exceptions.InferenceConfigError(\n          'Gemini models require either:\\n  - An API key via api_key parameter'\n          ' or LANGEXTRACT_API_KEY env var\\n  - Vertex AI configuration with'\n          ' vertexai=True, project, and location'\n      )\n    if self.vertexai and (not self.project or not self.location):\n      raise exceptions.InferenceConfigError(\n          'Vertex AI mode requires both project and location parameters'\n      )\n\n    if self.api_key and self.vertexai:\n      logging.warning(\n          'Both API key and Vertex AI configuration provided. '\n          'API key will take precedence for authentication.'\n      )\n\n    self._client = genai.Client(\n        api_key=self.api_key,\n        vertexai=vertexai,\n        credentials=credentials,\n        project=project,\n        location=location,\n        http_options=http_options,\n    )\n\n    super().__init__(\n        constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)\n    )\n    self._extra_kwargs = {\n        k: v for k, v in (kwargs or {}).items() if k in _API_CONFIG_KEYS\n    }\n\n  def _validate_schema_config(self) -> None:\n    \"\"\"Validate that schema configuration is compatible with format type.\n\n    Raises:\n      InferenceConfigError: If gemini_schema is set but format_type is not JSON.\n    \"\"\"\n    if self.gemini_schema and self.format_type != data.FormatType.JSON:\n      raise exceptions.InferenceConfigError(\n          'Gemini structured output only supports JSON format. '\n          'Set format_type=JSON or use_schema_constraints=False.'\n      )\n\n  def _process_single_prompt(\n      self, prompt: str, config: dict\n  ) -> core_types.ScoredOutput:\n    \"\"\"Process a single prompt and return a ScoredOutput.\"\"\"\n    try:\n      # Apply stored kwargs that weren't already set in config\n      for key, value in self._extra_kwargs.items():\n        if key not in config and value is not None:\n          config[key] = value\n\n      if self.gemini_schema:\n        self._validate_schema_config()\n        config.setdefault('response_mime_type', 'application/json')\n        config.setdefault('response_schema', self.gemini_schema.schema_dict)\n\n      response = self._client.models.generate_content(\n          model=self.model_id, contents=prompt, config=config\n      )\n\n      return core_types.ScoredOutput(score=1.0, output=response.text)\n\n    except Exception as e:\n      raise exceptions.InferenceRuntimeError(\n          f'Gemini API error: {str(e)}', original=e\n      ) from e\n\n  def infer(\n      self, batch_prompts: Sequence[str], **kwargs\n  ) -> Iterator[Sequence[core_types.ScoredOutput]]:\n    \"\"\"Runs inference on a list of prompts via Gemini's API.\n\n    Args:\n      batch_prompts: A list of string prompts.\n      **kwargs: Additional generation params (temperature, top_p, top_k, etc.)\n\n    Yields:\n      Lists of ScoredOutputs.\n    \"\"\"\n    merged_kwargs = self.merge_kwargs(kwargs)\n\n    config = {\n        'temperature': merged_kwargs.get('temperature', self.temperature),\n    }\n    for key in ('max_output_tokens', 'top_p', 'top_k'):\n      if key in merged_kwargs:\n        config[key] = merged_kwargs[key]\n\n    handled_keys = {'temperature', 'max_output_tokens', 'top_p', 'top_k'}\n    for key, value in merged_kwargs.items():\n      if (\n          key not in handled_keys\n          and key in _API_CONFIG_KEYS\n          and value is not None\n      ):\n        config[key] = value\n\n    # Use batch API if threshold met\n    if self._batch_cfg and self._batch_cfg.enabled:\n      if len(batch_prompts) >= self._batch_cfg.threshold:\n        try:\n          if self.gemini_schema:\n            self._validate_schema_config()\n          schema_dict = (\n              self.gemini_schema.schema_dict if self.gemini_schema else None\n          )\n          # Remove schema fields from config for batch API - they're handled via schema_dict\n          batch_config = dict(config)\n          batch_config.pop('response_mime_type', None)\n          batch_config.pop('response_schema', None)\n          # Extract top-level fields that don't belong in generationConfig\n          system_instruction = batch_config.pop('system_instruction', None)\n          safety_settings = batch_config.pop('safety_settings', None)\n          outputs = gemini_batch.infer_batch(\n              client=self._client,\n              model_id=self.model_id,\n              prompts=batch_prompts,\n              schema_dict=schema_dict,\n              gen_config=batch_config,\n              cfg=self._batch_cfg,\n              system_instruction=system_instruction,\n              safety_settings=safety_settings,\n              project=self.project,\n              location=self.location,\n          )\n        except exceptions.InferenceRuntimeError:\n          raise\n        except Exception as e:\n          raise exceptions.InferenceRuntimeError(\n              f'Gemini Batch API error: {e}', original=e\n          ) from e\n\n        for text in outputs:\n          yield [core_types.ScoredOutput(score=1.0, output=text)]\n        return\n      else:\n        logging.info(\n            'Gemini batch mode enabled but prompt count (%d) is below the'\n            ' threshold (%d); using real-time API. Submit at least %d prompts'\n            ' to trigger batch mode.',\n            len(batch_prompts),\n            self._batch_cfg.threshold,\n            self._batch_cfg.threshold,\n        )\n\n    # Use parallel processing for batches larger than 1\n    if len(batch_prompts) > 1 and self.max_workers > 1:\n      with concurrent.futures.ThreadPoolExecutor(\n          max_workers=min(self.max_workers, len(batch_prompts))\n      ) as executor:\n        future_to_index = {\n            executor.submit(\n                self._process_single_prompt, prompt, config.copy()\n            ): i\n            for i, prompt in enumerate(batch_prompts)\n        }\n\n        results: list[core_types.ScoredOutput | None] = [None] * len(\n            batch_prompts\n        )\n        for future in concurrent.futures.as_completed(future_to_index):\n          index = future_to_index[future]\n          try:\n            results[index] = future.result()\n          except Exception as e:\n            raise exceptions.InferenceRuntimeError(\n                f'Parallel inference error: {str(e)}', original=e\n            ) from e\n\n        for result in results:\n          if result is None:\n            raise exceptions.InferenceRuntimeError(\n                'Failed to process one or more prompts'\n            )\n          yield [result]\n    else:\n      # Sequential processing for single prompt or worker\n      for prompt in batch_prompts:\n        result = self._process_single_prompt(prompt, config.copy())\n        yield [result]  # pylint: disable=duplicate-code\n"
  },
  {
    "path": "langextract/providers/gemini_batch.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Gemini Batch API helper module for LangExtract.\n\nThis module provides batch inference support using the google-genai SDK.\nIt handles:\n- File-based batch submission for all batch sizes\n- Job polling and result extraction\n- Schema-based structured output\n- Order preservation across batch processing\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Iterator, Sequence\nimport concurrent.futures\nimport dataclasses\nimport enum\nimport hashlib\nimport json\nimport logging as std_logging\nimport os\nimport re\nimport tempfile\nimport time\nfrom typing import Any, Callable, Protocol\nimport uuid\n\nfrom absl import logging\nfrom google import genai\nfrom google.api_core import exceptions as google_exceptions\nfrom google.cloud import storage\n\nfrom langextract.core import exceptions\n\n_MIME_TYPE_JSON = \"application/json\"\n_DEFAULT_LOCATION = \"us-central1\"\n_EXT_JSON = \".json\"\n_EXT_JSONL = \".jsonl\"\n_KEY_IDX = \"idx-\"\n_CACHE_PREFIX = \"cache\"\n_UNSET = object()\n\n\n@dataclasses.dataclass(slots=True, frozen=True)\nclass BatchConfig:\n  \"\"\"Define and validate Gemini Batch API configuration.\n\n  Attributes:\n    enabled: Whether batch mode is enabled.\n    threshold: Minimum prompts to trigger batch processing.\n    poll_interval: Seconds between job status checks.\n    timeout: Maximum seconds to wait for job completion.\n    max_prompts_per_job: Max prompts allowed in one batch job.\n    ignore_item_errors: If True, continue on per-item errors.\n    enable_caching: If True, use GCS-based caching for inference results.\n    retention_days: Days to keep GCS data (default 30). None for permanent.\n  \"\"\"\n\n  enabled: bool = False\n  threshold: int = 50\n  poll_interval: int = 30\n  timeout: int = 3600\n  max_prompts_per_job: int = 20000\n  ignore_item_errors: bool = False\n  enable_caching: bool | None = _UNSET  # type: ignore\n  retention_days: int | None = _UNSET  # type: ignore\n  on_job_create: Callable[[Any], None] | None = None\n\n  def __post_init__(self):\n    \"\"\"Validate numeric knobs early.\"\"\"\n\n    validations = [\n        (self.threshold >= 1, \"batch.threshold must be >= 1\"),\n        (self.poll_interval > 0, \"batch.poll_interval must be > 0\"),\n        (self.timeout > 0, \"batch.timeout must be > 0\"),\n        (self.timeout > 0, \"batch.timeout must be > 0\"),\n        (self.max_prompts_per_job > 0, \"batch.max_prompts_per_job must be > 0\"),\n    ]\n    for is_valid, error_msg in validations:\n      if not is_valid:\n        raise ValueError(error_msg)\n\n    if self.enabled:\n      if self.enable_caching is _UNSET:\n        raise ValueError(\n            \"batch.enable_caching must be explicitly set when batch is enabled\"\n        )\n      if self.retention_days is _UNSET:\n        raise ValueError(\n            \"batch.retention_days must be explicitly set when batch is enabled\"\n            \" (use None for permanent)\"\n        )\n      if self.retention_days is not None and self.retention_days <= 0:\n        raise ValueError(\n            \"batch.retention_days must be > 0 or None (for permanent). \"\n            \"0 (immediate delete) is not allowed.\"\n        )\n\n  @classmethod\n  def from_dict(cls, d: dict | None) -> BatchConfig:\n    \"\"\"Create BatchConfig from dictionary, using defaults for missing keys.\"\"\"\n    if d is None:\n      return cls()\n    valid_keys = {f.name for f in dataclasses.fields(cls)}\n    filtered_dict = {k: v for k, v in d.items() if k in valid_keys}\n\n    unknown = sorted(set(d.keys()) - valid_keys)\n    if unknown:\n      logging.warning(\n          \"Ignoring unknown batch config keys: %s\", \", \".join(unknown)\n      )\n    cfg = cls(**filtered_dict)\n    if cfg.on_job_create is None:\n      object.__setattr__(cfg, \"on_job_create\", _default_job_create_callback)\n    return cfg\n\n\n_TERMINAL_FAIL = frozenset({\n    genai.types.JobState.JOB_STATE_FAILED,\n    genai.types.JobState.JOB_STATE_CANCELLED,\n    genai.types.JobState.JOB_STATE_EXPIRED,\n})\n_TERMINAL_OK = frozenset({\n    genai.types.JobState.JOB_STATE_SUCCEEDED,\n    genai.types.JobState.JOB_STATE_PAUSED,\n})\n\n\ndef _default_job_create_callback(job: Any) -> None:\n  \"\"\"Default callback to log batch job details.\"\"\"\n  logging.info(\"Batch job created successfully: %s\", job.name)\n  logging.info(\"Job State: %s\", job.state)\n  # Extract project and job ID for console URL\n  try:\n    # job.name format: projects/{project}/locations/{location}/batchPredictionJobs/{job_id}\n    parts = job.name.split(\"/\")\n    if len(parts) >= 6:\n      job_id = parts[-1]\n      location = parts[3]\n      project = parts[1]\n      logging.info(\n          \"Job Console URL:\"\n          \" https://console.cloud.google.com/vertex-ai/locations/%s/batch-predictions/%s?project=%s\",\n          location,\n          job_id,\n          project,\n      )\n  except Exception:\n    pass\n\n\ndef _snake_to_camel(key: str) -> str:\n  \"\"\"Convert snake_case to camelCase for REST API compatibility.\"\"\"\n  parts = key.split(\"_\")\n  return parts[0] + \"\".join(p.title() for p in parts[1:])\n\n\ndef _is_vertexai_client(client) -> bool:\n  \"\"\"Check if client is configured for Vertex AI with explicit identity check.\n\n  Args:\n    client: The genai.Client instance to check.\n\n  Returns:\n    True if client.vertexai is explicitly True, False otherwise.\n  \"\"\"\n  return getattr(client, \"vertexai\", False) is True\n\n\ndef _get_project_location(\n    client: genai.Client,\n    project: str | None = None,\n    location: str | None = None,\n) -> tuple[str | None, str]:\n  \"\"\"Extract project and location from client or arguments.\"\"\"\n  if project:\n    proj = project\n  else:\n    # Try to get from client (if available in future versions) or env.\n    proj = getattr(client, \"project\", None) or os.getenv(\"GOOGLE_CLOUD_PROJECT\")\n\n  if location:\n    loc = location\n  else:\n    loc = getattr(client, \"location\", None) or _DEFAULT_LOCATION\n\n  return proj, loc\n\n\ndef _get_bucket_name(project: str | None, location: str) -> str:\n  \"\"\"Generate consistent GCS bucket name for batch operations.\"\"\"\n  base = f\"langextract-{project}-{location}-batch\".lower()\n  return re.sub(r\"[^a-z0-9._-]\", \"-\", base)\n\n\ndef _ensure_bucket_lifecycle(\n    bucket: storage.Bucket, retention_days: int | None\n) -> None:\n  \"\"\"Ensure bucket has a lifecycle rule to delete objects after retention_days.\n\n  This is a best-effort optimization to reduce storage costs. It checks if\n  a rule with the exact age exists, and if not, adds it. It does NOT remove\n  existing rules.\n\n  Args:\n    bucket: The GCS bucket to configure.\n    retention_days: Number of days to keep objects. If None, no rule is added.\n  \"\"\"\n  if retention_days is None or retention_days <= 0:\n    return\n\n  # Check if rule already exists\n  for rule in bucket.lifecycle_rules:\n    if (\n        rule.get(\"action\", {}).get(\"type\") == \"Delete\"\n        and rule.get(\"condition\", {}).get(\"age\") == retention_days\n    ):\n      return\n\n  # Add new rule\n  bucket.add_lifecycle_delete_rule(age=retention_days)\n  try:\n    bucket.patch()\n    logging.info(\n        \"Added lifecycle rule to bucket %s: delete after %d days\",\n        bucket.name,\n        retention_days,\n    )\n  except Exception as e:\n    logging.warning(\n        \"Failed to update lifecycle rule for bucket %s: %s\", bucket.name, e\n    )\n\n\ndef _build_request(\n    prompt: str,\n    schema_dict: dict | None,\n    gen_config: dict | None,\n    system_instruction: str | None = None,\n    safety_settings: Sequence[Any] | None = None,\n) -> dict:\n  \"\"\"Build a batch request in REST format for file-based submission.\n\n  Constructs a properly formatted request dictionary for batch processing.\n  Per the Gemini Batch API documentation, each request in the JSONL file\n  can include its own generationConfig with schema and generation parameters,\n  as well as top-level systemInstruction and safetySettings.\n\n  Args:\n    prompt: The text prompt to send to the model.\n    schema_dict: Optional JSON schema for structured output.\n    gen_config: Optional generation configuration parameters.\n    system_instruction: Optional system instruction text.\n    safety_settings: Optional safety settings sequence.\n\n  Returns:\n    A dictionary formatted for REST API file-based submission, containing:\n      * contents: The prompt content.\n      * systemInstruction: Optional system instructions.\n      * safetySettings: Optional safety settings.\n      * generationConfig: Optional generation configuration and schema.\n  \"\"\"\n  request = {\"contents\": [{\"role\": \"user\", \"parts\": [{\"text\": prompt}]}]}\n\n  if system_instruction:\n    request[\"systemInstruction\"] = {\"parts\": [{\"text\": system_instruction}]}\n\n  if safety_settings:\n    request[\"safetySettings\"] = safety_settings\n\n  if schema_dict or gen_config:\n    generation_config = {}\n    if schema_dict:\n      generation_config[\"responseMimeType\"] = _MIME_TYPE_JSON\n      generation_config[\"responseSchema\"] = schema_dict\n    if gen_config:\n      for k, v in gen_config.items():\n        generation_config[_snake_to_camel(k)] = v\n    request[\"generationConfig\"] = generation_config\n\n  return request\n\n\ndef _submit_file(\n    client: genai.Client,\n    model_id: str,\n    requests: Sequence[dict],\n    display: str,\n    retention_days: int | None,\n    project: str | None = None,\n    location: str | None = None,\n) -> genai.types.BatchJob:\n  \"\"\"Submit a file-based batch job to Vertex AI using GCS storage.\n\n  Batch processing is only supported with Vertex AI because it requires\n  GCS for file upload. Creates JSONL file, uploads to auto-created bucket,\n  and submits job for async processing.\n\n  Args:\n    client: google.genai.Client instance configured for Vertex AI\n        (must have client.vertexai=True).\n    model_id: Model identifier (e.g., \"gemini-2.5-flash\").\n    requests: List of request dictionaries with embedded configuration.\n        Each request contains contents and optional generationConfig\n        (including schema and generation parameters).\n    display: Display name for the batch job, used for identification and\n        as part of the GCS blob name.\n    retention_days: Days to keep GCS data. If set, applies lifecycle rule.\n    project: Optional GCP project ID. If not provided, will attempt to\n        determine from client or environment.\n    location: Optional GCP region/location. If not provided, will attempt to\n        determine from client or use default.\n\n  Returns:\n    BatchJob object that can be polled for completion status.\n\n  Raises:\n    ValueError: If client is not configured for Vertex AI.\n  \"\"\"\n  path = None\n  try:\n    with tempfile.NamedTemporaryFile(\n        \"w\", suffix=_EXT_JSONL, delete=False, encoding=\"utf-8\"\n    ) as f:\n      path = f.name\n      for idx, req in enumerate(requests):\n        # We use a simple \"idx-{N}\" key format to track the original order\n        # of prompts, as batch processing may return results out of order.\n        line = {\"key\": f\"{_KEY_IDX}{idx}\", \"request\": req}\n        f.write(json.dumps(line, ensure_ascii=False) + \"\\n\")\n\n    project, location = _get_project_location(client, project, location)\n    bucket_name = _get_bucket_name(project, location)\n    blob_name = f\"batch-input/{display}-{uuid.uuid4().hex}.jsonl\"\n\n    storage_client = storage.Client(project=project)\n    try:\n      bucket = storage_client.create_bucket(bucket_name, location=location)\n      logging.info(\"Created GCS bucket: %s\", bucket_name)\n    except google_exceptions.Conflict:\n      bucket = storage_client.bucket(bucket_name)\n      logging.info(\"Using existing GCS bucket: %s\", bucket_name)\n\n    if retention_days:\n      _ensure_bucket_lifecycle(bucket, retention_days)\n\n    blob = bucket.blob(blob_name)\n    blob.upload_from_filename(path)\n\n    gcs_uri = f\"gs://{bucket.name}/{blob.name}\"\n\n    # Create batch job (config and schema are in per-request generationConfig)\n    job = client.batches.create(\n        model=model_id, src=gcs_uri, config={\"display_name\": display}\n    )\n    return job\n  finally:\n    if path:\n      try:\n        os.unlink(path)\n      except OSError:\n        pass\n\n\nclass GCSBatchCache:\n  \"\"\"GCS-based cache for batch inference results.\"\"\"\n\n  def __init__(self, bucket_name: str, project: str | None = None):\n    self.bucket_name = bucket_name\n    self.project = project\n    self._client = storage.Client(project=project)\n    self._bucket = self._client.bucket(bucket_name)\n\n  def _compute_hash(self, key_data: dict) -> str:\n    \"\"\"Compute SHA256 hash of the canonicalized request data.\"\"\"\n    canonical_json = json.dumps(key_data, sort_keys=True, ensure_ascii=False)\n    return hashlib.sha256(canonical_json.encode(\"utf-8\")).hexdigest()\n\n  def _get_single(self, key_hash: str) -> str | None:\n    \"\"\"Fetch single item from GCS.\"\"\"\n    blob = self._bucket.blob(f\"{_CACHE_PREFIX}/{key_hash}{_EXT_JSON}\")\n    try:\n      data = json.loads(blob.download_as_text())\n      return data.get(\"text\")\n    except google_exceptions.NotFound:\n      return None\n    except Exception as e:\n      logging.warning(\"Cache read error for %s: %s\", key_hash, e)\n    return None\n\n  def get_multi(self, key_data_list: Sequence[dict]) -> dict[int, str]:\n    \"\"\"Fetch multiple items from GCS in parallel.\n\n    Returns:\n      Dict mapping index in key_data_list to cached text.\n    \"\"\"\n    results = {}\n    # Limit max_workers to 10 to match default HTTP connection pool size.\n    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:\n      future_to_idx = {}\n      for idx, key_data in enumerate(key_data_list):\n        key_hash = self._compute_hash(key_data)\n        future = executor.submit(self._get_single, key_hash)\n        future_to_idx[future] = idx\n\n      for future in concurrent.futures.as_completed(future_to_idx):\n        idx = future_to_idx[future]\n        text = future.result()\n        if text is not None:\n          results[idx] = text\n    return results\n\n  def set_multi(self, items: Sequence[tuple[dict, str]]) -> None:\n    \"\"\"Upload multiple items to GCS in parallel.\n\n    Args:\n      items: List of (key_data, result_text) tuples.\n    \"\"\"\n\n    def _upload(text: str, key_data: dict):\n      key_hash = self._compute_hash(key_data)\n      blob = self._bucket.blob(f\"{_CACHE_PREFIX}/{key_hash}{_EXT_JSON}\")\n      try:\n        blob.upload_from_string(\n            json.dumps({\"text\": text}, ensure_ascii=False),\n            content_type=_MIME_TYPE_JSON,\n        )\n      except Exception as e:\n        logging.warning(\n            \"Cache write error for %s: %s\", key_hash, e, exc_info=True\n        )\n\n    def _json_default(obj):\n      if dataclasses.is_dataclass(obj):\n        return dataclasses.asdict(obj)\n      if isinstance(obj, enum.Enum):\n        return obj.value\n      raise TypeError(f\"Object of type {type(obj)} is not JSON serializable\")\n\n    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:\n      for key_data, text in items:\n        # If text is not a string, try to serialize it\n        if not isinstance(text, str):\n          try:\n            text = json.dumps(text, default=_json_default, ensure_ascii=False)\n          except Exception as e:\n            logging.warning(\"Serialization error: %s\", e)\n            continue\n\n        executor.submit(_upload, text, key_data)\n\n  def iter_items(self) -> Iterator[tuple[str, str]]:\n    \"\"\"Iterate over all items in the cache.\n\n    Yields:\n      Tuple of (key_hash, text_content).\n    \"\"\"\n    blobs = self._bucket.list_blobs(prefix=f\"{_CACHE_PREFIX}/\")\n    for blob in blobs:\n      if not blob.name.endswith(_EXT_JSON):\n        continue\n      try:\n        key_hash = blob.name.split(\"/\")[-1].replace(_EXT_JSON, \"\")\n        data = json.loads(blob.download_as_text())\n        text = data.get(\"text\")\n        if text is not None:\n          yield key_hash, text\n      except (json.JSONDecodeError, Exception) as e:\n        logging.warning(\"Failed to read cache item %s: %s\", blob.name, e)\n\n\nclass _TextResponse(Protocol):\n  \"\"\"Protocol for inline response objects with text attribute.\"\"\"\n\n  text: str\n\n\ndef _safe_get_nested(data: dict, *keys) -> Any:\n  \"\"\"Safely traverse nested dictionaries/lists.\n\n  Args:\n    data: The dict to traverse.\n    *keys: Keys/indices to access. Use integers for list indices.\n\n  Returns:\n    The value at the path, or None if any key doesn't exist.\n  \"\"\"\n  current = data\n  for key in keys:\n    if current is None:\n      return None\n    if isinstance(key, int):\n      if not isinstance(current, list) or len(current) <= key:\n        return None\n      current = current[key]\n    else:\n      if not isinstance(current, dict):\n        return None\n      current = current.get(key)\n  return current\n\n\ndef _extract_text(resp: _TextResponse | dict[str, Any] | None) -> str | None:\n  \"\"\"Extract text from Vertex AI batch API response.\n\n  Args:\n    resp: Response object (inline) or dict (file) containing text.\n\n  Returns:\n    Extracted text string, or None if not found or invalid.\n  \"\"\"\n  if resp is None:\n    return None\n\n  if hasattr(resp, \"text\"):\n    text = getattr(resp, \"text\", None)\n    return text if isinstance(text, str) else None\n\n  if not isinstance(resp, dict):\n    return None\n\n  # Vertex AI format: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"...\"}]}}]}\n  text = _safe_get_nested(resp, \"candidates\", 0, \"content\", \"parts\", 0, \"text\")\n  return text if isinstance(text, str) else None\n\n\ndef _poll_completion(\n    client: genai.Client, job: genai.types.BatchJob, cfg: BatchConfig\n) -> genai.types.BatchJob:\n  \"\"\"Poll batch job until completion or timeout.\n\n  Args:\n    client: google.genai.Client instance for polling job status.\n    job: Batch job object returned from client.batches.create().\n    cfg: Batch configuration including timeout and poll_interval.\n\n  Returns:\n    Completed batch job object.\n\n  Raises:\n    RuntimeError: If the job enters a failed terminal state.\n    TimeoutError: If the job does not complete within cfg.timeout.\n  \"\"\"\n  start = time.time()\n  name = job.name\n\n  while True:\n    job = client.batches.get(name=name)\n    state = job.state\n\n    if state in _TERMINAL_OK:\n      return job\n\n    if state in _TERMINAL_FAIL:\n      error_details = job.error or \"(no error details)\"\n      raise exceptions.InferenceRuntimeError(\n          f\"Batch job failed: state={state.name}, name={name}, \"\n          f\"error={error_details}\"\n      )\n\n    if time.time() - start > cfg.timeout:\n      try:\n        client.batches.cancel(name=name)\n      except Exception as e:\n        logging.warning(\"Failed to cancel timed-out batch job %s: %s\", name, e)\n      raise exceptions.InferenceRuntimeError(\n          f\"Batch job timed out after {cfg.timeout}s: {name}\"\n      )\n\n    time.sleep(cfg.poll_interval)\n    logging.info(\"Batch job is running... (State: %s)\", state.name)\n\n\ndef _parse_batch_line(\n    line: str, outputs: dict[int, str], cfg: BatchConfig\n) -> None:\n  \"\"\"Parse a single line from batch output JSONL.\"\"\"\n  try:\n    obj = json.loads(line)\n  except json.JSONDecodeError:\n    return\n\n  error = obj.get(\"error\")\n  if error and not cfg.ignore_item_errors:\n    code = error.get(\"code\") if isinstance(error, dict) else None\n    if code not in (None, 0):\n      raise exceptions.InferenceRuntimeError(f\"Batch item error: {error}\")\n\n  resp = obj.get(\"response\", {})\n  text = _extract_text(resp) or \"\"\n\n  key = obj.get(\"key\", \"\")\n  try:\n    # Extract the original index from the key (e.g., \"idx-5\" -> 5)\n    idx = int(str(key).rsplit(_KEY_IDX, maxsplit=1)[-1])\n  except (ValueError, IndexError):\n    idx = max(outputs.keys(), default=-1) + 1\n  outputs[idx] = text\n\n\ndef _extract_from_file(\n    client: genai.Client,\n    job: genai.types.BatchJob,\n    cfg: BatchConfig,\n    expected_count: int,\n) -> list[str]:\n  \"\"\"Extract text outputs from file-based batch results, preserving order.\n\n  Reads results from GCS output directory.\n\n  Args:\n    client: google.genai.Client instance for downloading result file.\n    job: Completed batch job object with result location.\n    cfg: Batch configuration including error handling settings.\n    expected_count: Number of prompts submitted (for order preservation).\n\n  Returns:\n    List of text outputs corresponding 1:1 to input prompts. Missing results\n    are padded with empty strings.\n\n  Raises:\n    RuntimeError: If job is missing result location or item has error.\n  \"\"\"\n  if not _is_vertexai_client(client):\n    raise ValueError(\"Batch API is only supported with Vertex AI.\")\n\n  outputs_by_idx: dict[int, str] = {}\n\n  if not job.dest:\n    raise exceptions.InferenceRuntimeError(\"Vertex AI batch job missing dest\")\n  gcs_uri = getattr(job.dest, \"gcs_uri\", None) or getattr(\n      job.dest, \"gcs_output_directory\", None\n  )\n  if not gcs_uri:\n    raise exceptions.InferenceRuntimeError(\n        \"Vertex AI batch job missing output GCS URI\"\n    )\n\n  if not gcs_uri.startswith(\"gs://\"):\n    raise exceptions.InferenceRuntimeError(f\"Invalid GCS URI format: {gcs_uri}\")\n\n  bucket_name, _, prefix = gcs_uri[5:].partition(\"/\")\n\n  project = getattr(client, \"project\", None) or os.getenv(\n      \"GOOGLE_CLOUD_PROJECT\"\n  )\n  storage_client = storage.Client(project=project)\n  bucket = storage_client.bucket(bucket_name)\n\n  # Vertex AI may write multiple output files.\n  blobs = list(bucket.list_blobs(prefix=prefix))\n  if not blobs:\n    raise exceptions.InferenceRuntimeError(\n        f\"No output files found in {gcs_uri}\"\n    )\n\n  logging.info(\"Batch API: Downloading results from %s\", gcs_uri)\n  logging.info(\"Batch API: Found %d output files\", len(blobs))\n\n  for blob in blobs:\n    if not blob.name.endswith(_EXT_JSONL):\n      continue\n\n    # Stream file line by line to avoid loading entire file into memory.\n    with blob.open(\"r\", encoding=\"utf-8\") as f:\n      for line in f:\n        if not line.strip():\n          continue\n        _parse_batch_line(line, outputs_by_idx, cfg)\n\n  logging.info(\"Batch API: Parsed %d results\", len(outputs_by_idx))\n  return [outputs_by_idx.get(i, \"\") for i in range(expected_count)]\n\n\ndef infer_batch(\n    client: genai.Client,\n    model_id: str,\n    prompts: Sequence[str],\n    schema_dict: dict | None,\n    gen_config: dict,\n    cfg: BatchConfig,\n    system_instruction: str | None = None,\n    safety_settings: Sequence[Any] | None = None,\n    project: str | None = None,\n    location: str | None = None,\n) -> list[str]:\n  \"\"\"Execute batch inference on multiple prompts using the Vertex AI Batch API.\n\n  This function provides file-based batch processing via Vertex AI. It:\n  - Uploads prompts to GCS (Google Cloud Storage)\n  - Submits batch job to Vertex AI\n  - Polls for job completion\n  - Extracts and returns results\n\n  Args:\n    client: google.genai.Client instance configured for Vertex AI\n        (must have client.vertexai=True).\n    model_id: Model identifier (e.g., \"gemini-2.5-flash\").\n    prompts: Sequence of prompts to process in batch.\n    schema_dict: Optional JSON schema for structured output. When provided,\n        enables JSON mode with the specified schema constraints.\n    gen_config: Generation configuration parameters (temperature, top_p, etc.).\n    cfg: Batch configuration including thresholds, timeouts, and error handling.\n    system_instruction: Optional system instruction text.\n    safety_settings: Optional safety settings sequence.\n    project: Google Cloud project ID (optional, overrides client/env).\n    location: Vertex AI location (optional, overrides client/env).\n\n  Returns:\n    List of text outputs corresponding 1:1 to input prompts. Missing results\n    are padded with empty strings.\n\n  Raises:\n    RuntimeError: If batch job fails or individual items have errors\n        (when cfg.ignore_item_errors is False).\n    TimeoutError: If batch job doesn't complete within cfg.timeout seconds.\n  \"\"\"\n  if not prompts:\n    return []\n\n  if not _is_vertexai_client(client):\n    raise ValueError(\n        \"Batch API is only supported with Vertex AI. To use batch mode, create\"\n        \" your client with: genai.Client(vertexai=True, project='YOUR_PROJECT',\"\n        \" location='us-central1'). For Google AI API keys, batch mode is not\"\n        \" currently supported.\"\n    )\n\n  # Suppress verbose HTTP logs from underlying libraries\n  std_logging.getLogger(\"google.auth.transport.requests\").setLevel(\n      std_logging.WARNING\n  )\n  std_logging.getLogger(\"urllib3.connectionpool\").setLevel(std_logging.WARNING)\n  std_logging.getLogger(\"httpx\").setLevel(std_logging.WARNING)\n  std_logging.getLogger(\"httpcore\").setLevel(std_logging.WARNING)\n  # Force disable httpx propagation or handlers if level setting fails\n  std_logging.getLogger(\"httpx\").disabled = True\n\n  logging.info(\"Batch API: Processing %d prompts\", len(prompts))\n\n  display_base = f\"langextract-batch-{int(time.time())}\"\n\n  project, location = _get_project_location(client, project, location)\n  bucket_name = _get_bucket_name(project, location)\n\n  cache = GCSBatchCache(bucket_name, project) if cfg.enable_caching else None\n  if cache:\n    logging.info(\n        \"Batch API: Using GCS bucket:\"\n        \" https://console.cloud.google.com/storage/browser/%s\",\n        bucket_name,\n    )\n\n  prompts_to_process: list[tuple[int, str]] = []\n  cached_results: dict[int, str] = {}\n\n  if cache:\n\n    key_data_list = []\n    for prompt in prompts:\n      key_data_list.append({\n          \"model_id\": model_id,\n          \"prompt\": prompt,\n          \"system_instruction\": system_instruction,\n          \"gen_config\": gen_config,\n          \"safety_settings\": safety_settings,\n          \"schema\": schema_dict,\n      })\n\n    cached_results = cache.get_multi(key_data_list)\n\n    for idx, prompt in enumerate(prompts):\n      if idx not in cached_results:\n        prompts_to_process.append((idx, prompt))\n  else:\n    prompts_to_process = list(enumerate(prompts))\n\n  if not prompts_to_process:\n    logging.info(\"Batch API: All %d prompts found in cache\", len(prompts))\n    return [cached_results[i] for i in range(len(prompts))]\n\n  logging.info(\n      \"Batch API: %d cached, %d to submit\",\n      len(cached_results),\n      len(prompts_to_process),\n  )\n\n  def _process_batch(\n      batch_items: Sequence[tuple[int, str]], display: str\n  ) -> dict[int, str]:\n    \"\"\"Submit batch job, poll completion, and extract results.\n\n    Returns:\n      Dict mapping original index to result text.\n    \"\"\"\n    batch_prompts = [p for _, p in batch_items]\n    requests = [\n        _build_request(\n            p, schema_dict, gen_config, system_instruction, safety_settings\n        )\n        for p in batch_prompts\n    ]\n    job = _submit_file(\n        client,\n        model_id,\n        requests,\n        display,\n        cfg.retention_days,\n        project,\n        location,\n    )\n    if cfg.on_job_create:\n      try:\n        cfg.on_job_create(job)\n      except Exception as e:\n        logging.warning(\"Batch job creation callback failed: %s\", e)\n    job = _poll_completion(client, job, cfg)\n    logging.info(\"Batch job completed successfully.\")\n    results = _extract_from_file(\n        client, job, cfg, expected_count=len(batch_prompts)\n    )\n\n    # Map results back to original indices\n    mapped_results = {}\n    for (orig_idx, _), result in zip(batch_items, results):\n      mapped_results[orig_idx] = result\n\n    return mapped_results\n\n  new_results: dict[int, str] = {}\n\n  if (\n      cfg.max_prompts_per_job\n      and len(prompts_to_process) > cfg.max_prompts_per_job\n  ):\n    chunk_size = cfg.max_prompts_per_job\n    for chunk_num, i in enumerate(\n        range(0, len(prompts_to_process), chunk_size)\n    ):\n      chunk_items = prompts_to_process[i : i + chunk_size]\n      chunk_results = _process_batch(\n          chunk_items, f\"{display_base}-part-{chunk_num}\"\n      )\n      new_results.update(chunk_results)\n  else:\n    new_results = _process_batch(prompts_to_process, display_base)\n\n  if cache:\n    upload_list = []\n    for idx, text in new_results.items():\n      prompt = prompts[idx]\n      key_data = {\n          \"model_id\": model_id,\n          \"prompt\": prompt,\n          \"system_instruction\": system_instruction,\n          \"gen_config\": gen_config,\n          \"safety_settings\": safety_settings,\n          \"schema\": schema_dict,\n      }\n      upload_list.append((key_data, text))\n\n    cache.set_multi(upload_list)\n\n  final_outputs = []\n  for i in range(len(prompts)):\n    if i in cached_results:\n      final_outputs.append(cached_results[i])\n    else:\n      final_outputs.append(new_results.get(i, \"\"))\n\n  return final_outputs\n"
  },
  {
    "path": "langextract/providers/ollama.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Ollama provider for LangExtract.\n\nThis provider enables using local Ollama models with LangExtract's extract() function.\nNo API key is required since Ollama runs locally on your machine.\n\nUsage with extract():\n    import langextract as lx\n    from langextract.data import ExampleData, Extraction\n\n    # Create an example for few-shot learning\n    example = ExampleData(\n        text=\"Marie Curie was a pioneering physicist and chemist.\",\n        extractions=[\n            Extraction(\n                extraction_class=\"person\",\n                extraction_text=\"Marie Curie\",\n                attributes={\"name\": \"Marie Curie\", \"field\": \"physics and chemistry\"}\n            )\n        ]\n    )\n\n    # Basic usage with Ollama\n    result = lx.extract(\n        text_or_documents=\"Isaac Asimov was a prolific science fiction writer.\",\n        model_id=\"gemma2:2b\",\n        prompt_description=\"Extract the person's name and field\",\n        examples=[example],\n    )\n\nDirect provider instantiation (when model ID conflicts with other providers):\n    from langextract.providers.ollama import OllamaLanguageModel\n\n    # Create Ollama provider directly\n    model = OllamaLanguageModel(\n        model_id=\"gemma2:2b\",\n        model_url=\"http://localhost:11434\",  # optional, uses default if not specified\n    )\n\n    # Use with extract by passing the model instance\n    result = lx.extract(\n        text_or_documents=\"Your text here\",\n        model=model,  # Pass the model instance directly\n        prompt_description=\"Extract information\",\n        examples=[example],\n    )\n\nUsing pre-configured FormatHandler for manual control:\n    from langextract.providers.ollama import OLLAMA_FORMAT_HANDLER\n\n    # Use the pre-configured Ollama FormatHandler\n    result = lx.extract(\n        text_or_documents=\"Your text here\",\n        model_id=\"gemma2:2b\",\n        prompt_description=\"Extract information\",\n        examples=[example],\n        resolver_params={'format_handler': OLLAMA_FORMAT_HANDLER}\n    )\n\nSupported model ID formats:\n    - Standard Ollama: llama3.2:1b, gemma2:2b, mistral:7b, qwen2.5:7b, etc.\n    - Hugging Face style: meta-llama/Llama-3.2-1B-Instruct, google/gemma-2b, etc.\n\nPrerequisites:\n    1. Install Ollama: https://ollama.ai\n    2. Pull the model: ollama pull gemma2:2b\n    3. Ollama server will start automatically when you use extract()\n\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport dataclasses\nfrom typing import Any, Iterator, Mapping, Sequence\nfrom urllib.parse import urljoin\nfrom urllib.parse import urlparse\nimport warnings\n\nimport requests\n\n# Import from core modules directly\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import format_handler as fh\nfrom langextract.core import schema\nfrom langextract.core import types as core_types\nfrom langextract.providers import patterns\nfrom langextract.providers import router\n\n# Ollama defaults\n_OLLAMA_DEFAULT_MODEL_URL = 'http://localhost:11434'\n_DEFAULT_TEMPERATURE = 0.1\n_DEFAULT_TIMEOUT = 120\n_DEFAULT_KEEP_ALIVE = 5 * 60  # 5 minutes\n_DEFAULT_NUM_CTX = 2048\n\n# Pre-configured FormatHandler for consistent Ollama configuration\n# use_wrapper=True creates {\"extractions\": [...]} vs just [...]\n# Ollama's JSON mode expects a dictionary root, not a bare list\nOLLAMA_FORMAT_HANDLER = fh.FormatHandler(\n    format_type=data.FormatType.JSON,\n    use_wrapper=True,\n    wrapper_key=None,\n    use_fences=False,\n    strict_fences=False,\n)\n\n\n@router.register(\n    *patterns.OLLAMA_PATTERNS,\n    priority=patterns.OLLAMA_PRIORITY,\n)\n@dataclasses.dataclass(init=False)\nclass OllamaLanguageModel(base_model.BaseLanguageModel):\n  \"\"\"Language model inference class using Ollama based host.\n\n  Timeout can be set via constructor or passed through lx.extract():\n    lx.extract(..., language_model_params={\"timeout\": 300})\n\n  Authentication is supported for proxied Ollama instances:\n    lx.extract(..., language_model_params={\"api_key\": \"sk-...\"})\n  \"\"\"\n\n  _model: str\n  _model_url: str\n  format_type: core_types.FormatType = core_types.FormatType.JSON\n  _constraint: schema.Constraint = dataclasses.field(\n      default_factory=schema.Constraint, repr=False, compare=False\n  )\n  _extra_kwargs: dict[str, Any] = dataclasses.field(\n      default_factory=dict, repr=False, compare=False\n  )\n  # Authentication\n  _api_key: str | None = None\n  _auth_scheme: str = 'Bearer'\n  _auth_header: str = 'Authorization'\n\n  @classmethod\n  def get_schema_class(cls) -> type[schema.BaseSchema] | None:\n    \"\"\"Return the FormatModeSchema class for JSON output support.\n\n    Returns:\n      The FormatModeSchema class that enables JSON mode (non-strict).\n    \"\"\"\n    return schema.FormatModeSchema\n\n  def __repr__(self) -> str:\n    \"\"\"Return string representation with redacted API key.\"\"\"\n    api_key_display = '[REDACTED]' if self._api_key else None\n    return (\n        f'{self.__class__.__name__}('\n        f'model={self._model!r}, '\n        f'model_url={self._model_url!r}, '\n        f'format_type={self.format_type!r}, '\n        f'api_key={api_key_display})'\n    )\n\n  def __init__(\n      self,\n      model_id: str,\n      model_url: str = _OLLAMA_DEFAULT_MODEL_URL,\n      base_url: str | None = None,  # Alias for model_url\n      format_type: core_types.FormatType | None = None,\n      structured_output_format: str | None = None,  # Deprecated\n      constraint: schema.Constraint = schema.Constraint(),\n      timeout: int | None = None,\n      **kwargs,\n  ) -> None:\n    \"\"\"Initialize the Ollama language model.\n\n    Args:\n      model_id: The Ollama model ID to use.\n      model_url: URL for Ollama server (legacy parameter).\n      base_url: Alternative parameter name for Ollama server URL.\n      format_type: Output format (JSON or YAML). Defaults to JSON.\n      structured_output_format: DEPRECATED - use format_type instead.\n      constraint: Schema constraints.\n      timeout: Request timeout in seconds. Defaults to 120.\n      **kwargs: Additional parameters.\n    \"\"\"\n    self._requests = requests\n\n    # Handle deprecated structured_output_format parameter\n    if structured_output_format is not None:\n      warnings.warn(\n          \"'structured_output_format' is deprecated and will be removed in \"\n          \"v2.0.0. Use 'format_type' instead.\",\n          FutureWarning,\n          stacklevel=2,\n      )\n      if format_type is None:\n        format_type = (\n            core_types.FormatType.JSON\n            if structured_output_format == 'json'\n            else core_types.FormatType.YAML\n        )\n\n    fmt = kwargs.pop('format', None)\n    if format_type is None and fmt in ('json', 'yaml'):\n      format_type = (\n          core_types.FormatType.JSON\n          if fmt == 'json'\n          else core_types.FormatType.YAML\n      )\n\n    if format_type is None:\n      format_type = core_types.FormatType.JSON\n\n    self._model = model_id\n    self._model_url = base_url or model_url or _OLLAMA_DEFAULT_MODEL_URL\n    self.format_type = format_type\n    self._constraint = constraint\n\n    self._api_key = kwargs.pop('api_key', None)\n    self._auth_scheme = kwargs.pop('auth_scheme', 'Bearer')\n    self._auth_header = kwargs.pop('auth_header', 'Authorization')\n\n    if self._api_key:\n      host = urlparse(self._model_url).hostname\n      if host in ('localhost', '127.0.0.1', '::1'):\n        warnings.warn(\n            'API key provided for localhost Ollama instance. '\n            \"Native Ollama doesn't require authentication. \"\n            'This is typically only needed for proxied instances.',\n            UserWarning,\n        )\n\n    super().__init__(constraint=constraint)\n    if timeout is not None:\n      kwargs['timeout'] = timeout\n    self._extra_kwargs = kwargs or {}\n\n  def infer(\n      self, batch_prompts: Sequence[str], **kwargs\n  ) -> Iterator[Sequence[core_types.ScoredOutput]]:\n    \"\"\"Runs inference on a list of prompts via Ollama's API.\n\n    Args:\n      batch_prompts: A list of string prompts.\n      **kwargs: Additional generation params.\n\n    Yields:\n      Lists of ScoredOutputs.\n    \"\"\"\n    combined_kwargs = self.merge_kwargs(kwargs)\n\n    for prompt in batch_prompts:\n      try:\n        response = self._ollama_query(\n            prompt=prompt,\n            model=self._model,\n            structured_output_format='json'\n            if self.format_type == core_types.FormatType.JSON\n            else 'yaml',\n            model_url=self._model_url,\n            **combined_kwargs,\n        )\n        yield [core_types.ScoredOutput(score=1.0, output=response['response'])]\n      except Exception as e:\n        raise exceptions.InferenceRuntimeError(\n            f'Ollama API error: {str(e)}', original=e\n        ) from e\n\n  def _ollama_query(\n      self,\n      prompt: str,\n      model: str | None = None,\n      temperature: float | None = None,\n      seed: int | None = None,\n      top_k: int | None = None,\n      top_p: float | None = None,\n      max_output_tokens: int | None = None,\n      structured_output_format: str | None = None,\n      system: str = '',\n      raw: bool = False,\n      model_url: str | None = None,\n      timeout: int | None = None,\n      keep_alive: int | None = None,\n      num_threads: int | None = None,\n      num_ctx: int | None = None,\n      stop: str | list[str] | None = None,\n      **kwargs,\n  ) -> Mapping[str, Any]:\n    \"\"\"Sends a prompt to an Ollama model and returns the generated response.\n\n    Note: This is a low-level method. Constructor timeout is only used when\n    calling through infer(). Direct calls use the timeout parameter here.\n\n    This function makes an HTTP POST request to the `/api/generate` endpoint of\n    an Ollama server. It can optionally load the specified model first, generate\n    a response (with or without streaming), then return a parsed JSON response.\n\n    Args:\n      prompt: The text prompt to send to the model.\n      model: The name of the model to use. Defaults to self._model.\n      temperature: Sampling temperature. Higher values produce more diverse\n        output.\n      seed: Seed for reproducible generation. If None, random seed is used.\n      top_k: The top-K parameter for sampling.\n      top_p: The top-P (nucleus) sampling parameter.\n      max_output_tokens: Maximum tokens to generate. If None, the model's\n        default is used.\n      structured_output_format: If set to \"json\" or a JSON schema dict, requests\n        structured outputs from the model. See Ollama documentation for details.\n      system: A system prompt to override any system-level instructions.\n      raw: If True, bypasses any internal prompt templating; you provide the\n        entire raw prompt.\n      model_url: The base URL for the Ollama server. Defaults to self._model_url.\n      timeout: Timeout (in seconds) for the HTTP request. Defaults to 120.\n      keep_alive: How long (in seconds) the model remains loaded after\n        generation completes.\n      num_threads: Number of CPU threads to use. If None, Ollama uses a default\n        heuristic.\n      num_ctx: Number of context tokens allowed. If None, uses model's default\n        or config.\n      stop: Stop sequences to halt generation. Can be a string or list of strings.\n      **kwargs: Additional parameters passed through.\n\n    Returns:\n      A mapping (dictionary-like) containing the server's JSON response. For\n      non-streaming calls, the `\"response\"` key typically contains the entire\n      generated text.\n\n    Raises:\n      InferenceConfigError: If the server returns a 404 (model not found).\n      InferenceRuntimeError: For any other HTTP errors, timeouts, or request\n        exceptions.\n    \"\"\"\n    model = model or self._model\n    model_url = model_url or self._model_url\n    if structured_output_format is None and self.format_type is not None:\n      structured_output_format = (\n          'json' if self.format_type == core_types.FormatType.JSON else 'yaml'\n      )\n\n    options: dict[str, Any] = {}\n    if keep_alive is not None:\n      options['keep_alive'] = keep_alive\n    else:\n      options['keep_alive'] = _DEFAULT_KEEP_ALIVE\n\n    if seed is not None:\n      options['seed'] = seed\n    if temperature is not None:\n      options['temperature'] = temperature\n    else:\n      options['temperature'] = _DEFAULT_TEMPERATURE\n    if top_k is not None:\n      options['top_k'] = top_k\n    if top_p is not None:\n      options['top_p'] = top_p\n    if num_threads is not None:\n      options['num_thread'] = num_threads\n    if max_output_tokens is not None:\n      options['num_predict'] = max_output_tokens\n    if num_ctx is not None:\n      options['num_ctx'] = num_ctx\n    else:\n      options['num_ctx'] = _DEFAULT_NUM_CTX\n\n    reserved_top_level = {\n        'model',\n        'prompt',\n        'system',\n        'stop',\n        'format',\n        'stream',\n        'raw',\n    }\n    for key, value in kwargs.items():\n      if value is None:\n        continue\n      if key in reserved_top_level:\n        continue\n      if key not in options:\n        options[key] = value\n\n    api_url = urljoin(\n        model_url if model_url.endswith('/') else model_url + '/',\n        'api/generate',\n    )\n\n    payload: dict[str, Any] = {\n        'model': model,\n        'prompt': prompt,\n        'system': system,\n        'stream': False,\n        'raw': raw,\n        'options': options,\n    }\n\n    if structured_output_format is not None:\n      payload['format'] = structured_output_format\n\n    if stop is not None:\n      payload['stop'] = stop\n\n    request_timeout = timeout if timeout is not None else _DEFAULT_TIMEOUT\n\n    headers = {\n        'Content-Type': 'application/json',\n        'Accept': 'application/json',\n    }\n\n    if self._api_key:\n      if self._auth_scheme:\n        headers[self._auth_header] = f'{self._auth_scheme} {self._api_key}'\n      else:\n        headers[self._auth_header] = self._api_key\n\n    try:\n      response = self._requests.post(\n          api_url,\n          headers=headers,\n          json=payload,\n          timeout=request_timeout,\n      )\n    except self._requests.exceptions.RequestException as e:\n      if isinstance(e, self._requests.exceptions.ReadTimeout):\n        msg = (\n            f'Ollama Model timed out (timeout={request_timeout},'\n            f' num_threads={num_threads})'\n        )\n        raise exceptions.InferenceRuntimeError(\n            msg, original=e, provider='Ollama'\n        ) from e\n      raise exceptions.InferenceRuntimeError(\n          f'Ollama request failed: {str(e)}', original=e, provider='Ollama'\n      ) from e\n\n    response.encoding = 'utf-8'\n    if response.status_code == 200:\n      return response.json()\n    if response.status_code == 404:\n      raise exceptions.InferenceConfigError(\n          f\"Can't find Ollama {model}. Try: ollama run {model}\"\n      )\n    else:\n      msg = f'Bad status code from Ollama: {response.status_code}'\n      raise exceptions.InferenceRuntimeError(msg, provider='Ollama')\n"
  },
  {
    "path": "langextract/providers/openai.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"OpenAI provider for LangExtract.\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport concurrent.futures\nimport dataclasses\nfrom typing import Any, Iterator, Sequence\n\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import schema\nfrom langextract.core import types as core_types\nfrom langextract.providers import patterns\nfrom langextract.providers import router\n\n\n@router.register(\n    *patterns.OPENAI_PATTERNS,\n    priority=patterns.OPENAI_PRIORITY,\n)\n@dataclasses.dataclass(init=False)\nclass OpenAILanguageModel(base_model.BaseLanguageModel):\n  \"\"\"Language model inference using OpenAI's API with structured output.\"\"\"\n\n  model_id: str = 'gpt-4o-mini'\n  api_key: str | None = None\n  base_url: str | None = None\n  organization: str | None = None\n  format_type: data.FormatType = data.FormatType.JSON\n  temperature: float | None = None\n  max_workers: int = 10\n  _client: Any = dataclasses.field(default=None, repr=False, compare=False)\n  _extra_kwargs: dict[str, Any] = dataclasses.field(\n      default_factory=dict, repr=False, compare=False\n  )\n\n  @property\n  def requires_fence_output(self) -> bool:\n    \"\"\"OpenAI JSON mode returns raw JSON without fences.\"\"\"\n    if self.format_type == data.FormatType.JSON:\n      return False\n    return super().requires_fence_output\n\n  def __init__(\n      self,\n      model_id: str = 'gpt-4o-mini',\n      api_key: str | None = None,\n      base_url: str | None = None,\n      organization: str | None = None,\n      format_type: data.FormatType = data.FormatType.JSON,\n      temperature: float | None = None,\n      max_workers: int = 10,\n      **kwargs,\n  ) -> None:\n    \"\"\"Initialize the OpenAI language model.\n\n    Args:\n      model_id: The OpenAI model ID to use (e.g., 'gpt-4o-mini', 'gpt-4o').\n      api_key: API key for OpenAI service.\n      base_url: Base URL for OpenAI service.\n      organization: Optional OpenAI organization ID.\n      format_type: Output format (JSON or YAML).\n      temperature: Sampling temperature.\n      max_workers: Maximum number of parallel API calls.\n      **kwargs: Ignored extra parameters so callers can pass a superset of\n        arguments shared across back-ends without raising ``TypeError``.\n    \"\"\"\n    # Lazy import: OpenAI package required\n    try:\n      # pylint: disable=import-outside-toplevel\n      import openai\n    except ImportError as e:\n      raise exceptions.InferenceConfigError(\n          'OpenAI provider requires openai package. '\n          'Install with: pip install langextract[openai]'\n      ) from e\n\n    self.model_id = model_id\n    self.api_key = api_key\n    self.base_url = base_url\n    self.organization = organization\n    self.format_type = format_type\n    self.temperature = temperature\n    self.max_workers = max_workers\n\n    if not self.api_key:\n      raise exceptions.InferenceConfigError('API key not provided.')\n\n    # Initialize the OpenAI client\n    self._client = openai.OpenAI(\n        api_key=self.api_key,\n        base_url=self.base_url,\n        organization=self.organization,\n    )\n\n    super().__init__(\n        constraint=schema.Constraint(constraint_type=schema.ConstraintType.NONE)\n    )\n    self._extra_kwargs = kwargs or {}\n\n  def _normalize_reasoning_params(self, config: dict) -> dict:\n    \"\"\"Normalize reasoning parameters for API compatibility.\n\n    Converts flat 'reasoning_effort' to nested 'reasoning' structure.\n    Merges with existing reasoning dict if present.\n    \"\"\"\n    result = config.copy()\n\n    if 'reasoning_effort' in result:\n      effort = result.pop('reasoning_effort')\n      reasoning = result.get('reasoning', {}) or {}\n      reasoning.setdefault('effort', effort)\n      result['reasoning'] = reasoning\n\n    return result\n\n  def _process_single_prompt(\n      self, prompt: str, config: dict\n  ) -> core_types.ScoredOutput:\n    \"\"\"Process a single prompt and return a ScoredOutput.\"\"\"\n    try:\n      normalized_config = self._normalize_reasoning_params(config)\n\n      system_message = ''\n      if self.format_type == data.FormatType.JSON:\n        system_message = (\n            'You are a helpful assistant that responds in JSON format.'\n        )\n      elif self.format_type == data.FormatType.YAML:\n        system_message = (\n            'You are a helpful assistant that responds in YAML format.'\n        )\n\n      messages = [{'role': 'user', 'content': prompt}]\n      if system_message:\n        messages.insert(0, {'role': 'system', 'content': system_message})\n\n      api_params = {\n          'model': self.model_id,\n          'messages': messages,\n          'n': 1,\n      }\n\n      temp = normalized_config.get('temperature', self.temperature)\n      if temp is not None:\n        api_params['temperature'] = temp\n\n      if self.format_type == data.FormatType.JSON:\n        api_params.setdefault('response_format', {'type': 'json_object'})\n\n      if (v := normalized_config.get('max_output_tokens')) is not None:\n        api_params['max_tokens'] = v\n      if (v := normalized_config.get('top_p')) is not None:\n        api_params['top_p'] = v\n      for key in [\n          'frequency_penalty',\n          'presence_penalty',\n          'seed',\n          'stop',\n          'logprobs',\n          'top_logprobs',\n          'reasoning',\n          'response_format',\n      ]:\n        if (v := normalized_config.get(key)) is not None:\n          api_params[key] = v\n\n      response = self._client.chat.completions.create(**api_params)\n\n      # Extract the response text using the v1.x response format\n      output_text = response.choices[0].message.content\n\n      return core_types.ScoredOutput(score=1.0, output=output_text)\n\n    except Exception as e:\n      raise exceptions.InferenceRuntimeError(\n          f'OpenAI API error: {str(e)}', original=e\n      ) from e\n\n  def infer(\n      self, batch_prompts: Sequence[str], **kwargs\n  ) -> Iterator[Sequence[core_types.ScoredOutput]]:\n    \"\"\"Runs inference on a list of prompts via OpenAI's API.\n\n    Args:\n      batch_prompts: A list of string prompts.\n      **kwargs: Additional generation params (temperature, top_p, etc.)\n\n    Yields:\n      Lists of ScoredOutputs.\n    \"\"\"\n    merged_kwargs = self.merge_kwargs(kwargs)\n\n    config = {}\n\n    temp = merged_kwargs.get('temperature', self.temperature)\n    if temp is not None:\n      config['temperature'] = temp\n    if 'max_output_tokens' in merged_kwargs:\n      config['max_output_tokens'] = merged_kwargs['max_output_tokens']\n    if 'top_p' in merged_kwargs:\n      config['top_p'] = merged_kwargs['top_p']\n\n    for key in [\n        'frequency_penalty',\n        'presence_penalty',\n        'seed',\n        'stop',\n        'logprobs',\n        'top_logprobs',\n        'reasoning_effort',\n        'reasoning',\n        'response_format',\n    ]:\n      if key in merged_kwargs:\n        config[key] = merged_kwargs[key]\n\n    # Use parallel processing for batches larger than 1\n    if len(batch_prompts) > 1 and self.max_workers > 1:\n      with concurrent.futures.ThreadPoolExecutor(\n          max_workers=min(self.max_workers, len(batch_prompts))\n      ) as executor:\n        future_to_index = {\n            executor.submit(\n                self._process_single_prompt, prompt, config.copy()\n            ): i\n            for i, prompt in enumerate(batch_prompts)\n        }\n\n        results: list[core_types.ScoredOutput | None] = [None] * len(\n            batch_prompts\n        )\n        for future in concurrent.futures.as_completed(future_to_index):\n          index = future_to_index[future]\n          try:\n            results[index] = future.result()\n          except Exception as e:\n            raise exceptions.InferenceRuntimeError(\n                f'Parallel inference error: {str(e)}', original=e\n            ) from e\n\n        for result in results:\n          if result is None:\n            raise exceptions.InferenceRuntimeError(\n                'Failed to process one or more prompts'\n            )\n          yield [result]\n    else:\n      # Sequential processing for single prompt or worker\n      for prompt in batch_prompts:\n        result = self._process_single_prompt(prompt, config.copy())\n        yield [result]  # pylint: disable=duplicate-code\n"
  },
  {
    "path": "langextract/providers/patterns.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Centralized pattern definitions for built-in providers.\n\nThis module defines all patterns and priorities for built-in providers\nin one place to avoid duplication.\n\"\"\"\n\n# Gemini provider patterns\nGEMINI_PATTERNS = (r'^gemini',)\nGEMINI_PRIORITY = 10\n\n# OpenAI provider patterns\nOPENAI_PATTERNS = (\n    r'^gpt-4',\n    r'^gpt4\\.',\n    r'^gpt-5',\n    r'^gpt5\\.',\n)\nOPENAI_PRIORITY = 10\n\n# Ollama provider patterns\nOLLAMA_PATTERNS = (\n    # Standard Ollama naming patterns\n    r'^gemma',  # gemma2:2b, gemma2:9b, etc.\n    r'^llama',  # llama3.2:1b, llama3.1:8b, etc.\n    r'^mistral',  # mistral:7b, mistral-nemo:12b, etc.\n    r'^mixtral',  # mixtral:8x7b, mixtral:8x22b, etc.\n    r'^phi',  # phi3:3.8b, phi3:14b, etc.\n    r'^qwen',  # qwen2.5:0.5b to 72b\n    r'^deepseek',  # deepseek-coder-v2, etc.\n    r'^command-r',  # command-r:35b, command-r-plus:104b\n    r'^starcoder',  # starcoder2:3b, starcoder2:7b, etc.\n    r'^codellama',  # codellama:7b, codellama:13b, etc.\n    r'^codegemma',  # codegemma:2b, codegemma:7b\n    r'^tinyllama',  # tinyllama:1.1b\n    r'^wizardcoder',  # wizardcoder:7b, wizardcoder:13b, etc.\n    r'^gpt-oss',  # Open source GPT variants\n    # HuggingFace model patterns\n    r'^meta-llama/[Ll]lama',\n    r'^google/gemma',\n    r'^mistralai/[Mm]istral',\n    r'^mistralai/[Mm]ixtral',\n    r'^microsoft/phi',\n    r'^Qwen/',\n    r'^deepseek-ai/',\n    r'^bigcode/starcoder',\n    r'^codellama/',\n    r'^TinyLlama/',\n    r'^WizardLM/',\n)\nOLLAMA_PRIORITY = 10\n"
  },
  {
    "path": "langextract/providers/router.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Runtime registry that maps model-ID patterns to provider classes.\n\nThis module provides a lazy registration system for LLM providers, allowing\nproviders to be registered without importing their dependencies until needed.\n\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nimport dataclasses\nimport functools\nimport importlib\nimport re\nimport typing\n\nfrom absl import logging\n\nfrom langextract.core import base_model\nfrom langextract.core import exceptions\n\nTLanguageModel = typing.TypeVar(\n    \"TLanguageModel\", bound=base_model.BaseLanguageModel\n)\n\n\n@dataclasses.dataclass(frozen=True, slots=True)\nclass _Entry:\n  \"\"\"Registry entry for a provider.\"\"\"\n\n  patterns: tuple[re.Pattern[str], ...]\n  loader: typing.Callable[[], type[base_model.BaseLanguageModel]]\n  priority: int\n\n\n_entries: list[_Entry] = []\n_entry_keys: set[tuple[str, tuple[str, ...], int]] = (\n    set()\n)  # (provider_id, patterns, priority)\n\n\ndef _add_entry(\n    *,\n    provider_id: str,\n    patterns: tuple[re.Pattern[str], ...],\n    loader: typing.Callable[[], type[base_model.BaseLanguageModel]],\n    priority: int,\n) -> None:\n  \"\"\"Add an entry to the registry with deduplication.\"\"\"\n  key = (provider_id, tuple(p.pattern for p in patterns), priority)\n  if key in _entry_keys:\n    logging.debug(\n        \"Skipping duplicate registration for %s with patterns %s at\"\n        \" priority %d\",\n        provider_id,\n        [p.pattern for p in patterns],\n        priority,\n    )\n    return\n  _entry_keys.add(key)\n  _entries.append(_Entry(patterns=patterns, loader=loader, priority=priority))\n  logging.debug(\n      \"Registered provider %s with patterns %s at priority %d\",\n      provider_id,\n      [p.pattern for p in patterns],\n      priority,\n  )\n\n\ndef register_lazy(\n    *patterns: str | re.Pattern[str], target: str, priority: int = 0\n) -> None:\n  \"\"\"Register a provider lazily using string import path.\n\n  Args:\n    *patterns: One or more regex patterns to match model IDs.\n    target: Import path in format \"module.path:ClassName\".\n    priority: Priority for resolution (higher wins on conflicts).\n  \"\"\"\n  compiled = tuple(re.compile(p) if isinstance(p, str) else p for p in patterns)\n\n  def _loader() -> type[base_model.BaseLanguageModel]:\n    module_path, class_name = target.rsplit(\":\", 1)\n    module = importlib.import_module(module_path)\n    return getattr(module, class_name)\n\n  _add_entry(\n      provider_id=target,\n      patterns=compiled,\n      loader=_loader,\n      priority=priority,\n  )\n\n\ndef register(\n    *patterns: str | re.Pattern[str], priority: int = 0\n) -> typing.Callable[[type[TLanguageModel]], type[TLanguageModel]]:\n  \"\"\"Decorator to register a provider class directly.\n\n  Args:\n    *patterns: One or more regex patterns to match model IDs.\n    priority: Priority for resolution (higher wins on conflicts).\n\n  Returns:\n    Decorator function that registers the class.\n  \"\"\"\n  compiled = tuple(re.compile(p) if isinstance(p, str) else p for p in patterns)\n\n  def _decorator(cls: type[TLanguageModel]) -> type[TLanguageModel]:\n    def _loader() -> type[base_model.BaseLanguageModel]:\n      return cls\n\n    provider_id = f\"{cls.__module__}:{cls.__name__}\"\n    _add_entry(\n        provider_id=provider_id,\n        patterns=compiled,\n        loader=_loader,\n        priority=priority,\n    )\n    return cls\n\n  return _decorator\n\n\n@functools.lru_cache(maxsize=128)\ndef resolve(model_id: str) -> type[base_model.BaseLanguageModel]:\n  \"\"\"Resolve a model ID to a provider class.\n\n  Args:\n    model_id: The model identifier to resolve.\n\n  Returns:\n    The provider class that handles this model ID.\n\n  Raises:\n    ValueError: If no provider is registered for the model ID.\n  \"\"\"\n  # Providers should be loaded by the caller (e.g., factory.create_model)\n  # Router doesn't load providers to avoid circular dependencies\n\n  sorted_entries = sorted(_entries, key=lambda e: e.priority, reverse=True)\n\n  for entry in sorted_entries:\n    if any(pattern.search(model_id) for pattern in entry.patterns):\n      return entry.loader()\n\n  available_patterns = [str(p.pattern) for e in _entries for p in e.patterns]\n  raise exceptions.InferenceConfigError(\n      f\"No provider registered for model_id={model_id!r}. \"\n      f\"Available patterns: {available_patterns}\\n\"\n      \"Tip: You can explicitly specify a provider using 'config' parameter \"\n      \"with factory.ModelConfig and a provider class.\"\n  )\n\n\n@functools.lru_cache(maxsize=128)\ndef resolve_provider(provider_name: str) -> type[base_model.BaseLanguageModel]:\n  \"\"\"Resolve a provider name to a provider class.\n\n  This allows explicit provider selection by name or class name.\n\n  Args:\n    provider_name: The provider name (e.g., \"gemini\", \"openai\") or\n      class name (e.g., \"GeminiLanguageModel\").\n\n  Returns:\n    The provider class.\n\n  Raises:\n    ValueError: If no provider matches the name.\n  \"\"\"\n  # Providers should be loaded by the caller (e.g., factory.create_model)\n  # Router doesn't load providers to avoid circular dependencies\n\n  for entry in _entries:\n    for pattern in entry.patterns:\n      if pattern.pattern == f\"^{re.escape(provider_name)}$\":\n        return entry.loader()\n\n  for entry in _entries:\n    try:\n      provider_class = entry.loader()\n      class_name = provider_class.__name__\n      if provider_name.lower() in class_name.lower():\n        return provider_class\n    except (ImportError, AttributeError):\n      continue\n\n  try:\n    pattern = re.compile(f\"^{provider_name}$\", re.IGNORECASE)\n    for entry in _entries:\n      for entry_pattern in entry.patterns:\n        if pattern.pattern == entry_pattern.pattern:\n          return entry.loader()\n  except re.error:\n    pass\n\n  raise exceptions.InferenceConfigError(\n      f\"No provider found matching: {provider_name!r}. \"\n      \"Available providers can be listed with list_providers()\"\n  )\n\n\ndef clear() -> None:\n  \"\"\"Clear all registered providers. Mainly for testing.\"\"\"\n  global _entries  # pylint: disable=global-statement\n  _entries = []\n  _entry_keys.clear()  # Also clear dedup keys to allow re-registration\n  resolve.cache_clear()\n  resolve_provider.cache_clear()\n\n\ndef list_providers() -> list[tuple[tuple[str, ...], int]]:\n  \"\"\"List all registered providers with their patterns and priorities.\n\n  Returns:\n    List of (patterns, priority) tuples for debugging.\n  \"\"\"\n  return [\n      (tuple(p.pattern for p in entry.patterns), entry.priority)\n      for entry in _entries\n  ]\n\n\ndef list_entries() -> list[tuple[list[str], int]]:\n  \"\"\"List all registered patterns and priorities. Mainly for debugging.\n\n  Returns:\n    List of (patterns, priority) tuples.\n  \"\"\"\n  return [([p.pattern for p in e.patterns], e.priority) for e in _entries]\n"
  },
  {
    "path": "langextract/providers/schemas/__init__.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Provider-specific schema implementations.\"\"\"\nfrom __future__ import annotations\n\nfrom langextract.providers.schemas import gemini\n\nGeminiSchema = gemini.GeminiSchema  # Backward compat\n\n__all__ = [\"GeminiSchema\"]\n"
  },
  {
    "path": "langextract/providers/schemas/gemini.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Gemini provider schema implementation.\"\"\"\n# pylint: disable=duplicate-code\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nimport dataclasses\nfrom typing import Any\nimport warnings\n\nfrom langextract.core import data\nfrom langextract.core import format_handler as fh\nfrom langextract.core import schema\n\n\n@dataclasses.dataclass\nclass GeminiSchema(schema.BaseSchema):\n  \"\"\"Schema implementation for Gemini structured output.\n\n  Converts ExampleData objects into an OpenAPI/JSON-schema definition\n  that Gemini can interpret via 'response_schema'.\n  \"\"\"\n\n  _schema_dict: dict[str, Any]\n\n  @property\n  def schema_dict(self) -> dict[str, Any]:\n    \"\"\"Returns the schema dictionary.\"\"\"\n    return self._schema_dict\n\n  @schema_dict.setter\n  def schema_dict(self, schema_dict: dict[str, Any]) -> None:\n    \"\"\"Sets the schema dictionary.\"\"\"\n    self._schema_dict = schema_dict\n\n  def to_provider_config(self) -> dict[str, Any]:\n    \"\"\"Convert schema to Gemini-specific configuration.\n\n    Returns:\n      Dictionary with response_schema and response_mime_type for Gemini API.\n    \"\"\"\n    return {\n        \"response_schema\": self._schema_dict,\n        \"response_mime_type\": \"application/json\",\n    }\n\n  @property\n  def requires_raw_output(self) -> bool:\n    \"\"\"Gemini outputs raw JSON via response_mime_type.\"\"\"\n    return True\n\n  def validate_format(self, format_handler: fh.FormatHandler) -> None:\n    \"\"\"Validate Gemini's format requirements.\n\n    Gemini requires:\n    - No fence markers (outputs raw JSON via response_mime_type)\n    - Wrapper with EXTRACTIONS_KEY (built into response_schema)\n    \"\"\"\n    # Check for fence usage with raw JSON output\n    if format_handler.use_fences:\n      warnings.warn(\n          \"Gemini outputs native JSON via\"\n          \" response_mime_type='application/json'. Using fence_output=True may\"\n          \" cause parsing issues. Set fence_output=False.\",\n          UserWarning,\n          stacklevel=3,\n      )\n\n    # Verify wrapper is enabled with correct key\n    if (\n        not format_handler.use_wrapper\n        or format_handler.wrapper_key != data.EXTRACTIONS_KEY\n    ):\n      warnings.warn(\n          \"Gemini's response_schema expects\"\n          f\" wrapper_key='{data.EXTRACTIONS_KEY}'. Current settings:\"\n          f\" use_wrapper={format_handler.use_wrapper},\"\n          f\" wrapper_key='{format_handler.wrapper_key}'\",\n          UserWarning,\n          stacklevel=3,\n      )\n\n  @classmethod\n  def from_examples(\n      cls,\n      examples_data: Sequence[data.ExampleData],\n      attribute_suffix: str = data.ATTRIBUTE_SUFFIX,\n  ) -> GeminiSchema:\n    \"\"\"Creates a GeminiSchema from example extractions.\n\n    Builds a JSON-based schema with a top-level \"extractions\" array. Each\n    element in that array is an object containing the extraction class name\n    and an accompanying \"<class>_attributes\" object for its attributes.\n\n    Args:\n      examples_data: A sequence of ExampleData objects containing extraction\n        classes and attributes.\n      attribute_suffix: String appended to each class name to form the\n        attributes field name (defaults to \"_attributes\").\n\n    Returns:\n      A GeminiSchema with internal dictionary represents the JSON constraint.\n    \"\"\"\n    # Track attribute types for each category\n    extraction_categories: dict[str, dict[str, set[type]]] = {}\n    for example in examples_data:\n      for extraction in example.extractions:\n        category = extraction.extraction_class\n        if category not in extraction_categories:\n          extraction_categories[category] = {}\n\n        if extraction.attributes:\n          for attr_name, attr_value in extraction.attributes.items():\n            if attr_name not in extraction_categories[category]:\n              extraction_categories[category][attr_name] = set()\n            extraction_categories[category][attr_name].add(type(attr_value))\n\n    extraction_properties: dict[str, dict[str, Any]] = {}\n\n    for category, attrs in extraction_categories.items():\n      extraction_properties[category] = {\"type\": \"string\"}\n\n      attributes_field = f\"{category}{attribute_suffix}\"\n      attr_properties = {}\n\n      # Default property for categories without attributes\n      if not attrs:\n        attr_properties[\"_unused\"] = {\"type\": \"string\"}\n      else:\n        for attr_name, attr_types in attrs.items():\n          # List attributes become arrays\n          if list in attr_types:\n            attr_properties[attr_name] = {\n                \"type\": \"array\",\n                \"items\": {\"type\": \"string\"},  # type: ignore[dict-item]\n            }\n          else:\n            attr_properties[attr_name] = {\"type\": \"string\"}\n\n      extraction_properties[attributes_field] = {\n          \"type\": \"object\",\n          \"properties\": attr_properties,\n          \"nullable\": True,\n      }\n\n    extraction_schema = {\n        \"type\": \"object\",\n        \"properties\": extraction_properties,\n    }\n\n    schema_dict = {\n        \"type\": \"object\",\n        \"properties\": {\n            data.EXTRACTIONS_KEY: {\"type\": \"array\", \"items\": extraction_schema}\n        },\n        \"required\": [data.EXTRACTIONS_KEY],\n    }\n\n    return cls(_schema_dict=schema_dict)\n"
  },
  {
    "path": "langextract/py.typed",
    "content": ""
  },
  {
    "path": "langextract/registry.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.registry imports.\n\nThis module redirects to langextract.plugins for backward compatibility.\nWill be removed in v2.0.0.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\n\nfrom langextract import plugins\n\n\ndef __getattr__(name: str):\n  \"\"\"Redirect to plugins module with deprecation warning.\"\"\"\n  warnings.warn(\n      \"`langextract.registry` is deprecated and will be removed in v2.0.0; \"\n      \"use `langextract.plugins` instead.\",\n      FutureWarning,\n      stacklevel=2,\n  )\n  return getattr(plugins, name)\n"
  },
  {
    "path": "langextract/resolver.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Library for resolving LLM output.\n\nIn the context of this module, a \"resolver\" is a component designed to parse and\ntransform the textual output of an LLM into structured data.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport abc\nimport collections\nfrom collections.abc import Iterator, Mapping, Sequence\nimport difflib\nimport functools\nimport itertools\nimport operator\nfrom typing import Final\n\nfrom absl import logging\n\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import format_handler as fh\nfrom langextract.core import schema\nfrom langextract.core import tokenizer as tokenizer_lib\n\n_FUZZY_ALIGNMENT_MIN_THRESHOLD = 0.75\n\n# Default suffix for extraction index keys (e.g., \"entity_index\")\nDEFAULT_INDEX_SUFFIX = \"_index\"  # Suffix for index fields in extraction sorting\n\nALIGNMENT_PARAM_KEYS: Final[frozenset[str]] = frozenset({\n    \"enable_fuzzy_alignment\",\n    \"fuzzy_alignment_threshold\",\n    \"accept_match_lesser\",\n    \"suppress_parse_errors\",\n})\n\n\nclass AbstractResolver(abc.ABC):\n  \"\"\"Resolves LLM text outputs into structured data.\"\"\"\n\n  # TODO: Review value and requirements for abstract class.\n  def __init__(\n      self,\n      fence_output: bool = True,\n      constraint: schema.Constraint = schema.Constraint(),\n      format_type: data.FormatType = data.FormatType.JSON,\n  ):\n    \"\"\"Initializes the BaseResolver.\n\n    Delimiters are used for parsing text blocks, and are used primarily for\n    models that do not have constrained-decoding support.\n\n    Args:\n      fence_output: Whether to expect/generate fenced output (```json or\n        ```yaml). When True, the model is prompted to generate fenced output and\n        the resolver expects it. When False, raw JSON/YAML is expected. If your\n        model utilizes schema constraints, this can generally be set to False\n        unless the constraint also accounts for code fence delimiters.\n      constraint: Applies constraint when decoding the output. Defaults to no\n        constraint.\n      format_type: The format type for the output (JSON or YAML).\n    \"\"\"\n    self._fence_output = fence_output\n    self._constraint = constraint\n    self._format_type = format_type\n\n  @property\n  def fence_output(self) -> bool:\n    \"\"\"Returns whether fenced output is expected.\"\"\"\n    return self._fence_output\n\n  @fence_output.setter\n  def fence_output(self, fence_output: bool) -> None:\n    \"\"\"Sets whether fenced output is expected.\n\n    Args:\n      fence_output: Whether to expect fenced output.\n    \"\"\"\n    self._fence_output = fence_output\n\n  @property\n  def format_type(self) -> data.FormatType:\n    \"\"\"Returns the format type.\"\"\"\n    return self._format_type\n\n  @format_type.setter\n  def format_type(self, new_format_type: data.FormatType) -> None:\n    \"\"\"Sets a new format type.\"\"\"\n    self._format_type = new_format_type\n\n  @abc.abstractmethod\n  def resolve(\n      self,\n      input_text: str,\n      **kwargs,\n  ) -> Sequence[data.Extraction]:\n    \"\"\"Run resolve function on input text.\n\n    Args:\n        input_text: The input text to be processed.\n        **kwargs: Additional arguments for subclass implementations.\n\n    Returns:\n        Annotated text in the form of Extractions.\n    \"\"\"\n\n  @abc.abstractmethod\n  def align(\n      self,\n      extractions: Sequence[data.Extraction],\n      source_text: str,\n      token_offset: int,\n      char_offset: int | None = None,\n      enable_fuzzy_alignment: bool = True,\n      fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD,\n      accept_match_lesser: bool = True,\n      **kwargs,\n  ) -> Iterator[data.Extraction]:\n    \"\"\"Aligns extractions with source text, setting token/char intervals and alignment status.\n\n    Uses exact matching first (difflib), then fuzzy alignment fallback if\n    enabled.\n\n    Alignment Status Results:\n    - MATCH_EXACT: Perfect token-level match\n    - MATCH_LESSER: Partial exact match (extraction longer than matched text)\n    - MATCH_FUZZY: Best overlap window meets threshold (≥\n    fuzzy_alignment_threshold)\n    - None: No alignment found\n\n    Args:\n      extractions: Annotated extractions to align with the source text.\n      source_text: The text in which to align the extractions.\n      token_offset: The token_offset corresponding to the starting token index\n        of the chunk.\n      char_offset: The char_offset corresponding to the starting character index\n        of the chunk.\n      enable_fuzzy_alignment: Whether to use fuzzy alignment when exact matching\n        fails.\n      fuzzy_alignment_threshold: Minimum token overlap ratio for fuzzy alignment\n        (0-1).\n      accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER\n        status).\n      **kwargs: Additional keyword arguments for provider-specific alignment.\n\n    Yields:\n      Aligned extractions with updated token intervals and alignment status.\n    \"\"\"\n\n\nclass ResolverParsingError(exceptions.LangExtractError):\n  \"\"\"Error raised when content cannot be parsed as the given format.\"\"\"\n\n\nclass Resolver(AbstractResolver):\n  \"\"\"Resolver for YAML/JSON-based information extraction.\n\n  By default, extractions are returned in the order they appear in the model\n  output. To enable index-based sorting, set extraction_index_suffix to a\n  value like \"_index\" (the DEFAULT_INDEX_SUFFIX constant). This will sort\n  extractions by fields ending with that suffix (e.g., \"entity_index\").\n\n  Uses FormatHandler for parsing model output into extractions.\n  \"\"\"\n\n  def __init__(\n      self,\n      format_handler: fh.FormatHandler | None = None,\n      extraction_index_suffix: str | None = None,\n      **kwargs,  # Collect legacy parameters\n  ):\n    \"\"\"Constructor.\n\n    Args:\n      format_handler: The format handler that knows how to parse output.\n      extraction_index_suffix: Suffix identifying index keys that determine the\n        ordering of extractions.\n      **kwargs: Legacy parameters (fence_output, format_type, etc.) for backward\n        compatibility. These will be used to create a FormatHandler if one is not\n        provided. Support for these parameters will be removed in v2.0.0.\n    \"\"\"\n    constraint = kwargs.pop(\"constraint\", None)\n    extraction_attributes_suffix = kwargs.pop(\n        \"extraction_attributes_suffix\", None\n    )\n\n    if format_handler is None:\n      if kwargs or extraction_attributes_suffix is not None:\n        handler_kwargs = dict(kwargs)\n        if extraction_attributes_suffix is not None:\n          handler_kwargs[\"attribute_suffix\"] = extraction_attributes_suffix\n        format_handler = fh.FormatHandler.from_kwargs(**handler_kwargs)\n        for param in [\n            \"fence_output\",\n            \"format_type\",\n            \"strict_fences\",\n            \"require_extractions_key\",\n            \"attribute_suffix\",\n        ]:\n          kwargs.pop(param, None)\n      else:\n        format_handler = fh.FormatHandler()\n\n    if kwargs:\n      raise TypeError(\n          f\"got an unexpected keyword argument '{list(kwargs.keys())[0]}'\"\n      )\n\n    constraint = constraint or schema.Constraint()\n    super().__init__(\n        fence_output=format_handler.use_fences,\n        format_type=format_handler.format_type,\n        constraint=constraint,\n    )\n    self.format_handler = format_handler\n    self.extraction_index_suffix = extraction_index_suffix\n    self._constraint = constraint\n\n  def resolve(\n      self,\n      input_text: str,\n      suppress_parse_errors: bool = False,\n      **kwargs,\n  ) -> Sequence[data.Extraction]:\n    \"\"\"Runs resolve function on text with YAML/JSON extraction data.\n\n    Args:\n        input_text: The input text to be processed.\n        suppress_parse_errors: Log errors and continue pipeline.\n        **kwargs: Additional keyword arguments.\n\n    Returns:\n        Annotated text in the form of a sequence of data.Extraction objects.\n\n    Raises:\n        ResolverParsingError: If the content within the string cannot be parsed\n        due to formatting errors, or if the parsed content is not as expected.\n    \"\"\"\n    logging.debug(\"Starting resolver process for input text.\")\n    logging.debug(\"Input Text: %s\", input_text)\n\n    try:\n      constraint = getattr(self, \"_constraint\", schema.Constraint())\n      strict = getattr(constraint, \"strict\", False)\n      extraction_data = self.format_handler.parse_output(\n          input_text, strict=strict\n      )\n      logging.debug(\"Parsed content: %s\", extraction_data)\n\n    except exceptions.FormatError as e:\n      if suppress_parse_errors:\n        logging.exception(\n            \"Failed to parse input_text: %s, error: %s\", input_text, e\n        )\n        return []\n      raise ResolverParsingError(str(e)) from e\n\n    processed_extractions = self.extract_ordered_extractions(extraction_data)\n\n    logging.debug(\"Completed the resolver process.\")\n\n    return processed_extractions\n\n  def align(\n      self,\n      extractions: Sequence[data.Extraction],\n      source_text: str,\n      token_offset: int,\n      char_offset: int | None = None,\n      enable_fuzzy_alignment: bool = True,\n      fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD,\n      accept_match_lesser: bool = True,\n      tokenizer_inst: tokenizer_lib.Tokenizer | None = None,\n      **kwargs,\n  ) -> Iterator[data.Extraction]:\n    \"\"\"Aligns annotated extractions with source text.\n\n    This uses WordAligner which is based on Python's difflib SequenceMatcher to\n    match tokens in the source text with tokens from the annotated extractions.\n    If\n    the extraction order is significantly different from the source text order,\n    difflib may skip some matches, leaving certain extractions unmatched.\n\n    Args:\n      extractions: Annotated extractions.\n      source_text: The text chunk in which to align the extractions.\n      token_offset: The starting token index of the chunk.\n      char_offset: The starting character index of the chunk.\n      enable_fuzzy_alignment: Whether to enable fuzzy alignment fallback.\n      fuzzy_alignment_threshold: Minimum overlap ratio required for fuzzy\n        alignment.\n      accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER\n        status).\n      tokenizer_inst: Optional tokenizer instance.\n      **kwargs: Additional parameters.\n\n    Yields:\n        Iterator on aligned extractions.\n    \"\"\"\n    logging.debug(\"Starting alignment process for provided chunk text.\")\n\n    if not extractions:\n      logging.debug(\n          \"No extractions found in the annotated text; exiting alignment\"\n          \" process.\"\n      )\n      return\n    else:\n      extractions_group = [extractions]\n\n    aligner = WordAligner()\n    aligned_yaml_extractions = aligner.align_extractions(\n        extractions_group,\n        source_text,\n        token_offset,\n        char_offset or 0,\n        enable_fuzzy_alignment=enable_fuzzy_alignment,\n        fuzzy_alignment_threshold=fuzzy_alignment_threshold,\n        accept_match_lesser=accept_match_lesser,\n        tokenizer_impl=tokenizer_inst,\n    )\n    logging.debug(\n        \"Aligned extractions count: %d\",\n        sum(len(group) for group in aligned_yaml_extractions),\n    )\n\n    for extraction in itertools.chain(*aligned_yaml_extractions):\n      logging.debug(\"Yielding aligned extraction: %s\", extraction)\n      yield extraction\n\n    logging.debug(\"Completed alignment process for the provided source_text.\")\n\n  def string_to_extraction_data(\n      self,\n      input_string: str,\n  ) -> Sequence[Mapping[str, fh.ExtractionValueType]]:\n    \"\"\"Parses a YAML or JSON-formatted string into extraction data.\n\n    This method is kept for backward compatibility with tests.\n    It delegates to the FormatHandler for actual parsing.\n\n    Args:\n        input_string: A string containing YAML or JSON content.\n\n    Returns:\n        Sequence[Mapping[str, fh.ExtractionValueType]]: A sequence of parsed objects.\n\n    Raises:\n        ResolverParsingError: If the content within the string cannot be parsed.\n        ValueError: If the input is invalid or does not contain expected format.\n    \"\"\"\n    if not input_string or not isinstance(input_string, str):\n      logging.error(\"Input string must be a non-empty string.\")\n      raise ValueError(\"Input string must be a non-empty string.\")\n\n    try:\n      constraint = getattr(self, \"_constraint\", schema.Constraint())\n      strict = getattr(constraint, \"strict\", False)\n      return self.format_handler.parse_output(input_string, strict=strict)\n\n    except exceptions.FormatError as e:\n      raise ResolverParsingError(str(e)) from e\n\n    except Exception as e:\n      logging.exception(\"Failed to parse content.\")\n      raise ResolverParsingError(\"Failed to parse content.\") from e\n\n  def extract_ordered_extractions(\n      self,\n      extraction_data: Sequence[Mapping[str, fh.ExtractionValueType]],\n  ) -> Sequence[data.Extraction]:\n    \"\"\"Extracts and orders extraction data based on their associated indexes.\n\n    This function processes a list of dictionaries, each containing pairs of\n    extraction class keys and their corresponding values, along with optionally\n    associated index keys (identified by the index_suffix). It sorts these pairs\n    by their indices in ascending order and excludes pairs without an index key,\n    returning a list of lists of tuples (extraction_class: str, extraction_text:\n    str).\n\n    Args:\n        extraction_data: A list of dictionaries. Each dictionary contains pairs\n          of extraction class keys and their values, along with optional index\n          keys.\n\n    Returns:\n        Extractions sorted by the index attribute or by order of appearance. If\n        two\n        extractions have the same index, their group order dictates the sorting\n        order.\n    Raises:\n        ValueError: If the extraction text is not a string or integer, or if the\n        index is not an integer.\n    \"\"\"\n    logging.debug(\"Starting to extract and order extractions from data.\")\n\n    if not extraction_data:\n      logging.debug(\"Received empty extraction data.\")\n\n    processed_extractions = []\n    extraction_index = 0\n    index_suffix = self.extraction_index_suffix\n    attributes_suffix = self.format_handler.attribute_suffix\n\n    for group_index, group in enumerate(extraction_data):\n      for extraction_class, extraction_value in group.items():\n        if index_suffix and extraction_class.endswith(index_suffix):\n          if not isinstance(extraction_value, int):\n            logging.error(\n                \"Index must be an integer. Found: %s\",\n                type(extraction_value),\n            )\n            raise ValueError(\"Index must be an integer.\")\n          continue\n\n        if attributes_suffix and extraction_class.endswith(attributes_suffix):\n          if not isinstance(extraction_value, (dict, type(None))):\n            logging.error(\n                \"Attributes must be a dict or None. Found: %s\",\n                type(extraction_value),\n            )\n            raise ValueError(\n                \"Extraction value must be a dict or None for attributes.\"\n            )\n          continue\n\n        if not isinstance(extraction_value, (str, int, float)):\n          logging.error(\n              \"Extraction text must be a string, integer, or float. Found: %s\",\n              type(extraction_value),\n          )\n          raise ValueError(\n              \"Extraction text must be a string, integer, or float.\"\n          )\n\n        if not isinstance(extraction_value, str):\n          extraction_value = str(extraction_value)\n\n        if index_suffix:\n          index_key = extraction_class + index_suffix\n          extraction_index = group.get(index_key, None)\n          if extraction_index is None:\n            logging.debug(\n                \"No index value for %s. Skipping extraction.\", extraction_class\n            )\n            continue\n        else:\n          extraction_index += 1\n\n        attributes = None\n        if attributes_suffix:\n          attributes_key = extraction_class + attributes_suffix\n          attributes = group.get(attributes_key, None)\n\n        processed_extractions.append(\n            data.Extraction(\n                extraction_class=extraction_class,\n                extraction_text=extraction_value,\n                extraction_index=extraction_index,\n                group_index=group_index,\n                attributes=attributes,\n            )\n        )\n\n    processed_extractions.sort(key=operator.attrgetter(\"extraction_index\"))\n    logging.debug(\"Completed extraction and ordering of extractions.\")\n    return processed_extractions\n\n\nclass WordAligner:\n  \"\"\"Aligns words between two sequences of tokens using Python's difflib.\"\"\"\n\n  def __init__(self):\n    \"\"\"Initialize the WordAligner with difflib SequenceMatcher.\"\"\"\n    self.matcher = difflib.SequenceMatcher(autojunk=False)\n    self.source_tokens: Sequence[str] | None = None\n    self.extraction_tokens: Sequence[str] | None = None\n\n  def _set_seqs(\n      self,\n      source_tokens: Sequence[str] | Iterator[str],\n      extraction_tokens: Sequence[str] | Iterator[str],\n  ):\n    \"\"\"Sets the source and extraction tokens for alignment.\n\n    Args:\n      source_tokens: A nonempty sequence or iterator of word-level tokens from\n        source text.\n      extraction_tokens: A nonempty sequence or iterator of extraction tokens in\n        order for matching to the source.\n    \"\"\"\n\n    if isinstance(source_tokens, Iterator):\n      source_tokens = list(source_tokens)\n    if isinstance(extraction_tokens, Iterator):\n      extraction_tokens = list(extraction_tokens)\n\n    if not source_tokens or not extraction_tokens:\n      raise ValueError(\"Source tokens and extraction tokens cannot be empty.\")\n\n    self.source_tokens = source_tokens\n    self.extraction_tokens = extraction_tokens\n    self.matcher.set_seqs(a=source_tokens, b=extraction_tokens)\n\n  def _get_matching_blocks(self) -> Sequence[tuple[int, int, int]]:\n    \"\"\"Utilizes difflib SequenceMatcher and returns matching blocks of tokens.\n\n    Returns:\n      Sequence of matching blocks between source_tokens (S) and\n      extraction_tokens\n      (E). Each block (i, j, n) conforms to: S[i:i+n] == E[j:j+n], guaranteed to\n      be monotonically increasing in j. Final entry is a dummy with value\n      (len(S), len(E), 0).\n    \"\"\"\n    if self.source_tokens is None or self.extraction_tokens is None:\n      raise ValueError(\n          \"Source tokens and extraction tokens must be set before getting\"\n          \" matching blocks.\"\n      )\n    return self.matcher.get_matching_blocks()\n\n  def _fuzzy_align_extraction(\n      self,\n      extraction: data.Extraction,\n      source_tokens: list[str],\n      tokenized_text: tokenizer_lib.TokenizedText,\n      token_offset: int,\n      char_offset: int,\n      fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD,\n      tokenizer_impl: tokenizer_lib.Tokenizer | None = None,\n  ) -> data.Extraction | None:\n    \"\"\"Fuzzy-align an extraction using difflib.SequenceMatcher on tokens.\n\n    The algorithm scans every candidate window in `source_tokens` and selects\n    the window with the highest SequenceMatcher `ratio`. It uses an efficient\n    token-count intersection as a fast pre-check to discard windows that cannot\n    meet the alignment threshold. A match is accepted when the ratio is ≥\n    `fuzzy_alignment_threshold`. This only runs on unmatched extractions, which\n    is usually a small subset of the total extractions.\n\n    Args:\n      extraction: The extraction to align.\n      source_tokens: The tokens from the source text.\n      tokenized_text: The tokenized source text.\n      token_offset: The token offset of the current chunk.\n      char_offset: The character offset of the current chunk.\n      fuzzy_alignment_threshold: The minimum ratio for a fuzzy match.\n      tokenizer_impl: Optional tokenizer instance.\n\n    Returns:\n      The aligned data.Extraction if successful, None otherwise.\n    \"\"\"\n\n    extraction_tokens = list(\n        _tokenize_with_lowercase(\n            extraction.extraction_text, tokenizer_inst=tokenizer_impl\n        )\n    )\n    # Work with lightly stemmed tokens so pluralisation doesn't block alignment\n    extraction_tokens_norm = [_normalize_token(t) for t in extraction_tokens]\n\n    if not extraction_tokens:\n      return None\n\n    logging.debug(\n        \"Fuzzy aligning %r (%d tokens)\",\n        extraction.extraction_text,\n        len(extraction_tokens),\n    )\n\n    best_ratio = 0.0\n    best_span: tuple[int, int] | None = None  # (start_idx, window_size)\n\n    len_e = len(extraction_tokens)\n    max_window = len(source_tokens)\n\n    extraction_counts = collections.Counter(extraction_tokens_norm)\n    min_overlap = int(len_e * fuzzy_alignment_threshold)\n\n    matcher = difflib.SequenceMatcher(autojunk=False, b=extraction_tokens_norm)\n\n    for window_size in range(len_e, max_window + 1):\n      if window_size > len(source_tokens):\n        break\n\n      # Initialize for sliding window\n      window_deque = collections.deque(source_tokens[0:window_size])\n      window_counts = collections.Counter(\n          [_normalize_token(t) for t in window_deque]\n      )\n\n      for start_idx in range(len(source_tokens) - window_size + 1):\n        # Optimization: check if enough overlapping tokens exist before expensive\n        # sequence matching. This is an upper bound on the match count.\n        if (extraction_counts & window_counts).total() >= min_overlap:\n          window_tokens_norm = [_normalize_token(t) for t in window_deque]\n          matcher.set_seq1(window_tokens_norm)\n          matches = sum(size for _, _, size in matcher.get_matching_blocks())\n          if len_e > 0:\n            ratio = matches / len_e\n          else:\n            ratio = 0.0\n          if ratio > best_ratio:\n            best_ratio = ratio\n            best_span = (start_idx, window_size)\n\n        # Slide the window to the right\n        if start_idx + window_size < len(source_tokens):\n          # Remove the leftmost token from the count\n          old_token = window_deque.popleft()\n          old_token_norm = _normalize_token(old_token)\n          window_counts[old_token_norm] -= 1\n          if window_counts[old_token_norm] == 0:\n            del window_counts[old_token_norm]\n\n          # Add the new rightmost token to the deque and count\n          new_token = source_tokens[start_idx + window_size]\n          window_deque.append(new_token)\n          new_token_norm = _normalize_token(new_token)\n          window_counts[new_token_norm] += 1\n\n    if best_span and best_ratio >= fuzzy_alignment_threshold:\n      start_idx, window_size = best_span\n\n      try:\n        extraction.token_interval = tokenizer_lib.TokenInterval(\n            start_index=start_idx + token_offset,\n            end_index=start_idx + window_size + token_offset,\n        )\n\n        start_token = tokenized_text.tokens[start_idx]\n        end_token = tokenized_text.tokens[start_idx + window_size - 1]\n        extraction.char_interval = data.CharInterval(\n            start_pos=char_offset + start_token.char_interval.start_pos,\n            end_pos=char_offset + end_token.char_interval.end_pos,\n        )\n\n        extraction.alignment_status = data.AlignmentStatus.MATCH_FUZZY\n        return extraction\n      except IndexError:\n        logging.exception(\n            \"Index error while setting intervals during fuzzy alignment.\"\n        )\n        return None\n\n    return None\n\n  def align_extractions(\n      self,\n      extraction_groups: Sequence[Sequence[data.Extraction]],\n      source_text: str,\n      token_offset: int = 0,\n      char_offset: int = 0,\n      delim: str = \"\\u241F\",  # Unicode Symbol for unit separator\n      enable_fuzzy_alignment: bool = True,\n      fuzzy_alignment_threshold: float = _FUZZY_ALIGNMENT_MIN_THRESHOLD,\n      accept_match_lesser: bool = True,\n      tokenizer_impl: tokenizer_lib.Tokenizer | None = None,\n  ) -> Sequence[Sequence[data.Extraction]]:\n    \"\"\"Aligns extractions with their positions in the source text.\n\n    This method takes a sequence of extractions and the source text, aligning\n    each extraction with its corresponding position in the source text. It\n    returns a sequence of extractions along with token intervals indicating the\n    start and\n    end positions of each extraction in the source text. If an extraction cannot\n    be\n    aligned, its token interval is set to None.\n\n    Args:\n      extraction_groups: A sequence of sequences, where each inner sequence\n        contains an Extraction object.\n      source_text: The source text against which extractions are to be aligned.\n      token_offset: The offset to add to the start and end indices of the token\n        intervals.\n      char_offset: The offset to add to the start and end positions of the\n        character intervals.\n      delim: Token used to separate multi-token extractions.\n      enable_fuzzy_alignment: Whether to use fuzzy alignment when exact matching\n        fails.\n      fuzzy_alignment_threshold: Minimum token overlap ratio for fuzzy alignment\n        (0-1).\n      accept_match_lesser: Whether to accept partial exact matches (MATCH_LESSER\n        status).\n      tokenizer_impl: Optional tokenizer instance.\n\n    Returns:\n      A sequence of extractions aligned with the source text, including token\n      intervals.\n    \"\"\"\n    logging.debug(\n        \"WordAligner: Starting alignment of extractions with the source text.\"\n        \" Extraction groups to align: %s\",\n        extraction_groups,\n    )\n    if not extraction_groups:\n      logging.info(\"No extraction groups provided; returning empty list.\")\n      return []\n\n    source_tokens = list(\n        _tokenize_with_lowercase(source_text, tokenizer_inst=tokenizer_impl)\n    )\n\n    delim_len = len(\n        list(_tokenize_with_lowercase(delim, tokenizer_inst=tokenizer_impl))\n    )\n    if delim_len != 1:\n      raise ValueError(f\"Delimiter {delim!r} must be a single token.\")\n\n    logging.debug(\"Using delimiter %r for extraction alignment\", delim)\n\n    extraction_tokens = list(\n        _tokenize_with_lowercase(\n            f\" {delim} \".join(\n                extraction.extraction_text\n                for extraction in itertools.chain(*extraction_groups)\n            ),\n            tokenizer_inst=tokenizer_impl,\n        )\n    )\n\n    self._set_seqs(source_tokens, extraction_tokens)\n\n    index_to_extraction_group = {}\n    extraction_index = 0\n    for group_index, group in enumerate(extraction_groups):\n      logging.debug(\n          \"Processing extraction group %d with %d extractions.\",\n          group_index,\n          len(group),\n      )\n      for extraction in group:\n        # Validate delimiter doesn't appear in extraction text\n        if delim in extraction.extraction_text:\n          raise ValueError(\n              f\"Delimiter {delim!r} appears inside extraction text\"\n              f\" {extraction.extraction_text!r}. This would corrupt alignment\"\n              \" mapping.\"\n          )\n\n        index_to_extraction_group[extraction_index] = (extraction, group_index)\n        extraction_text_tokens = list(\n            _tokenize_with_lowercase(\n                extraction.extraction_text, tokenizer_inst=tokenizer_impl\n            )\n        )\n        extraction_index += len(extraction_text_tokens) + delim_len\n\n    aligned_extraction_groups: list[list[data.Extraction]] = [\n        [] for _ in extraction_groups\n    ]\n    tokenized_text = (\n        tokenizer_impl.tokenize(source_text)\n        if tokenizer_impl\n        else tokenizer_lib.tokenize(source_text)\n    )\n\n    # Track which extractions were aligned in the exact matching phase\n    aligned_extractions = []\n    exact_matches = 0\n    lesser_matches = 0\n\n    # Exact matching phase\n    for i, j, n in self._get_matching_blocks()[:-1]:\n      extraction, _ = index_to_extraction_group.get(j, (None, None))\n      if extraction is None:\n        logging.debug(\n            \"No clean start index found for extraction index=%d iterating\"\n            \" Difflib matching_blocks\",\n            j,\n        )\n        continue\n\n      extraction.token_interval = tokenizer_lib.TokenInterval(\n          start_index=i + token_offset,\n          end_index=i + n + token_offset,\n      )\n\n      try:\n        start_token = tokenized_text.tokens[i]\n        end_token = tokenized_text.tokens[i + n - 1]\n        extraction.char_interval = data.CharInterval(\n            start_pos=char_offset + start_token.char_interval.start_pos,\n            end_pos=char_offset + end_token.char_interval.end_pos,\n        )\n      except IndexError as e:\n        raise IndexError(\n            \"Failed to align extraction with source text. Extraction token\"\n            f\" interval {extraction.token_interval} does not match source text\"\n            f\" tokens {tokenized_text.tokens}.\"\n        ) from e\n\n      extraction_text_len = len(\n          list(\n              _tokenize_with_lowercase(\n                  extraction.extraction_text, tokenizer_inst=tokenizer_impl\n              )\n          )\n      )\n      if extraction_text_len < n:\n        raise ValueError(\n            \"Delimiter prevents blocks greater than extraction length: \"\n            f\"extraction_text_len={extraction_text_len}, block_size={n}\"\n        )\n      if extraction_text_len == n:\n        extraction.alignment_status = data.AlignmentStatus.MATCH_EXACT\n        exact_matches += 1\n        aligned_extractions.append(extraction)\n      else:\n        # Partial match (extraction longer than matched text)\n        if accept_match_lesser:\n          extraction.alignment_status = data.AlignmentStatus.MATCH_LESSER\n          lesser_matches += 1\n          aligned_extractions.append(extraction)\n        else:\n          # Reset intervals when not accepting lesser matches\n          extraction.token_interval = None\n          extraction.char_interval = None\n          extraction.alignment_status = None\n\n    # Collect unaligned extractions\n    unaligned_extractions = []\n    for extraction, _ in index_to_extraction_group.values():\n      if extraction not in aligned_extractions:\n        unaligned_extractions.append(extraction)\n\n    # Apply fuzzy alignment to remaining extractions\n    if enable_fuzzy_alignment and unaligned_extractions:\n      logging.debug(\n          \"Starting fuzzy alignment for %d unaligned extractions\",\n          len(unaligned_extractions),\n      )\n      for extraction in unaligned_extractions:\n        aligned_extraction = self._fuzzy_align_extraction(\n            extraction,\n            source_tokens,\n            tokenized_text,\n            token_offset,\n            char_offset,\n            fuzzy_alignment_threshold,\n            tokenizer_impl=tokenizer_impl,\n        )\n        if aligned_extraction:\n          aligned_extractions.append(aligned_extraction)\n          logging.debug(\n              \"Fuzzy alignment successful for extraction: %s\",\n              extraction.extraction_text,\n          )\n\n    for extraction, group_index in index_to_extraction_group.values():\n      aligned_extraction_groups[group_index].append(extraction)\n\n    logging.debug(\n        \"Final aligned extraction groups: %s\", aligned_extraction_groups\n    )\n    return aligned_extraction_groups\n\n\ndef _tokenize_with_lowercase(\n    text: str,\n    tokenizer_inst: tokenizer_lib.Tokenizer | None = None,\n) -> Iterator[str]:\n  \"\"\"Extract and lowercase tokens from the input text into words.\n\n  This function utilizes the tokenizer module to tokenize text and yields\n  lowercased words.\n\n  Args:\n    text (str): The text to be tokenized.\n    tokenizer_inst: Optional tokenizer instance.\n\n  Yields:\n    Iterator[str]: An iterator over tokenized words.\n  \"\"\"\n  if tokenizer_inst is not None:\n    tokenized_pb2 = tokenizer_inst.tokenize(text)\n  else:\n    tokenized_pb2 = tokenizer_lib.tokenize(text)\n  original_text = tokenized_pb2.text\n  for token in tokenized_pb2.tokens:\n    start = token.char_interval.start_pos\n    end = token.char_interval.end_pos\n    token_str = original_text[start:end]\n    token_str = token_str.lower()\n    yield token_str\n\n\n@functools.lru_cache(maxsize=10000)\ndef _normalize_token(token: str) -> str:\n  \"\"\"Lowercases and applies light pluralisation stemming.\"\"\"\n  token = token.lower()\n  if len(token) > 3 and token.endswith(\"s\") and not token.endswith(\"ss\"):\n    token = token[:-1]\n  return token\n"
  },
  {
    "path": "langextract/schema.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Schema compatibility layer.\n\nThis module provides backward compatibility for the schema module.\nNew code should import from langextract.core.schema instead.\n\"\"\"\n\nfrom __future__ import annotations\n\n# Re-export core schema items with deprecation warnings\nimport warnings\n\nfrom langextract._compat import schema\n\n\ndef __getattr__(name: str):\n  \"\"\"Handle imports with appropriate warnings.\"\"\"\n  core_items = {\n      \"BaseSchema\": (\"langextract.core.schema\", \"BaseSchema\"),\n      \"Constraint\": (\"langextract.core.schema\", \"Constraint\"),\n      \"ConstraintType\": (\"langextract.core.schema\", \"ConstraintType\"),\n      \"EXTRACTIONS_KEY\": (\"langextract.core.data\", \"EXTRACTIONS_KEY\"),\n      \"ATTRIBUTE_SUFFIX\": (\"langextract.core.data\", \"ATTRIBUTE_SUFFIX\"),\n      \"FormatModeSchema\": (\"langextract.core.schema\", \"FormatModeSchema\"),\n  }\n\n  if name in core_items:\n    mod, attr = core_items[name]\n    warnings.warn(\n        f\"`langextract.schema.{name}` has moved to `{mod}.{attr}`. Please\"\n        \" update your imports. This compatibility layer will be removed in\"\n        \" v2.0.0.\",\n        FutureWarning,\n        stacklevel=2,\n    )\n    module = __import__(mod, fromlist=[attr])\n    return getattr(module, attr)\n  elif name == \"GeminiSchema\":\n    return schema.__getattr__(name)\n\n  raise AttributeError(f\"module 'langextract.schema' has no attribute '{name}'\")\n"
  },
  {
    "path": "langextract/tokenizer.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Compatibility shim for langextract.tokenizer imports.\n\nThis module provides backward compatibility for code that imports from\nlangextract.tokenizer. All functionality has moved to langextract.core.tokenizer.\n\"\"\"\n\nfrom __future__ import annotations\n\n# Re-export everything from core.tokenizer for backward compatibility\n# pylint: disable=unused-wildcard-import\nfrom langextract.core.tokenizer import *\n"
  },
  {
    "path": "langextract/visualization.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utility functions for visualizing LangExtract extractions in notebooks.\n\nExample\n-------\n>>> import langextract as lx\n>>> doc = lx.extract(...)\n>>> lx.visualize(doc)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport dataclasses\nimport enum\nimport html\nimport itertools\nimport json\nimport pathlib\nimport textwrap\n\nfrom langextract import io\nfrom langextract.core import data\n\n# Fallback if IPython is not present\ntry:\n  from IPython import get_ipython  # type: ignore[import-not-found]\n  from IPython.display import HTML  # type: ignore[import-not-found]\nexcept ImportError:\n\n  def get_ipython():  # type: ignore[no-redef]\n    return None\n\n  HTML = None  # pytype: disable=annotation-type-mismatch\n\n\ndef _is_jupyter() -> bool:\n  \"\"\"Check if we're in a Jupyter/IPython environment that can display HTML.\"\"\"\n  try:\n    if get_ipython is None:\n      return False\n    ip = get_ipython()\n    if ip is None:\n      return False\n    # Simple check: if we're in IPython and NOT in a plain terminal\n    return ip.__class__.__name__ != 'TerminalInteractiveShell'\n  except Exception:\n    return False\n\n\n_PALETTE: list[str] = [\n    '#D2E3FC',  # Light Blue (Primary Container)\n    '#C8E6C9',  # Light Green (Tertiary Container)\n    '#FEF0C3',  # Light Yellow (Primary Color)\n    '#F9DEDC',  # Light Red (Error Container)\n    '#FFDDBE',  # Light Orange (Tertiary Container)\n    '#EADDFF',  # Light Purple (Secondary/Tertiary Container)\n    '#C4E9E4',  # Light Teal (Teal Container)\n    '#FCE4EC',  # Light Pink (Pink Container)\n    '#E8EAED',  # Very Light Grey (Neutral Highlight)\n    '#DDE8E8',  # Pale Cyan (Cyan Container)\n]\n\n_VISUALIZATION_CSS = textwrap.dedent(\"\"\"\\\n    <style>\n    .lx-highlight { position: relative; border-radius:3px; padding:1px 2px;}\n    .lx-highlight .lx-tooltip {\n      visibility: hidden;\n      opacity: 0;\n      transition: opacity 0.2s ease-in-out;\n      background: #333;\n      color: #fff;\n      text-align: left;\n      border-radius: 4px;\n      padding: 6px 8px;\n      position: absolute;\n      z-index: 1000;\n      bottom: 125%;\n      left: 50%;\n      transform: translateX(-50%);\n      font-size: 12px;\n      max-width: 240px;\n      white-space: normal;\n      box-shadow: 0 2px 6px rgba(0,0,0,0.3);\n    }\n    .lx-highlight:hover .lx-tooltip { visibility: visible; opacity:1; }\n    .lx-animated-wrapper { max-width: 100%; font-family: Arial, sans-serif; }\n    .lx-controls {\n      background: #fafafa; border: 1px solid #90caf9; border-radius: 8px;\n      padding: 12px; margin-bottom: 16px;\n    }\n    .lx-button-row {\n      display: flex; justify-content: center; gap: 8px; margin-bottom: 12px;\n    }\n    .lx-control-btn {\n      background: #4285f4; color: white; border: none; border-radius: 4px;\n      padding: 8px 16px; cursor: pointer; font-size: 13px; font-weight: 500;\n      transition: background-color 0.2s;\n    }\n    .lx-control-btn:hover { background: #3367d6; }\n    .lx-progress-container {\n      margin-bottom: 8px;\n    }\n    .lx-progress-slider {\n      width: 100%; margin: 0; appearance: none; height: 6px;\n      background: #ddd; border-radius: 3px; outline: none;\n    }\n    .lx-progress-slider::-webkit-slider-thumb {\n      appearance: none; width: 18px; height: 18px; background: #4285f4;\n      border-radius: 50%; cursor: pointer;\n    }\n    .lx-progress-slider::-moz-range-thumb {\n      width: 18px; height: 18px; background: #4285f4; border-radius: 50%;\n      cursor: pointer; border: none;\n    }\n    .lx-status-text {\n      text-align: center; font-size: 12px; color: #666; margin-top: 4px;\n    }\n    .lx-text-window {\n      font-family: monospace; white-space: pre-wrap; border: 1px solid #90caf9;\n      padding: 12px; max-height: 260px; overflow-y: auto; margin-bottom: 12px;\n      line-height: 1.6;\n    }\n    .lx-attributes-panel {\n      background: #fafafa; border: 1px solid #90caf9; border-radius: 6px;\n      padding: 8px 10px; margin-top: 8px; font-size: 13px;\n    }\n    .lx-current-highlight {\n      border-bottom: 4px solid #ff4444;\n      font-weight: bold;\n      animation: lx-pulse 1s ease-in-out;\n    }\n    @keyframes lx-pulse {\n      0% { text-decoration-color: #ff4444; }\n      50% { text-decoration-color: #ff0000; }\n      100% { text-decoration-color: #ff4444; }\n    }\n    .lx-legend {\n      font-size: 12px; margin-bottom: 8px;\n      padding-bottom: 8px; border-bottom: 1px solid #e0e0e0;\n    }\n    .lx-label {\n      display: inline-block;\n      padding: 2px 4px;\n      border-radius: 3px;\n      margin-right: 4px;\n      color: #000;\n    }\n    .lx-attr-key {\n      font-weight: 600;\n      color: #1565c0;\n      letter-spacing: 0.3px;\n    }\n    .lx-attr-value {\n      font-weight: 400;\n      opacity: 0.85;\n      letter-spacing: 0.2px;\n    }\n\n    /* Add optimizations with larger fonts and better readability for GIFs */\n    .lx-gif-optimized .lx-text-window { font-size: 16px; line-height: 1.8; }\n    .lx-gif-optimized .lx-attributes-panel { font-size: 15px; }\n    .lx-gif-optimized .lx-current-highlight { text-decoration-thickness: 4px; }\n    </style>\"\"\")\n\n\ndef _assign_colors(extractions: list[data.Extraction]) -> dict[str, str]:\n  \"\"\"Assigns a background colour to each extraction class.\n\n  Args:\n    extractions: list of extractions.\n\n  Returns:\n    Mapping from extraction_class to a hex colour string.\n  \"\"\"\n  classes = {e.extraction_class for e in extractions if e.char_interval}\n  color_map: dict[str, str] = {}\n  palette_cycle = itertools.cycle(_PALETTE)\n  for cls in sorted(classes):\n    color_map[cls] = next(palette_cycle)\n  return color_map\n\n\ndef _filter_valid_extractions(\n    extractions: list[data.Extraction],\n) -> list[data.Extraction]:\n  \"\"\"Filters extractions to only include those with valid char intervals.\"\"\"\n  return [\n      e\n      for e in extractions\n      if (\n          e.char_interval\n          and e.char_interval.start_pos is not None\n          and e.char_interval.end_pos is not None\n      )\n  ]\n\n\nclass TagType(enum.Enum):\n  \"\"\"Enum for span boundary tag types.\"\"\"\n\n  START = 'start'\n  END = 'end'\n\n\n@dataclasses.dataclass(frozen=True)\nclass SpanPoint:\n  \"\"\"Represents a span boundary point for HTML generation.\n\n  Attributes:\n    position: Character position in the text.\n    tag_type: Type of span boundary (START or END).\n    span_idx: Index of the span for HTML data-idx attribute.\n    extraction: The extraction data associated with this span.\n  \"\"\"\n\n  position: int\n  tag_type: TagType\n  span_idx: int\n  extraction: data.Extraction\n\n\ndef _build_highlighted_text(\n    text: str,\n    extractions: list[data.Extraction],\n    color_map: dict[str, str],\n) -> str:\n  \"\"\"Returns text with <span> highlights inserted, supporting nesting.\n\n  Args:\n    text: Original document text.\n    extractions: List of extraction objects with char_intervals.\n    color_map: Mapping of extraction_class to colour.\n  \"\"\"\n  points = []\n  span_lengths = {}\n  for index, extraction in enumerate(extractions):\n    if (\n        not extraction.char_interval\n        or extraction.char_interval.start_pos is None\n        or extraction.char_interval.end_pos is None\n        or extraction.char_interval.start_pos\n        >= extraction.char_interval.end_pos\n    ):\n      continue\n\n    start_pos = extraction.char_interval.start_pos\n    end_pos = extraction.char_interval.end_pos\n    points.append(SpanPoint(start_pos, TagType.START, index, extraction))\n    points.append(SpanPoint(end_pos, TagType.END, index, extraction))\n    span_lengths[index] = end_pos - start_pos\n\n  def sort_key(point: SpanPoint):\n    \"\"\"Sorts span boundary points for proper HTML nesting.\n\n    Sorts by position first, then handles ties at the same position to ensure\n    proper HTML nesting. At the same position:\n    1. End tags come before start tags (to close before opening)\n    2. Among end tags: shorter spans close first\n    3. Among start tags: longer spans open first\n\n    Args:\n      point: SpanPoint containing position, tag_type, span_idx, and extraction.\n\n    Returns:\n      Sort key tuple ensuring proper nesting order.\n    \"\"\"\n    span_length = span_lengths.get(point.span_idx, 0)\n\n    if point.tag_type == TagType.END:\n      return (point.position, 0, span_length)\n    else:  # point.tag_type == TagType.START\n      return (point.position, 1, -span_length)\n\n  points.sort(key=sort_key)\n\n  html_parts: list[str] = []\n  cursor = 0\n  for point in points:\n    if point.position > cursor:\n      html_parts.append(html.escape(text[cursor : point.position]))\n\n    if point.tag_type == TagType.START:\n      colour = color_map.get(point.extraction.extraction_class, '#ffff8d')\n      highlight_class = ' lx-current-highlight' if point.span_idx == 0 else ''\n\n      span_html = (\n          f'<span class=\"lx-highlight{highlight_class}\"'\n          f' data-idx=\"{point.span_idx}\" style=\"background-color:{colour};\">'\n      )\n      html_parts.append(span_html)\n    else:  # point.tag_type == TagType.END\n      html_parts.append('</span>')\n\n    cursor = point.position\n\n  if cursor < len(text):\n    html_parts.append(html.escape(text[cursor:]))\n  return ''.join(html_parts)\n\n\ndef _build_legend_html(color_map: dict[str, str]) -> str:\n  \"\"\"Builds legend HTML showing extraction classes and their colors.\"\"\"\n  if not color_map:\n    return ''\n\n  legend_items = []\n  for extraction_class, colour in color_map.items():\n    legend_items.append(\n        '<span class=\"lx-label\"'\n        f' style=\"background-color:{colour};\">{html.escape(extraction_class)}</span>'\n    )\n  return (\n      '<div class=\"lx-legend\">Highlights Legend:'\n      f' {\" \".join(legend_items)}</div>'\n  )\n\n\ndef _format_attributes(attributes: dict | None) -> str:\n  \"\"\"Formats attributes as a single-line string.\"\"\"\n  if not attributes:\n    return '{}'\n\n  valid_attrs = {\n      key: value\n      for key, value in attributes.items()\n      if value not in (None, '', 'null')\n  }\n\n  if not valid_attrs:\n    return '{}'\n\n  attrs_parts = []\n  for key, value in valid_attrs.items():\n    # Clean up array formatting for better readability\n    if isinstance(value, list):\n      value_str = ', '.join(str(v) for v in value)\n    else:\n      value_str = str(value)\n    attrs_parts.append(\n        f'<span class=\"lx-attr-key\">{html.escape(str(key))}</span>: <span'\n        f' class=\"lx-attr-value\">{html.escape(value_str)}</span>'\n    )\n  return '{' + ', '.join(attrs_parts) + '}'\n\n\ndef _prepare_extraction_data(\n    text: str,\n    extractions: list[data.Extraction],\n    color_map: dict[str, str],\n    context_chars: int = 150,\n) -> list[dict]:\n  \"\"\"Prepares JavaScript data for extractions.\"\"\"\n  extraction_data = []\n  for i, extraction in enumerate(extractions):\n    # Assertions to inform pytype about the invariants guaranteed by _filter_valid_extractions\n    assert (\n        extraction.char_interval is not None\n    ), 'char_interval must be non-None for valid extractions'\n    assert (\n        extraction.char_interval.start_pos is not None\n    ), 'start_pos must be non-None for valid extractions'\n    assert (\n        extraction.char_interval.end_pos is not None\n    ), 'end_pos must be non-None for valid extractions'\n\n    start_pos = extraction.char_interval.start_pos\n    end_pos = extraction.char_interval.end_pos\n\n    context_start = max(0, start_pos - context_chars)\n    context_end = min(len(text), end_pos + context_chars)\n\n    before_text = text[context_start:start_pos]\n    extraction_text = text[start_pos:end_pos]\n    after_text = text[end_pos:context_end]\n\n    colour = color_map.get(extraction.extraction_class, '#ffff8d')\n\n    # Build attributes display\n    attributes_html = (\n        '<div><strong>class:</strong>'\n        f' {html.escape(extraction.extraction_class)}</div>'\n    )\n    attributes_html += (\n        '<div><strong>attributes:</strong>'\n        f' {_format_attributes(extraction.attributes)}</div>'\n    )\n\n    extraction_data.append({\n        'index': i,\n        'class': extraction.extraction_class,\n        'text': extraction.extraction_text,\n        'color': colour,\n        'startPos': start_pos,\n        'endPos': end_pos,\n        'beforeText': html.escape(before_text),\n        'extractionText': html.escape(extraction_text),\n        'afterText': html.escape(after_text),\n        'attributesHtml': attributes_html,\n    })\n\n  return extraction_data\n\n\ndef _build_visualization_html(\n    text: str,\n    extractions: list[data.Extraction],\n    color_map: dict[str, str],\n    animation_speed: float = 1.0,\n    show_legend: bool = True,\n) -> str:\n  \"\"\"Builds the complete visualization HTML.\"\"\"\n  if not extractions:\n    return (\n        '<div class=\"lx-animated-wrapper\"><p>No extractions to'\n        ' animate.</p></div>'\n    )\n\n  # Sort extractions by position for proper HTML nesting.\n  def _extraction_sort_key(extraction):\n    \"\"\"Sort by position, then by span length descending for proper nesting.\"\"\"\n    start = extraction.char_interval.start_pos\n    end = extraction.char_interval.end_pos\n    span_length = end - start\n    return (start, -span_length)  # longer spans first\n\n  sorted_extractions = sorted(extractions, key=_extraction_sort_key)\n\n  highlighted_text = _build_highlighted_text(\n      text, sorted_extractions, color_map\n  )\n  extraction_data = _prepare_extraction_data(\n      text, sorted_extractions, color_map\n  )\n  legend_html = _build_legend_html(color_map) if show_legend else ''\n\n  js_data = json.dumps(extraction_data)\n\n  # Prepare pos_info_str safely for pytype for the f-string below\n  first_extraction = extractions[0]\n  assert (\n      first_extraction.char_interval\n      and first_extraction.char_interval.start_pos is not None\n      and first_extraction.char_interval.end_pos is not None\n  ), 'first extraction must have valid char_interval with start_pos and end_pos'\n  pos_info_str = f'[{first_extraction.char_interval.start_pos}-{first_extraction.char_interval.end_pos}]'\n\n  html_content = textwrap.dedent(f\"\"\"\n    <div class=\"lx-animated-wrapper\">\n      <div class=\"lx-attributes-panel\">\n        {legend_html}\n        <div id=\"attributesContainer\"></div>\n      </div>\n      <div class=\"lx-text-window\" id=\"textWindow\">\n        {highlighted_text}\n      </div>\n      <div class=\"lx-controls\">\n        <div class=\"lx-button-row\">\n          <button class=\"lx-control-btn\" onclick=\"playPause()\">▶️ Play</button>\n          <button class=\"lx-control-btn\" onclick=\"prevExtraction()\">⏮ Previous</button>\n          <button class=\"lx-control-btn\" onclick=\"nextExtraction()\">⏭ Next</button>\n        </div>\n        <div class=\"lx-progress-container\">\n          <input type=\"range\" id=\"progressSlider\" class=\"lx-progress-slider\"\n                 min=\"0\" max=\"{len(extractions)-1}\" value=\"0\"\n                 onchange=\"jumpToExtraction(this.value)\">\n        </div>\n        <div class=\"lx-status-text\">\n          Entity <span id=\"entityInfo\">1/{len(extractions)}</span> |\n          Pos <span id=\"posInfo\">{pos_info_str}</span>\n        </div>\n      </div>\n    </div>\n\n    <script>\n      (function() {{\n        const extractions = {js_data};\n        let currentIndex = 0;\n        let isPlaying = false;\n        let animationInterval = null;\n        let animationSpeed = {animation_speed};\n\n        function updateDisplay() {{\n          const extraction = extractions[currentIndex];\n          if (!extraction) return;\n\n          document.getElementById('attributesContainer').innerHTML = extraction.attributesHtml;\n          document.getElementById('entityInfo').textContent = (currentIndex + 1) + '/' + extractions.length;\n          document.getElementById('posInfo').textContent = '[' + extraction.startPos + '-' + extraction.endPos + ']';\n          document.getElementById('progressSlider').value = currentIndex;\n\n          const playBtn = document.querySelector('.lx-control-btn');\n          if (playBtn) playBtn.textContent = isPlaying ? '⏸ Pause' : '▶️ Play';\n\n          const prevHighlight = document.querySelector('.lx-text-window .lx-current-highlight');\n          if (prevHighlight) prevHighlight.classList.remove('lx-current-highlight');\n          const currentSpan = document.querySelector('.lx-text-window span[data-idx=\"' + currentIndex + '\"]');\n          if (currentSpan) {{\n            currentSpan.classList.add('lx-current-highlight');\n            currentSpan.scrollIntoView({{block: 'center', behavior: 'smooth'}});\n          }}\n        }}\n\n        function nextExtraction() {{\n          currentIndex = (currentIndex + 1) % extractions.length;\n          updateDisplay();\n        }}\n\n        function prevExtraction() {{\n          currentIndex = (currentIndex - 1 + extractions.length) % extractions.length;\n          updateDisplay();\n        }}\n\n        function jumpToExtraction(index) {{\n          currentIndex = parseInt(index);\n          updateDisplay();\n        }}\n\n        function playPause() {{\n          if (isPlaying) {{\n            clearInterval(animationInterval);\n            isPlaying = false;\n          }} else {{\n            animationInterval = setInterval(nextExtraction, animationSpeed * 1000);\n            isPlaying = true;\n          }}\n          updateDisplay();\n        }}\n\n        window.playPause = playPause;\n        window.nextExtraction = nextExtraction;\n        window.prevExtraction = prevExtraction;\n        window.jumpToExtraction = jumpToExtraction;\n\n        updateDisplay();\n      }})();\n    </script>\"\"\")\n\n  return html_content\n\n\ndef visualize(\n    data_source: data.AnnotatedDocument | str | pathlib.Path,\n    *,\n    animation_speed: float = 1.0,\n    show_legend: bool = True,\n    gif_optimized: bool = True,\n) -> HTML | str:\n  \"\"\"Visualises extraction data as animated highlighted HTML.\n\n  Args:\n    data_source: Either an AnnotatedDocument or path to a JSONL file.\n    animation_speed: Animation speed in seconds between extractions.\n    show_legend: If ``True``, appends a colour legend mapping extraction classes\n      to colours.\n    gif_optimized: If ``True``, applies GIF-optimized styling with larger fonts,\n      better contrast, and improved dimensions for video capture.\n\n  Returns:\n    An :class:`IPython.display.HTML` object if IPython is available, otherwise\n    the generated HTML string.\n  \"\"\"\n  # Load document if it's a file path\n  if isinstance(data_source, (str, pathlib.Path)):\n    file_path = pathlib.Path(data_source)\n    if not file_path.exists():\n      raise FileNotFoundError(f'JSONL file not found: {file_path}')\n\n    documents = list(io.load_annotated_documents_jsonl(file_path))\n    if not documents:\n      raise ValueError(f'No documents found in JSONL file: {file_path}')\n\n    annotated_doc = documents[0]  # Use first document\n  else:\n    annotated_doc = data_source\n\n  if not annotated_doc or annotated_doc.text is None:\n    raise ValueError('annotated_doc must contain text to visualise.')\n\n  if annotated_doc.extractions is None:\n    raise ValueError('annotated_doc must contain extractions to visualise.')\n\n  # Filter valid extractions - show ALL of them\n  valid_extractions = _filter_valid_extractions(annotated_doc.extractions)\n\n  if not valid_extractions:\n    empty_html = (\n        '<div class=\"lx-animated-wrapper\"><p>No valid extractions to'\n        ' animate.</p></div>'\n    )\n    full_html = _VISUALIZATION_CSS + empty_html\n    if HTML is not None and _is_jupyter():\n      return HTML(full_html)\n    return full_html\n\n  color_map = _assign_colors(valid_extractions)\n\n  visualization_html = _build_visualization_html(\n      annotated_doc.text,\n      valid_extractions,\n      color_map,\n      animation_speed,\n      show_legend,\n  )\n\n  full_html = _VISUALIZATION_CSS + visualization_html\n\n  # Apply GIF optimizations if requested\n  if gif_optimized:\n    full_html = full_html.replace(\n        'class=\"lx-animated-wrapper\"',\n        'class=\"lx-animated-wrapper lx-gif-optimized\"',\n    )\n\n  if HTML is not None and _is_jupyter():\n    return HTML(full_html)\n  return full_html\n"
  },
  {
    "path": "pyproject.toml",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n[build-system]\nrequires = [\"setuptools>=67.0.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n\n[project]\nname = \"langextract\"\nversion = \"1.1.1\"\ndescription = \"LangExtract: A library for extracting structured data from language models\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = \"Apache-2.0\"\nauthors = [\n    {name = \"Akshay Goel\", email = \"goelak@google.com\"}\n]\ndependencies = [\n    \"absl-py>=1.0.0\",\n    \"aiohttp>=3.8.0\",\n    \"async_timeout>=4.0.0\",\n    \"exceptiongroup>=1.1.0\",\n    \"google-genai>=1.39.0\",\n    \"google-cloud-storage>=2.14.0\",\n    \"ml-collections>=0.1.0\",\n    \"more-itertools>=8.0.0\",\n    \"numpy>=1.20.0\",\n    \"pandas>=1.3.0\",\n    \"pydantic>=1.8.0\",\n    \"python-dotenv>=0.19.0\",\n    \"PyYAML>=6.0\",\n    \"regex>=2023.0.0\",\n    \"requests>=2.25.0\",\n    \"tqdm>=4.64.0\",\n    \"typing-extensions>=4.0.0\"\n]\n\n[project.urls]\n\"Homepage\" = \"https://github.com/google/langextract\"\n\"Repository\" = \"https://github.com/google/langextract\"\n\"Documentation\" = \"https://github.com/google/langextract/blob/main/README.md\"\n\"Bug Tracker\" = \"https://github.com/google/langextract/issues\"\n\"Changelog\" = \"https://github.com/google/langextract/releases\"\n\"DOI\" = \"https://doi.org/10.5281/zenodo.17015089\"\n\n[project.optional-dependencies]\nopenai = [\"openai>=1.50.0\"]\nall = [\"openai>=1.50.0\"]\ndev = [\n    \"pyink~=24.3.0\",\n    \"isort>=5.13.0\",\n    \"pylint>=3.0.0\",\n    \"pytype>=2024.10.11\",\n    \"tox>=4.0.0\",\n    \"import-linter>=2.0\",\n    \"pre-commit>=3.5.0\",\n    \"types-regex>=2023.0.0\"\n]\ntest = [\n    \"pytest>=7.4.0\",\n    \"tomli>=2.0.0\"\n]\nnotebook = [\n    \"ipython>=7.0.0\",\n    \"notebook>=6.0.0\"\n]\n\n[tool.setuptools]\npackages = [\n    \"langextract\",\n    \"langextract._compat\",\n    \"langextract.core\",\n    \"langextract.providers\",\n    \"langextract.providers.schemas\"\n]\ninclude-package-data = true\n\n[tool.setuptools.package-data]\nlangextract = [\"py.typed\"]\n\n# Provider discovery mechanism for built-in and third-party providers\n[project.entry-points.\"langextract.providers\"]\ngemini = \"langextract.providers.gemini:GeminiLanguageModel\"\nollama = \"langextract.providers.ollama:OllamaLanguageModel\"\nopenai = \"langextract.providers.openai:OpenAILanguageModel\"\n\n[tool.setuptools.exclude-package-data]\n\"*\" = [\n    \"docs*\",\n    \"tests*\",\n    \"kokoro*\",\n    \"*.gif\",\n    \"*.svg\",\n]\n\n[tool.pytest.ini_options]\ntestpaths = [\"tests\"]\npython_files = \"*_test.py\"\npython_classes = \"Test*\"\npython_functions = \"test_*\"\n# Show extra test summary info\naddopts = \"-ra\"\nmarkers = [\n    \"live_api: marks tests as requiring live API access\",\n    \"requires_pip: marks tests that perform pip install/uninstall operations\",\n    \"integration: marks integration tests that test multiple components together\",\n]\n\n[tool.pyink]\n# Configuration for Google's style guide\nline-length = 80\nunstable = true\npyink-indentation = 2\npyink-use-majority-quotes = true\n\n[tool.isort]\n# Configuration for Google's style guide\nprofile = \"google\"\nline_length = 80\nforce_sort_within_sections = true\n# Allow multiple imports on one line for these modules\nsingle_line_exclusions = [\"typing\", \"typing_extensions\", \"collections.abc\"]\n\n[tool.importlinter]\nroot_package = \"langextract\"\n\n\n[[tool.importlinter.contracts]]\nname = \"Providers must not import inference\"\ntype = \"forbidden\"\nsource_modules = [\"langextract.providers\"]\nforbidden_modules = [\"langextract.inference\"]\n\n[[tool.importlinter.contracts]]\nname = \"Core must not import providers\"\ntype = \"forbidden\"\nsource_modules = [\"langextract.core\"]\nforbidden_modules = [\"langextract.providers\"]\n\n[[tool.importlinter.contracts]]\nname = \"Core must not import high-level modules\"\ntype = \"forbidden\"\nsource_modules = [\"langextract.core\"]\nforbidden_modules = [\n  \"langextract.annotation\",\n  \"langextract.chunking\",\n  \"langextract.prompting\",\n  \"langextract.resolver\",\n]\n"
  },
  {
    "path": "scripts/create_provider_plugin.py",
    "content": "#!/usr/bin/env python3\n# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Create a new LangExtract provider plugin with all boilerplate code.\n\nThis script automates steps 1-6 of the provider creation checklist:\n1. Setup Package Structure\n2. Configure Entry Point\n3. Implement Provider\n4. Add Schema Support (optional)\n5. Create and run tests\n6. Generate documentation\n\nFor detailed documentation, see:\nhttps://github.com/google/langextract/blob/main/langextract/providers/README.md\n\nUsage:\n    python create_provider_plugin.py MyProvider\n    python create_provider_plugin.py MyProvider --with-schema\n    python create_provider_plugin.py MyProvider --patterns \"^mymodel\" \"^custom\"\n\"\"\"\n\nimport argparse\nimport os\nfrom pathlib import Path\nimport re\nimport subprocess\nimport sys\nimport textwrap\n\n\ndef create_directory_structure(package_name: str, force: bool = False) -> Path:\n  \"\"\"Step 1: Setup Package Structure.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"STEP 1: Setup Package Structure\")\n  print(\"=\" * 60)\n\n  base_dir = Path(f\"langextract-{package_name}\")\n  package_dir = base_dir / f\"langextract_{package_name}\"\n\n  if base_dir.exists() and any(base_dir.iterdir()) and not force:\n    print(f\"ERROR: {base_dir} already exists and is not empty.\")\n    print(\"Use --force to overwrite or choose a different package name.\")\n    sys.exit(1)\n\n  base_dir.mkdir(parents=True, exist_ok=True)\n  package_dir.mkdir(parents=True, exist_ok=True)\n\n  print(f\"✓ Created directory: {base_dir}/\")\n  print(f\"✓ Created package: {package_dir}/\")\n  print(\"✅ Step 1 complete: Package structure created\")\n\n  return base_dir\n\n\ndef create_pyproject_toml(\n    base_dir: Path, provider_name: str, package_name: str\n) -> None:\n  \"\"\"Step 2: Configure Entry Point.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"STEP 2: Configure Entry Point\")\n  print(\"=\" * 60)\n\n  content = textwrap.dedent(f\"\"\"\\\n        [build-system]\n        requires = [\"setuptools>=61.0\", \"wheel\"]\n        build-backend = \"setuptools.build_meta\"\n\n        [project]\n        name = \"langextract-{package_name}\"\n        version = \"0.1.0\"\n        description = \"LangExtract provider plugin for {provider_name}\"\n        readme = \"README.md\"\n        requires-python = \">=3.10\"\n        license = {{text = \"Apache-2.0\"}}\n        dependencies = [\n            \"langextract>=1.0.0\",\n            # Add your provider's SDK dependencies here\n        ]\n\n        [project.entry-points.\"langextract.providers\"]\n        {package_name} = \"langextract_{package_name}.provider:{provider_name}LanguageModel\"\n\n        [tool.setuptools.packages.find]\n        where = [\".\"]\n        include = [\"langextract_{package_name}*\"]\n    \"\"\")\n\n  (base_dir / \"pyproject.toml\").write_text(content, encoding=\"utf-8\")\n  print(\"✓ Created pyproject.toml with entry point configuration\")\n  print(\"✅ Step 2 complete: Entry point configured\")\n\n\ndef create_provider(\n    base_dir: Path,\n    provider_name: str,\n    package_name: str,\n    patterns: list[str],\n    with_schema: bool,\n) -> None:\n  \"\"\"Step 3: Implement Provider.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"STEP 3: Implement Provider\")\n  print(\"=\" * 60)\n\n  package_dir = base_dir / f\"langextract_{package_name}\"\n\n  patterns_str = \", \".join(f\"r'{p}'\" for p in patterns)\n  env_var_safe = re.sub(r\"[^A-Z0-9]+\", \"_\", package_name.upper()) + \"_API_KEY\"\n\n  schema_imports = (\n      f\"\"\"\nfrom langextract_{package_name}.schema import {provider_name}Schema\"\"\"\n      if with_schema\n      else \"\"\n  )\n\n  schema_init = (\n      \"\"\"\n                self.response_schema = kwargs.get('response_schema')\n                self.structured_output = kwargs.get('structured_output', False)\"\"\"\n      if with_schema\n      else \"\"\n  )\n\n  schema_methods = f\"\"\"\n\n            @classmethod\n            def get_schema_class(cls):\n                \\\"\\\"\\\"Tell LangExtract about our schema support.\\\"\\\"\\\"\n                from langextract_{package_name}.schema import {provider_name}Schema\n                return {provider_name}Schema\n\n            def apply_schema(self, schema_instance):\n                \\\"\\\"\\\"Apply or clear schema configuration.\\\"\\\"\\\"\n                super().apply_schema(schema_instance)\n                if schema_instance:\n                    config = schema_instance.to_provider_config()\n                    self.response_schema = config.get('response_schema')\n                    self.structured_output = config.get('structured_output', False)\n                else:\n                    self.response_schema = None\n                    self.structured_output = False\"\"\" if with_schema else \"\"\n\n  schema_infer = (\n      \"\"\"\n                    api_params = {}\n                    if self.response_schema:\n                        api_params['response_schema'] = self.response_schema\n                    # result = self.client.generate(prompt, **api_params)\"\"\"\n      if with_schema\n      else \"\"\"\n                    # result = self.client.generate(prompt, **kwargs)\"\"\"\n  )\n\n  provider_content = textwrap.dedent(f'''\\\n        \"\"\"Provider implementation for {provider_name}.\"\"\"\n\n        import os\n        import langextract as lx{schema_imports}\n\n\n        @lx.providers.registry.register({patterns_str}, priority=10)\n        class {provider_name}LanguageModel(lx.inference.BaseLanguageModel):\n            \"\"\"LangExtract provider for {provider_name}.\n\n            This provider handles model IDs matching: {patterns}\n            \"\"\"\n\n            def __init__(self, model_id: str, api_key: str = None, **kwargs):\n                \"\"\"Initialize the {provider_name} provider.\n\n                Args:\n                    model_id: The model identifier.\n                    api_key: API key for authentication.\n                    **kwargs: Additional provider-specific parameters.\n                \"\"\"\n                super().__init__()\n                self.model_id = model_id\n                self.api_key = api_key or os.environ.get('{env_var_safe}'){schema_init}\n\n                # self.client = YourClient(api_key=self.api_key)\n                self._extra_kwargs = kwargs{schema_methods}\n\n            def infer(self, batch_prompts, **kwargs):\n                \"\"\"Run inference on a batch of prompts.\n\n                Args:\n                    batch_prompts: List of prompts to process.\n                    **kwargs: Additional inference parameters.\n\n                Yields:\n                    Lists of ScoredOutput objects, one per prompt.\n                \"\"\"\n                for prompt in batch_prompts:{schema_infer}\n                    result = f\"Mock response for: {{prompt[:50]}}...\"\n                    yield [lx.inference.ScoredOutput(score=1.0, output=result)]\n    ''')\n\n  (package_dir / \"provider.py\").write_text(provider_content, encoding=\"utf-8\")\n  print(\"✓ Created provider.py with mock implementation\")\n\n  # Create __init__.py\n  init_content = textwrap.dedent(f'''\\\n        \"\"\"LangExtract provider plugin for {provider_name}.\"\"\"\n\n        from langextract_{package_name}.provider import {provider_name}LanguageModel\n\n        __all__ = ['{provider_name}LanguageModel']\n        __version__ = \"0.1.0\"\n    ''')\n\n  (package_dir / \"__init__.py\").write_text(init_content, encoding=\"utf-8\")\n  print(\"✓ Created __init__.py with exports\")\n  print(\"✅ Step 3 complete: Provider implementation created\")\n\n\ndef create_schema(\n    base_dir: Path, provider_name: str, package_name: str\n) -> None:\n  \"\"\"Step 4: Add Schema Support.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"STEP 4: Add Schema Support (Optional)\")\n  print(\"=\" * 60)\n\n  package_dir = base_dir / f\"langextract_{package_name}\"\n\n  schema_content = textwrap.dedent(f'''\\\n        \"\"\"Schema implementation for {provider_name} provider.\"\"\"\n\n        import langextract as lx\n        from langextract import schema\n\n\n        class {provider_name}Schema(lx.schema.BaseSchema):\n            \"\"\"Schema implementation for {provider_name} structured output.\"\"\"\n\n            def __init__(self, schema_dict: dict):\n                \"\"\"Initialize the schema with a dictionary.\"\"\"\n                self._schema_dict = schema_dict\n\n            @property\n            def schema_dict(self) -> dict:\n                \"\"\"Return the schema dictionary.\"\"\"\n                return self._schema_dict\n\n            @classmethod\n            def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n                \"\"\"Build schema from example extractions.\n\n                Args:\n                    examples_data: Sequence of ExampleData objects.\n                    attribute_suffix: Suffix for attribute fields.\n\n                Returns:\n                    A configured {provider_name}Schema instance.\n                \"\"\"\n                extraction_types = {{}}\n                for example in examples_data:\n                    for extraction in example.extractions:\n                        class_name = extraction.extraction_class\n                        if class_name not in extraction_types:\n                            extraction_types[class_name] = set()\n                        if extraction.attributes:\n                            extraction_types[class_name].update(extraction.attributes.keys())\n\n                schema_dict = {{\n                    \"type\": \"object\",\n                    \"properties\": {{\n                        \"extractions\": {{\n                            \"type\": \"array\",\n                            \"items\": {{\"type\": \"object\"}}\n                        }}\n                    }},\n                    \"required\": [\"extractions\"]\n                }}\n\n                return cls(schema_dict)\n\n            def to_provider_config(self) -> dict:\n                \"\"\"Convert to provider-specific configuration.\n\n                Returns:\n                    Dictionary of provider-specific configuration.\n                \"\"\"\n                return {{\n                    \"response_schema\": self._schema_dict,\n                    \"structured_output\": True\n                }}\n\n            @property\n            def supports_strict_mode(self) -> bool:\n                \"\"\"Whether this schema guarantees valid structured output.\n\n                Returns:\n                    True if the provider enforces valid JSON output.\n                \"\"\"\n                return False  # Set to True only if your provider guarantees valid JSON\n    ''')\n\n  (package_dir / \"schema.py\").write_text(schema_content, encoding=\"utf-8\")\n  print(\"✓ Created schema.py with BaseSchema implementation\")\n  print(\"✅ Step 4 complete: Schema support added\")\n\n\ndef create_test_script(\n    base_dir: Path,\n    provider_name: str,\n    package_name: str,\n    patterns: list[str],\n    with_schema: bool,\n) -> None:\n  \"\"\"Step 5: Create and run tests.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"STEP 5: Create Tests\")\n  print(\"=\" * 60)\n\n  patterns_literal = \"[\" + \", \".join(repr(p) for p in patterns) + \"]\"\n  provider_cls_name = f\"{provider_name}LanguageModel\"\n\n  test_content = textwrap.dedent(f'''\\\n        #!/usr/bin/env python3\n        \"\"\"Test script for {provider_name} provider (Step 5 checklist).\"\"\"\n\n        import re\n        import sys\n        import langextract as lx\n        from langextract.providers import registry\n\n        try:\n            from langextract_{package_name} import {provider_cls_name}\n        except ImportError:\n            print(\"ERROR: Plugin not installed. Run: pip install -e .\")\n            sys.exit(1)\n\n        lx.providers.load_plugins_once()\n\n        PROVIDER_CLS_NAME = \"{provider_cls_name}\"\n        PATTERNS = {patterns_literal}\n\n        def _example_id(pattern: str) -> str:\n            \\\"\\\"\\\"Generate test model ID from pattern.\\\"\\\"\\\"\n            base = re.sub(r'^\\\\^', '', pattern)\n            m = re.match(r\"[A-Za-z0-9._-]+\", base)\n            base = m.group(0) if m else (base or \"model\")\n            return f\"{{base}}-test\"\n\n        sample_ids = [_example_id(p) for p in PATTERNS]\n        sample_ids.append(\"unknown-model\")\n\n        print(\"Testing {provider_name} Provider - Step 5 Checklist:\")\n        print(\"-\" * 50)\n\n        # 1 & 2. Provider registration + pattern matching via resolve()\n        print(\"1–2. Provider registration & pattern matching\")\n        for model_id in sample_ids:\n            try:\n                provider_class = registry.resolve(model_id)\n                ok = provider_class.__name__ == PROVIDER_CLS_NAME\n                status = \"✓\" if (ok or model_id == \"unknown-model\") else \"✗\"\n                note = \"expected\" if ok else (\"expected (no provider)\" if model_id == \"unknown-model\" else \"unexpected provider\")\n                print(f\"   {{status}} {{model_id}} -> {{provider_class.__name__ if ok else 'resolved'}} {{note}}\")\n            except Exception as e:\n                if model_id == \"unknown-model\":\n                    print(f\"   ✓ {{model_id}}: No provider found (expected)\")\n                else:\n                    print(f\"   ✗ {{model_id}}: resolve() failed: {{e}}\")\n\n        # 3. Inference sanity check\n        print(\"\\\\n3. Test inference with sample prompts\")\n        try:\n            model_id = sample_ids[0] if sample_ids[0] != \"unknown-model\" else (_example_id(PATTERNS[0]) if PATTERNS else \"test-model\")\n            provider = {provider_cls_name}(model_id=model_id)\n            prompts = [\"Test prompt 1\", \"Test prompt 2\"]\n            results = list(provider.infer(prompts))\n            print(f\"   ✓ Inference returned {{len(results)}} results\")\n            for i, result in enumerate(results):\n                try:\n                    out = result[0].output if result and result[0] else None\n                    print(f\"   ✓ Result {{i+1}}: {{(out or '')[:60]}}...\")\n                except Exception:\n                    print(f\"   ✗ Result {{i+1}}: Unexpected result shape: {{result}}\")\n        except Exception as e:\n            print(f\"   ✗ ERROR: {{e}}\")\n    ''')\n\n  if with_schema:\n    test_content += textwrap.dedent(f\"\"\"\n        # 4. Test schema creation and application\n        print(\"\\\\n4. Test schema creation and application\")\n        try:\n            from langextract_{package_name}.schema import {provider_name}Schema\n            from langextract import data\n\n            examples = [\n                data.ExampleData(\n                    text=\"Test text\",\n                    extractions=[\n                        data.Extraction(\n                            extraction_class=\"entity\",\n                            extraction_text=\"test\",\n                            attributes={{\"type\": \"example\"}}\n                        )\n                    ]\n                )\n            ]\n\n            schema = {provider_name}Schema.from_examples(examples)\n            print(f\"   ✓ Schema created (keys={{list(schema.schema_dict.keys())}})\")\n\n            schema_class = {provider_cls_name}.get_schema_class()\n            print(f\"   ✓ Provider schema class: {{schema_class.__name__}}\")\n\n            provider = {provider_cls_name}(model_id=_example_id(PATTERNS[0]) if PATTERNS else \"test-model\")\n            provider.apply_schema(schema)\n            print(f\"   ✓ Schema applied: response_schema={{provider.response_schema is not None}} structured={{getattr(provider, 'structured_output', False)}}\")\n        except Exception as e:\n            print(f\"   ✗ ERROR: {{e}}\")\n        \"\"\")\n\n  test_content += textwrap.dedent(f\"\"\"\n        # 5. Test factory integration\n        print(\"\\\\n5. Test factory integration\")\n        try:\n            from langextract import factory\n            config = factory.ModelConfig(\n                model_id=_example_id(PATTERNS[0]) if PATTERNS else \"test-model\",\n                provider=\"{provider_cls_name}\"\n            )\n            model = factory.create_model(config)\n            print(f\"   ✓ Factory created: {{type(model).__name__}}\")\n        except Exception as e:\n            print(f\"   ✗ ERROR: {{e}}\")\n\n        print(\"\\\\n\" + \"-\" * 50)\n        print(\"✅ Testing complete!\")\n        \"\"\")\n\n  (base_dir / \"test_plugin.py\").write_text(test_content, encoding=\"utf-8\")\n  print(\"✓ Created test_plugin.py with comprehensive tests\")\n  print(\"✅ Step 5 complete: Test suite created\")\n\n\ndef create_readme(\n    base_dir: Path, provider_name: str, package_name: str, patterns: list[str]\n) -> None:\n  \"\"\"Create README documentation.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"STEP 6: Documentation\")\n  print(\"=\" * 60)\n\n  def _display(p: str) -> str:\n    \"\"\"Strip leading ^ from pattern for display.\"\"\"\n    return p[1:] if p.startswith(\"^\") else p\n\n  env_var_safe = re.sub(r\"[^A-Z0-9]+\", \"_\", package_name.upper()) + \"_API_KEY\"\n\n  supported = \"\\n\".join(\n      f\"- `{_display(p)}*`: Models matching pattern {p}\" for p in patterns\n  )\n\n  readme_content = textwrap.dedent(f\"\"\"\\\n        # LangExtract {provider_name} Provider\n\nA provider plugin for LangExtract that supports {provider_name} models.\n\n## Installation\n\n```bash\npip install -e .\n```\n\n## Supported Model IDs\n\n{supported}\n\n## Environment Variables\n\n- `{env_var_safe}`: API key for authentication\n\n## Usage\n\n```python\nimport langextract as lx\n\nresult = lx.extract(\n    text=\"Your document here\",\n    model_id=\"{_display(patterns[0]) if patterns else package_name}-model\",\n    prompt_description=\"Extract entities\",\n    examples=[...]\n)\n```\n\n## Development\n\n1. Install in development mode: `pip install -e .`\n2. Run tests: `python test_plugin.py`\n3. Build package: `python -m build`\n4. Publish to PyPI: `twine upload dist/*`\n\n## License\n\nApache License 2.0\n    \"\"\")\n\n  (base_dir / \"README.md\").write_text(readme_content, encoding=\"utf-8\")\n  print(\"✓ Created README.md with usage examples\")\n\n\ndef create_gitignore(base_dir: Path) -> None:\n  \"\"\"Create .gitignore file with Python-specific entries.\"\"\"\n  gitignore_content = textwrap.dedent(\"\"\"\\\n        # Python\n        __pycache__/\n        *.py[cod]\n        *$py.class\n        *.so\n\n        # Distribution / packaging\n        build/\n        dist/\n        *.egg-info/\n        .eggs/\n        *.egg\n\n        # Virtual environments\n        .env\n        .venv\n        env/\n        venv/\n        ENV/\n\n        # Testing & coverage\n        .pytest_cache/\n        .tox/\n        htmlcov/\n        .coverage\n        .coverage.*\n\n        # Type checking\n        .mypy_cache/\n        .dmypy.json\n        dmypy.json\n        .pytype/\n\n        # IDEs\n        .idea/\n        .vscode/\n        *.swp\n        *.swo\n\n        # OS-specific\n        .DS_Store\n        Thumbs.db\n\n        # Logs\n        *.log\n\n        # Temp files\n        *.tmp\n        *.bak\n        *.backup\n    \"\"\")\n\n  (base_dir / \".gitignore\").write_text(gitignore_content, encoding=\"utf-8\")\n  print(\"✓ Created .gitignore file with Python-specific entries\")\n\n\ndef create_license(base_dir: Path) -> None:\n  \"\"\"Create LICENSE file.\"\"\"\n  license_content = textwrap.dedent(\"\"\"\\\n        # LICENSE\n\n        TODO: Add your license here.\n\n        This is a placeholder license file for your provider plugin.\n        Please replace this with your actual license before distribution.\n\n        Common options include:\n        - Apache License 2.0\n        - MIT License\n        - BSD License\n        - GPL License\n        - Proprietary/Commercial License\n    \"\"\")\n\n  (base_dir / \"LICENSE\").write_text(license_content, encoding=\"utf-8\")\n  print(\"✓ Created LICENSE file\")\n  print(\"✅ Step 6 complete: Documentation created\")\n\n\ndef install_and_test(base_dir: Path) -> bool:\n  \"\"\"Install the plugin and run tests.\"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"Installing and testing the plugin...\")\n  print(\"=\" * 60)\n\n  os.chdir(base_dir)\n  print(\"\\nInstalling plugin...\")\n  result = subprocess.run(\n      [sys.executable, \"-m\", \"pip\", \"install\", \"-e\", \".\"],\n      capture_output=True,\n      text=True,\n      check=False,\n  )\n  if result.returncode:\n    print(f\"Installation failed: {result.stderr}\")\n    return False\n  print(\"✓ Plugin installed successfully\")\n\n  print(\"\\nRunning tests...\")\n  result = subprocess.run(\n      [sys.executable, \"test_plugin.py\"],\n      capture_output=True,\n      text=True,\n      check=False,\n  )\n  print(result.stdout)\n  if result.returncode:\n    print(f\"Tests failed: {result.stderr}\")\n    return False\n\n  return True\n\n\ndef parse_arguments():\n  \"\"\"Parse command line arguments.\n\n  Returns:\n    Parsed arguments from argparse.\n  \"\"\"\n  parser = argparse.ArgumentParser(\n      description=\"Create a new LangExtract provider plugin\",\n      formatter_class=argparse.RawDescriptionHelpFormatter,\n      epilog=textwrap.dedent(\"\"\"\n        Examples:\n            python create_provider_plugin.py MyProvider\n            python create_provider_plugin.py MyProvider --with-schema\n            python create_provider_plugin.py MyProvider --patterns \"^mymodel\" \"^custom\"\n            python create_provider_plugin.py MyProvider --package-name my_custom_name\n        \"\"\"),\n  )\n\n  parser.add_argument(\n      \"provider_name\",\n      help=\"Name of your provider (e.g., MyProvider, CustomLLM)\",\n  )\n\n  parser.add_argument(\n      \"--patterns\",\n      nargs=\"+\",\n      default=None,\n      help=\"Regex patterns for model IDs (default: based on provider name)\",\n  )\n\n  parser.add_argument(\n      \"--package-name\",\n      default=None,\n      help=\"Package name (default: lowercase provider name)\",\n  )\n\n  parser.add_argument(\n      \"--with-schema\",\n      action=\"store_true\",\n      help=\"Include schema support (Step 4)\",\n  )\n\n  parser.add_argument(\n      \"--no-install\", action=\"store_true\", help=\"Skip installation and testing\"\n  )\n\n  parser.add_argument(\n      \"--force\",\n      action=\"store_true\",\n      help=\"Overwrite existing plugin directory if it exists\",\n  )\n\n  return parser.parse_args()\n\n\ndef validate_patterns(patterns: list[str]) -> None:\n  \"\"\"Validate regex patterns.\n\n  Args:\n    patterns: List of regex patterns to validate.\n\n  Raises:\n    SystemExit: If any pattern is invalid.\n  \"\"\"\n  for p in patterns:\n    try:\n      re.compile(p)\n    except re.error as e:\n      print(f\"ERROR: Invalid regex pattern '{p}': {e}\")\n      sys.exit(1)\n\n\ndef print_summary(\n    provider_name: str,\n    package_name: str,\n    patterns: list[str],\n    with_schema: bool,\n) -> None:\n  \"\"\"Print configuration summary.\n\n  Args:\n    provider_name: Name of the provider.\n    package_name: Package name.\n    patterns: List of model ID patterns.\n    with_schema: Whether to include schema support.\n  \"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"LANGEXTRACT PROVIDER PLUGIN GENERATOR\")\n  print(\"=\" * 60)\n  print(f\"Provider Name: {provider_name}\")\n  print(f\"Package Name: langextract-{package_name}\")\n  print(f\"Model Patterns: {patterns}\")\n  print(f\"Include Schema: {with_schema}\")\n  print(\"\\nFor documentation, see:\")\n  print(\n      \"https://github.com/google/langextract/blob/main/langextract/providers/README.md\"\n  )\n\n\ndef create_plugin(\n    args: argparse.Namespace, package_name: str, patterns: list[str]\n) -> Path:\n  \"\"\"Create the plugin with all necessary files.\n\n  Args:\n    args: Parsed command line arguments.\n    package_name: Package name.\n    patterns: List of model ID patterns.\n\n  Returns:\n    Path to the created plugin directory.\n  \"\"\"\n  base_dir = create_directory_structure(package_name, force=args.force)\n  create_pyproject_toml(base_dir, args.provider_name, package_name)\n  create_provider(\n      base_dir, args.provider_name, package_name, patterns, args.with_schema\n  )\n\n  if args.with_schema:\n    create_schema(base_dir, args.provider_name, package_name)\n\n  create_test_script(\n      base_dir, args.provider_name, package_name, patterns, args.with_schema\n  )\n  create_readme(base_dir, args.provider_name, package_name, patterns)\n  create_gitignore(base_dir)\n  create_license(base_dir)\n\n  return base_dir\n\n\ndef print_completion_summary(with_schema: bool) -> None:\n  \"\"\"Print completion summary.\n\n  Args:\n    with_schema: Whether schema support was included.\n  \"\"\"\n  print(\"\\n\" + \"=\" * 60)\n  print(\"SUMMARY: Steps 1-6 Completed\")\n  print(\"=\" * 60)\n  print(\"✅ Package structure created\")\n  print(\"✅ Entry point configured\")\n  print(\"✅ Provider implemented\")\n  if with_schema:\n    print(\"✅ Schema support added\")\n  print(\"✅ Tests created\")\n  print(\"✅ Documentation generated\")\n\n\ndef main():\n  \"\"\"Main entry point for the provider plugin generator.\"\"\"\n  args = parse_arguments()\n\n  package_name = args.package_name or args.provider_name.lower()\n  patterns = args.patterns if args.patterns else [f\"^{package_name}\"]\n\n  validate_patterns(patterns)\n  print_summary(args.provider_name, package_name, patterns, args.with_schema)\n\n  base_dir = create_plugin(args, package_name, patterns)\n  print_completion_summary(args.with_schema)\n\n  if not args.no_install:\n    success = install_and_test(base_dir)\n    if success:\n      print(\"\\n✅ Plugin created, installed, and tested successfully!\")\n      print(f\"\\nYour plugin is ready at: {base_dir.absolute()}\")\n      print(\"\\nNext steps:\")\n      print(\"  1. Replace mock inference with actual API calls\")\n      print(\"  2. Update documentation with real examples\")\n      print(\"  3. Build package: python -m build\")\n      print(\"  4. Publish to PyPI: twine upload dist/*\")\n    else:\n      print(\n          \"\\n⚠️ Plugin created but tests failed. Please check the\"\n          \" implementation.\"\n      )\n      sys.exit(1)\n  else:\n    print(f\"\\nPlugin created at: {base_dir.absolute()}\")\n    print(\"\\nTo install and test:\")\n    print(f\"  cd {base_dir}\")\n    print(\"  pip install -e .\")\n    print(\"  python test_plugin.py\")\n\n\nif __name__ == \"__main__\":\n  main()\n"
  },
  {
    "path": "scripts/validate_community_providers.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/usr/bin/env python3\n\"\"\"Validation for COMMUNITY_PROVIDERS.md plugin registry table.\"\"\"\n\nimport os\nfrom pathlib import Path\nimport re\nimport re as regex_module\nimport sys\nfrom typing import Dict, List, Tuple\n\nHEADER_ANCHOR = '| Plugin Name | PyPI Package |'\nEND_MARKER = '<!-- ADD NEW PLUGINS ABOVE THIS LINE -->'\n\n# GitHub username/org and repo patterns\nGH_NAME = r'[-a-zA-Z0-9]+'  # usernames/orgs allow hyphens\nGH_REPO = r'[-a-zA-Z0-9._]+'  # repos allow ., _\nGH_USER_LINK = rf'\\[@{GH_NAME}\\]\\(https://github\\.com/{GH_NAME}\\)'\nGH_MULTI_USER = rf'^{GH_USER_LINK}(,\\s*{GH_USER_LINK})*$'\n\n# Markdown link to a GitHub repo\nGH_REPO_LINK = rf'^\\[[^\\]]+\\]\\(https://github\\.com/{GH_NAME}/{GH_REPO}\\)$'\n\n# Issue link must point to LangExtract repository (issues only)\nLANGEXTRACT_ISSUE_LINK = (\n    r'^\\[[^\\]]+\\]\\(https://github\\.com/google/langextract/issues/\\d+\\)$'\n)\n\n# PEP 503-ish normalized name (loose): lowercase letters/digits with - _ . separators\nPYPI_NORMALIZED = r'`[a-z0-9]([\\-_.]?[a-z0-9]+)*`'\n\nMIN_DESC_LEN = 10\n\n\ndef normalize_pypi(name: str) -> str:\n  \"\"\"PEP 503 normalization for PyPI package names.\"\"\"\n  return regex_module.sub(r'[-_.]+', '-', name.strip().lower())\n\n\ndef find_table_bounds(lines: List[str]) -> Tuple[int, int]:\n  start = end = -1\n  for i, line in enumerate(lines):\n    if HEADER_ANCHOR in line:\n      start = i\n    elif start >= 0 and END_MARKER in line:\n      end = i\n      break\n  return start, end\n\n\ndef parse_row(line: str) -> List[str]:\n  # assumes caller trimmed line\n  parts = [c.strip() for c in line.split('|')[1:-1]]\n  return parts\n\n\ndef validate(filepath: Path) -> bool:\n  errors: List[str] = []\n  warnings: List[str] = []\n\n  content = filepath.read_text(encoding='utf-8')\n  lines = content.splitlines()\n\n  start, end = find_table_bounds(lines)\n  if start < 0:\n    errors.append('Could not find plugin registry table header.')\n    print_report(errors, warnings)\n    return False\n  if end < 0:\n    errors.append(\n        'Could not find end marker: <!-- ADD NEW PLUGINS ABOVE THIS LINE -->.'\n    )\n    print_report(errors, warnings)\n    return False\n\n  rows: List[Dict] = []\n  seen_names = set()\n  seen_pkgs = set()\n\n  for i in range(start + 2, end):\n    raw = lines[i].strip()\n    if not raw:\n      continue\n\n    if not raw.startswith('|') or not raw.endswith('|'):\n      errors.append(\n          f\"Line {i+1}: Not a valid table row (must start and end with '|').\"\n      )\n      continue\n\n    cols = parse_row(raw)\n    if len(cols) != 6:\n      errors.append(f'Line {i+1}: Expected 6 columns, found {len(cols)}.')\n      continue\n\n    plugin, pypi, maint, repo, desc, issue_link = cols\n\n    # Basic presence checks\n    if not plugin:\n      errors.append(f'Line {i+1}: Plugin Name is required.')\n\n    if not re.fullmatch(PYPI_NORMALIZED, pypi):\n      errors.append(\n          f'Line {i+1}: PyPI package must be backticked and normalized (e.g.,'\n          ' `langextract-provider-foo`).'\n      )\n    elif pypi and not pypi.strip('`').lower().startswith('langextract-'):\n      errors.append(\n          f'Line {i+1}: PyPI package should start with `langextract-` for'\n          ' discoverability.'\n      )\n\n    if not re.fullmatch(GH_MULTI_USER, maint):\n      errors.append(\n          f'Line {i+1}: Maintainer must be one or more GitHub handles as links '\n          '(e.g., [@alice](https://github.com/alice) or comma-separated).'\n      )\n\n    if not re.fullmatch(GH_REPO_LINK, repo):\n      errors.append(\n          f'Line {i+1}: GitHub Repo must be a Markdown link to a GitHub'\n          ' repository.'\n      )\n\n    if not desc or len(desc) < MIN_DESC_LEN:\n      errors.append(\n          f'Line {i+1}: Description must be at least {MIN_DESC_LEN} characters.'\n      )\n\n    # Issue link is required and must point to LangExtract repo\n    if not issue_link:\n      errors.append(f'Line {i+1}: Issue Link is required.')\n    elif not re.fullmatch(LANGEXTRACT_ISSUE_LINK, issue_link):\n      errors.append(\n          f'Line {i+1}: Issue Link must point to a LangExtract issue (e.g.,'\n          ' [#123](https://github.com/google/langextract/issues/123)).'\n      )\n\n    rows.append({\n        'line': i + 1,\n        'plugin': plugin,\n        'pypi': pypi.strip('`').lower() if pypi else '',\n    })\n\n  # Duplicate checks (case-insensitive and PEP 503 normalized)\n  for r in rows:\n    pn_key = r['plugin'].strip().casefold()\n    pk_key = normalize_pypi(r['pypi']) if r['pypi'] else None\n\n    if pn_key in seen_names:\n      errors.append(f\"Line {r['line']}: Duplicate Plugin Name '{r['plugin']}'.\")\n    seen_names.add(pn_key)\n\n    if pk_key and pk_key in seen_pkgs:\n      errors.append(f\"Line {r['line']}: Duplicate PyPI Package '{r['pypi']}'.\")\n    if pk_key:\n      seen_pkgs.add(pk_key)\n\n  # Required alphabetical sorting check\n  sorted_by_name = sorted(rows, key=lambda r: r['plugin'].casefold())\n  if [r['plugin'] for r in rows] != [r['plugin'] for r in sorted_by_name]:\n    errors.append('Registry rows must be alphabetically sorted by Plugin Name.')\n\n  # Guardrail: discourage leaving only the example entry\n  if len(rows) == 1 and rows[0]['plugin'].lower().startswith('example'):\n    warnings.append(\n        'The registry currently contains only the example row. Add real'\n        ' providers above the marker.'\n    )\n\n  print_report(errors, warnings)\n  return not errors\n\n\ndef print_report(errors: List[str], warnings: List[str]) -> None:\n  if errors:\n    print('❌ Validation failed:')\n    for e in errors:\n      print(f'  • {e}')\n  if warnings:\n    print('⚠️  Warnings:')\n    for w in warnings:\n      print(f'  • {w}')\n  if not errors and not warnings:\n    print('✅ Table format validation passed!')\n\n\nif __name__ == '__main__':\n  path = Path('COMMUNITY_PROVIDERS.md')\n  if len(sys.argv) > 1:\n    path = Path(sys.argv[1])\n  if not path.exists():\n    print(f'❌ Error: File not found: {path}')\n    sys.exit(1)\n  ok = validate(path)\n  sys.exit(0 if ok else 1)\n"
  },
  {
    "path": "tests/.pylintrc",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Test-specific Pylint configuration\n# Inherits from parent ../.pylintrc and adds test-specific relaxations\n\n[MASTER]\n# Python will merge with parent; no need to repeat plugins.\n\n[MESSAGES CONTROL]\n# Additional disables for test code only\ndisable=\n    # --- Test-specific relaxations ---\n    duplicate-code,           # Test fixtures often have similar patterns\n    too-many-lines,           # Large test files are common\n    missing-module-docstring, # Tests don't need module docs\n    missing-class-docstring,  # Test classes are self-explanatory\n    missing-function-docstring, # Test method names describe intent\n    line-too-long,            # Golden strings and test data\n    invalid-name,             # setUp, tearDown, maxDiff, etc.\n    protected-access,         # Tests often access private members\n    use-dict-literal,         # Parametrized tests benefit from dict()\n    bad-indentation,          # pyink 2-space style conflicts with pylint\n    unused-argument,          # Mock callbacks often have unused args\n    import-error,             # Test dependencies may not be installed\n    too-many-positional-arguments  # Test methods can have many args\n\n[DESIGN]\n# Relax complexity limits for tests\nmax-args = 10                 # Fixtures often take many params\nmax-locals = 25               # Complex test setups\nmax-statements = 75           # Detailed test scenarios\nmax-branches = 15             # Multiple test conditions\n\n[BASIC]\n# Allow common test naming patterns\ngood-names=i,j,k,ex,Run,_,id,ok,fd,fp,maxDiff,setUp,tearDown\n\n# Include test-specific naming patterns\nmethod-rgx=[a-z_][a-z0-9_]{2,50}$|test[A-Z_][a-zA-Z0-9]*$|assert[A-Z][a-zA-Z0-9]*$\n"
  },
  {
    "path": "tests/annotation_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom collections.abc import Sequence\nimport dataclasses\nimport inspect\nimport textwrap\nfrom typing import Type\nfrom unittest import mock\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import annotation\nfrom langextract import prompting\nfrom langextract import resolver as resolver_lib\nfrom langextract.core import data\nfrom langextract.core import exceptions\nfrom langextract.core import tokenizer\nfrom langextract.core import types\nfrom langextract.providers import gemini\n\n\nclass AnnotatorTest(absltest.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    self.mock_language_model = self.enter_context(\n        mock.patch.object(gemini, \"GeminiLanguageModel\", autospec=True)\n    )\n    self.annotator = annotation.Annotator(\n        language_model=self.mock_language_model,\n        prompt_template=prompting.PromptTemplateStructured(description=\"\"),\n    )\n\n  def assert_char_interval_match_source(\n      self, source_text: str, extractions: Sequence[data.Extraction]\n  ):\n    \"\"\"Case-insensitive assertion that char_interval matches source text.\n\n    For each extraction, this function extracts the substring from the source\n    text using the extraction's char_interval and asserts that it matches the\n    extraction's text. Note the Alignment process between tokens is also\n    case-insensitive.\n\n    Args:\n      source_text: The original source text.\n      extractions: A sequence of extractions to check.\n    \"\"\"\n    for extraction in extractions:\n      if extraction.alignment_status == data.AlignmentStatus.MATCH_EXACT:\n        assert (\n            extraction.char_interval is not None\n        ), \"char_interval should not be None for AlignmentStatus.MATCH_EXACT\"\n\n        char_int = extraction.char_interval\n        start = char_int.start_pos\n        end = char_int.end_pos\n        self.assertIsNotNone(start, \"start_pos should not be None\")\n        self.assertIsNotNone(end, \"end_pos should not be None\")\n        extracted = source_text[start:end]\n        self.assertEqual(\n            extracted.lower(),\n            extraction.extraction_text.lower(),\n            f\"Extraction '{extraction.extraction_text}' does not match\"\n            f\" extracted '{extracted}' using char_interval {char_int}\",\n        )\n\n  def test_annotate_text_single_chunk(self):\n    text = (\n        \"Patient Jane Doe, ID 67890, received 10mg of Lisinopril daily for\"\n        \" hypertension diagnosed on 2023-03-15.\"\n    )\n    self.mock_language_model.infer.return_value = [[\n        types.ScoredOutput(\n            score=1.0,\n            output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - patient: \"Jane Doe\"\n                patient_index: 1\n                patient_id: \"67890\"\n                patient_id_index: 4\n                dosage: \"10mg\"\n                dosage_index: 6\n                medication: \"Lisinopril\"\n                medication_index: 8\n                frequency: \"daily\"\n                frequency_index: 9\n                condition: \"hypertension\"\n                condition_index: 11\n                diagnosis_date: \"2023-03-15\"\n                diagnosis_date_index: 13\n              ```\"\"\"),\n        )\n    ]]\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n    expected_annotated_text = data.AnnotatedDocument(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=\"patient\",\n                extraction_index=1,\n                extraction_text=\"Jane Doe\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=1, end_index=3\n                ),\n                char_interval=data.CharInterval(start_pos=8, end_pos=16),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"patient_id\",\n                extraction_index=4,\n                extraction_text=\"67890\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=5, end_index=6\n                ),\n                char_interval=data.CharInterval(start_pos=21, end_pos=26),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"dosage\",\n                extraction_index=6,\n                extraction_text=\"10mg\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=8, end_index=10\n                ),\n                char_interval=data.CharInterval(start_pos=37, end_pos=41),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"medication\",\n                extraction_index=8,\n                extraction_text=\"Lisinopril\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=11, end_index=12\n                ),\n                char_interval=data.CharInterval(start_pos=45, end_pos=55),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"frequency\",\n                extraction_index=9,\n                extraction_text=\"daily\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=12, end_index=13\n                ),\n                char_interval=data.CharInterval(start_pos=56, end_pos=61),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"condition\",\n                extraction_index=11,\n                extraction_text=\"hypertension\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=14, end_index=15\n                ),\n                char_interval=data.CharInterval(start_pos=66, end_pos=78),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"diagnosis_date\",\n                extraction_index=13,\n                extraction_text=\"2023-03-15\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=17, end_index=22\n                ),\n                char_interval=data.CharInterval(start_pos=92, end_pos=102),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n        ],\n    )\n\n    actual_annotated_text = self.annotator.annotate_text(\n        text, resolver=resolver\n    )\n    self.assertDataclassEqual(expected_annotated_text, actual_annotated_text)\n    self.assert_char_interval_match_source(\n        text, actual_annotated_text.extractions\n    )\n    self.mock_language_model.infer.assert_called_once_with(\n        batch_prompts=[f\"\\n\\nQ: {text}\\nA: \"],\n    )\n\n  def test_annotate_text_without_index_suffix(self):\n    text = (\n        \"Patient Jane Doe, ID 67890, received 10mg of Lisinopril daily for\"\n        \" hypertension diagnosed on 2023-03-15.\"\n    )\n    self.mock_language_model.infer.return_value = [[\n        types.ScoredOutput(\n            score=1.0,\n            output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - patient: \"Jane Doe\"\n                patient_id: \"67890\"\n                dosage: \"10mg\"\n                medication: \"Lisinopril\"\n                frequency: \"daily\"\n                condition: \"hypertension\"\n                diagnosis_date: \"2023-03-15\"\n              ```\"\"\"),\n        )\n    ]]\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=None,\n    )\n    expected_annotated_text = data.AnnotatedDocument(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=\"patient\",\n                extraction_index=1,\n                extraction_text=\"Jane Doe\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=1, end_index=3\n                ),\n                char_interval=data.CharInterval(start_pos=8, end_pos=16),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"patient_id\",\n                extraction_index=2,\n                extraction_text=\"67890\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=5, end_index=6\n                ),\n                char_interval=data.CharInterval(start_pos=21, end_pos=26),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"dosage\",\n                extraction_index=3,\n                extraction_text=\"10mg\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=8, end_index=10\n                ),\n                char_interval=data.CharInterval(start_pos=37, end_pos=41),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"medication\",\n                extraction_index=4,\n                extraction_text=\"Lisinopril\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=11, end_index=12\n                ),\n                char_interval=data.CharInterval(start_pos=45, end_pos=55),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"frequency\",\n                extraction_index=5,\n                extraction_text=\"daily\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=12, end_index=13\n                ),\n                char_interval=data.CharInterval(start_pos=56, end_pos=61),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"condition\",\n                extraction_index=6,\n                extraction_text=\"hypertension\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=14, end_index=15\n                ),\n                char_interval=data.CharInterval(start_pos=66, end_pos=78),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"diagnosis_date\",\n                extraction_index=7,\n                extraction_text=\"2023-03-15\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=17, end_index=22\n                ),\n                char_interval=data.CharInterval(start_pos=92, end_pos=102),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n        ],\n    )\n\n    actual_annotated_text = self.annotator.annotate_text(\n        text, resolver=resolver\n    )\n    self.assertDataclassEqual(expected_annotated_text, actual_annotated_text)\n    self.assert_char_interval_match_source(\n        text, actual_annotated_text.extractions\n    )\n    self.mock_language_model.infer.assert_called_once_with(\n        batch_prompts=[f\"\\n\\nQ: {text}\\nA: \"],\n    )\n\n  def test_annotate_text_with_attributes_suffix(self):\n    text = (\n        \"Patient Jane Doe, ID 67890, received 10mg of Lisinopril daily for\"\n        \" hypertension diagnosed on 2023-03-15.\"\n    )\n    self.mock_language_model.infer.return_value = [[\n        types.ScoredOutput(\n            score=1.0,\n            output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - patient: \"Jane Doe\"\n                patient_attributes:\n                  status: \"IDENTIFIABLE\"\n                patient_id: \"67890\"\n                patient_id_attributes:\n                  type: \"UNIQUE_IDENTIFIER\"\n                dosage: \"10mg\"\n                dosage_attributes:\n                  frequency: \"DAILY\"\n                medication: \"Lisinopril\"\n                medication_attributes:\n                  class: \"ANTIHYPERTENSIVE\"\n                frequency: \"daily\"\n                frequency_attributes:\n                  time: \"DAILY\"\n                condition: \"hypertension\"\n                condition_attributes:\n                  type: \"CHRONIC\"\n                diagnosis_date: \"2023-03-15\"\n                diagnosis_date_attributes:\n                  status: \"RELEVANT\"\n              ```\"\"\"),\n        )\n    ]]\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=None,\n        extraction_attributes_suffix=data.ATTRIBUTE_SUFFIX,\n    )\n    expected_annotated_text = data.AnnotatedDocument(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=\"patient\",\n                extraction_index=1,\n                extraction_text=\"Jane Doe\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=1, end_index=3\n                ),\n                char_interval=data.CharInterval(start_pos=8, end_pos=16),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\n                    \"status\": \"IDENTIFIABLE\",\n                },\n            ),\n            data.Extraction(\n                extraction_class=\"patient_id\",\n                extraction_index=2,\n                extraction_text=\"67890\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=5, end_index=6\n                ),\n                char_interval=data.CharInterval(start_pos=21, end_pos=26),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\"type\": \"UNIQUE_IDENTIFIER\"},\n            ),\n            data.Extraction(\n                extraction_class=\"dosage\",\n                extraction_index=3,\n                extraction_text=\"10mg\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=8, end_index=10\n                ),\n                char_interval=data.CharInterval(start_pos=37, end_pos=41),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\"frequency\": \"DAILY\"},\n            ),\n            data.Extraction(\n                extraction_class=\"medication\",\n                extraction_index=4,\n                extraction_text=\"Lisinopril\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=11, end_index=12\n                ),\n                char_interval=data.CharInterval(start_pos=45, end_pos=55),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\"class\": \"ANTIHYPERTENSIVE\"},\n            ),\n            data.Extraction(\n                extraction_class=\"frequency\",\n                extraction_index=5,\n                extraction_text=\"daily\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=12, end_index=13\n                ),\n                char_interval=data.CharInterval(start_pos=56, end_pos=61),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\"time\": \"DAILY\"},\n            ),\n            data.Extraction(\n                extraction_class=\"condition\",\n                extraction_index=6,\n                extraction_text=\"hypertension\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=14, end_index=15\n                ),\n                char_interval=data.CharInterval(start_pos=66, end_pos=78),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\"type\": \"CHRONIC\"},\n            ),\n            data.Extraction(\n                extraction_class=\"diagnosis_date\",\n                extraction_index=7,\n                extraction_text=\"2023-03-15\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=17, end_index=22\n                ),\n                char_interval=data.CharInterval(start_pos=92, end_pos=102),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                attributes={\"status\": \"RELEVANT\"},\n            ),\n        ],\n    )\n\n    actual_annotated_text = self.annotator.annotate_text(\n        text,\n        resolver=resolver,\n    )\n    self.assertDataclassEqual(expected_annotated_text, actual_annotated_text)\n    self.assert_char_interval_match_source(\n        text, actual_annotated_text.extractions\n    )\n    self.mock_language_model.infer.assert_called_once_with(\n        batch_prompts=[f\"\\n\\nQ: {text}\\nA: \"],\n    )\n\n  def test_annotate_text_multiple_chunks(self):\n    self.mock_language_model.infer.side_effect = [\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n                  ```yaml\n                  {data.EXTRACTIONS_KEY}:\n                  - medication: \"Aspirin\"\n                    medication_index: 4\n                    reason: \"headache\"\n                    reason_index: 8\n                  ```\"\"\"),\n            )\n        ]],\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n                  ```yaml\n                  {data.EXTRACTIONS_KEY}:\n                  - condition: \"fever\"\n                    condition_index: 2\n                  ```\"\"\"),\n            )\n        ]],\n    ]\n\n    # Simulating tokenization for text broken into two chunks:\n    # Chunk 1: 'Patient takes one Aspirin for headaches.'\n    # Chunk 2: 'Pt has fever.'\n    text = \"Patient takes one Aspirin for headaches. Pt has fever.\"\n\n    # Indexes Aligned with Tokens\n    # -------------------------------------------------------------------------\n    # Index | 0        1     2    3        4    5         6  7    8    9     10\n    # Token | Patient  takes one  Aspirin  for  headaches .  Pt   has  fever  .\n\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n    expected_annotated_text = data.AnnotatedDocument(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=\"medication\",\n                extraction_index=4,\n                extraction_text=\"Aspirin\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=3, end_index=4\n                ),\n                char_interval=data.CharInterval(start_pos=18, end_pos=25),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n            data.Extraction(\n                extraction_class=\"reason\",\n                extraction_index=8,\n                extraction_text=\"headache\",\n                group_index=0,\n            ),\n            data.Extraction(\n                extraction_class=\"condition\",\n                extraction_index=2,\n                extraction_text=\"fever\",\n                group_index=0,\n                token_interval=tokenizer.TokenInterval(\n                    start_index=9, end_index=10\n                ),\n                char_interval=data.CharInterval(start_pos=48, end_pos=53),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n            ),\n        ],\n    )\n\n    actual_annotated_text = self.annotator.annotate_text(\n        text,\n        max_char_buffer=40,\n        batch_length=1,\n        resolver=resolver,\n        enable_fuzzy_alignment=False,\n    )\n    self.assertDataclassEqual(expected_annotated_text, actual_annotated_text)\n    self.assert_char_interval_match_source(\n        text, actual_annotated_text.extractions\n    )\n    self.mock_language_model.infer.assert_has_calls([\n        mock.call(\n            batch_prompts=[\n                \"\\n\\nQ: Patient takes one Aspirin for headaches.\\nA: \"\n            ],\n            enable_fuzzy_alignment=False,\n        ),\n        mock.call(\n            batch_prompts=[\"\\n\\nQ: Pt has fever.\\nA: \"],\n            enable_fuzzy_alignment=False,\n        ),\n    ])\n\n  def test_annotate_text_no_extractions(self):\n    text = \"Text without extractions.\"\n    self.mock_language_model.infer.return_value = [[\n        types.ScoredOutput(\n            score=1.0,\n            output=textwrap.dedent(f\"\"\"\\\n            ```yaml\n            {data.EXTRACTIONS_KEY}: []\n            ```\"\"\"),\n        )\n    ]]\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n    expected_annotated_text = data.AnnotatedDocument(text=text, extractions=[])\n\n    actual_annotated_text = self.annotator.annotate_text(\n        text, resolver=resolver\n    )\n    self.assertDataclassEqual(expected_annotated_text, actual_annotated_text)\n    self.mock_language_model.infer.assert_called_once_with(\n        batch_prompts=[f\"\\n\\nQ: {text}\\nA: \"],\n    )\n\n\nclass AnnotatorMultipleDocumentTest(parameterized.TestCase):\n\n  _FIXED_DOCUMENT_CONTENT = \"Patient reports migraine.\"\n\n  _LLM_INFERENCE = textwrap.dedent(f\"\"\"\\\n    ```yaml\n    {data.EXTRACTIONS_KEY}:\n    - PATIENT: \"Patient\"\n      PATIENT_index: 0\n    - SYMPTOM: \"migraine\"\n      SYMPTOM_index: 2\n    ```\"\"\")\n\n  _ANNOTATED_DOCUMENT = data.AnnotatedDocument(\n      document_id=\"\",\n      extractions=[\n          data.Extraction(\n              extraction_class=\"PATIENT\",\n              extraction_text=\"Patient\",\n              token_interval=tokenizer.TokenInterval(\n                  start_index=0, end_index=1\n              ),\n              char_interval=data.CharInterval(start_pos=0, end_pos=7),\n              alignment_status=data.AlignmentStatus.MATCH_EXACT,\n              extraction_index=0,\n              group_index=0,\n          ),\n          data.Extraction(\n              extraction_class=\"SYMPTOM\",\n              extraction_text=\"migraine\",\n              token_interval=tokenizer.TokenInterval(\n                  start_index=2, end_index=3\n              ),\n              char_interval=data.CharInterval(start_pos=16, end_pos=24),\n              alignment_status=data.AlignmentStatus.MATCH_EXACT,\n              extraction_index=2,\n              group_index=1,\n          ),\n      ],\n      text=\"Patient reports migraine.\",\n  )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"single_document\",\n          documents=[\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n          ],\n          expected_result=[\n              dataclasses.replace(\n                  _ANNOTATED_DOCUMENT,\n                  document_id=\"doc1\",\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"multiple_documents\",\n          documents=[\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc2\"},\n          ],\n          expected_result=[\n              dataclasses.replace(\n                  _ANNOTATED_DOCUMENT,\n                  document_id=\"doc1\",\n              ),\n              dataclasses.replace(\n                  _ANNOTATED_DOCUMENT,\n                  document_id=\"doc2\",\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"zero_documents\",\n          documents=[],\n          expected_result=[],\n      ),\n      dict(\n          testcase_name=\"multiple_documents_same_batch\",\n          documents=[\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc2\"},\n          ],\n          expected_result=[\n              dataclasses.replace(\n                  _ANNOTATED_DOCUMENT,\n                  document_id=\"doc1\",\n              ),\n              dataclasses.replace(\n                  _ANNOTATED_DOCUMENT,\n                  document_id=\"doc2\",\n              ),\n          ],\n          batch_length=10,\n      ),\n  )\n  def test_annotate_documents(\n      self,\n      documents: Sequence[dict[str, str]],\n      expected_result: Sequence[data.AnnotatedDocument],\n      batch_length: int = 1,\n  ):\n    mock_language_model = self.enter_context(\n        mock.patch.object(gemini, \"GeminiLanguageModel\", autospec=True)\n    )\n\n    # Define a side effect function so return length based on batch length.\n    def mock_infer_side_effect(batch_prompts, **kwargs):\n      for _ in batch_prompts:\n        yield [\n            types.ScoredOutput(\n                score=1.0,\n                output=self._LLM_INFERENCE,\n            )\n        ]\n\n    mock_language_model.infer.side_effect = mock_infer_side_effect\n\n    annotator = annotation.Annotator(\n        language_model=mock_language_model,\n        prompt_template=prompting.PromptTemplateStructured(description=\"\"),\n    )\n\n    document_objects = [\n        data.Document(\n            text=doc[\"text\"],\n            document_id=doc[\"document_id\"],\n        )\n        for doc in documents\n    ]\n    actual_annotations = list(\n        annotator.annotate_documents(\n            document_objects,\n            resolver=resolver_lib.Resolver(\n                fence_output=True,\n                format_type=data.FormatType.YAML,\n                extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n            ),\n            max_char_buffer=200,\n            batch_length=batch_length,\n            debug=False,\n        )\n    )\n\n    self.assertLen(actual_annotations, len(expected_result))\n    for actual_annotation, expected_annotation in zip(\n        actual_annotations, expected_result\n    ):\n      self.assertDataclassEqual(expected_annotation, actual_annotation)\n\n    self.assertGreaterEqual(mock_language_model.infer.call_count, 0)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"same_document_id_contiguous\",\n          documents=[\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n          ],\n          expected_exception=exceptions.InvalidDocumentError,\n      ),\n      dict(\n          testcase_name=\"same_document_id_separated\",\n          documents=[\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc2\"},\n              {\"text\": _FIXED_DOCUMENT_CONTENT, \"document_id\": \"doc1\"},\n          ],\n          expected_exception=exceptions.InvalidDocumentError,\n      ),\n  )\n  def test_annotate_documents_exceptions(\n      self,\n      documents: Sequence[dict[str, str]],\n      expected_exception: Type[exceptions.InvalidDocumentError],\n      batch_length: int = 1,\n  ):\n    mock_language_model = self.enter_context(\n        mock.patch.object(gemini, \"GeminiLanguageModel\", autospec=True)\n    )\n    mock_language_model.infer.return_value = [\n        [\n            types.ScoredOutput(\n                score=1.0,\n                output=self._LLM_INFERENCE,\n            )\n        ]\n    ]\n    annotator = annotation.Annotator(\n        language_model=mock_language_model,\n        prompt_template=prompting.PromptTemplateStructured(description=\"\"),\n    )\n\n    document_objects = [\n        data.Document(text=doc[\"text\"], document_id=doc[\"document_id\"])\n        for doc in documents\n    ]\n\n    with self.assertRaises(expected_exception):\n      list(\n          annotator.annotate_documents(\n              document_objects,\n              max_char_buffer=200,\n              batch_length=batch_length,\n              debug=False,\n          )\n      )\n\n\nclass AnnotatorMultiPassTest(absltest.TestCase):\n  \"\"\"Tests for multi-pass extraction functionality.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self.mock_language_model = self.enter_context(\n        mock.patch.object(gemini, \"GeminiLanguageModel\", autospec=True)\n    )\n    self.annotator = annotation.Annotator(\n        language_model=self.mock_language_model,\n        prompt_template=prompting.PromptTemplateStructured(description=\"\"),\n    )\n\n  def test_multipass_extraction_non_overlapping(self):\n    \"\"\"Test multi-pass extraction with non-overlapping extractions.\"\"\"\n    text = \"Patient John Smith has diabetes and takes insulin daily.\"\n\n    self.mock_language_model.infer.side_effect = [\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - patient: \"John Smith\"\n                patient_index: 1\n              - condition: \"diabetes\"\n                condition_index: 4\n              ```\"\"\"),\n            )\n        ]],\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - medication: \"insulin\"\n                medication_index: 7\n              - frequency: \"daily\"\n                frequency_index: 8\n              ```\"\"\"),\n            )\n        ]],\n    ]\n\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n\n    result = self.annotator.annotate_text(\n        text, resolver=resolver, extraction_passes=2, debug=False\n    )\n\n    self.assertLen(result.extractions, 4)\n    extraction_classes = [e.extraction_class for e in result.extractions]\n    self.assertCountEqual(\n        extraction_classes, [\"patient\", \"condition\", \"medication\", \"frequency\"]\n    )\n\n    self.assertEqual(self.mock_language_model.infer.call_count, 2)\n\n  def test_multipass_extraction_overlapping(self):\n    \"\"\"Test multi-pass extraction with overlapping extractions (first pass wins).\"\"\"\n    text = \"Dr. Smith prescribed aspirin.\"\n\n    # Mock overlapping extractions - both passes find \"Smith\" but differently\n    self.mock_language_model.infer.side_effect = [\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - doctor: \"Dr. Smith\"\n                doctor_index: 0\n              ```\"\"\"),\n            )\n        ]],\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - patient: \"Smith\"\n                patient_index: 1\n              - medication: \"aspirin\"\n                medication_index: 2\n              ```\"\"\"),\n            )\n        ]],\n    ]\n\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n\n    result = self.annotator.annotate_text(\n        text, resolver=resolver, extraction_passes=2, debug=False\n    )\n\n    self.assertLen(result.extractions, 2)\n    extraction_classes = [e.extraction_class for e in result.extractions]\n    self.assertCountEqual(extraction_classes, [\"doctor\", \"medication\"])\n\n    # Verify \"Dr. Smith\" from first pass is kept, not \"Smith\" from second pass\n    doctor_extraction = next(\n        e for e in result.extractions if e.extraction_class == \"doctor\"\n    )\n    self.assertEqual(doctor_extraction.extraction_text, \"Dr. Smith\")\n\n  def test_multipass_extraction_single_pass(self):\n    \"\"\"Test that extraction_passes=1 behaves like normal single-pass extraction.\"\"\"\n    text = \"Patient has fever.\"\n\n    self.mock_language_model.infer.return_value = [[\n        types.ScoredOutput(\n            score=1.0,\n            output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - patient: \"Patient\"\n                patient_index: 0\n              - condition: \"fever\"\n                condition_index: 2\n              ```\"\"\"),\n        )\n    ]]\n\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n\n    result = self.annotator.annotate_text(\n        text, resolver=resolver, extraction_passes=1, debug=False  # Single pass\n    )\n\n    self.assertLen(result.extractions, 2)\n    self.assertEqual(self.mock_language_model.infer.call_count, 1)\n\n  def test_multipass_extraction_empty_passes(self):\n    \"\"\"Test multi-pass extraction when some passes return no extractions.\"\"\"\n    text = \"Test text.\"\n\n    self.mock_language_model.infer.side_effect = [\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - test: \"Test\"\n                test_index: 0\n              ```\"\"\"),\n            )\n        ]],\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n              ```yaml\n              {data.EXTRACTIONS_KEY}: []\n              ```\"\"\"),\n            )\n        ]],\n    ]\n\n    resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.YAML,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n\n    result = self.annotator.annotate_text(\n        text, resolver=resolver, extraction_passes=2, debug=False\n    )\n\n    self.assertLen(result.extractions, 1)\n    self.assertEqual(result.extractions[0].extraction_class, \"test\")\n\n\nclass MultiPassHelperFunctionsTest(parameterized.TestCase):\n  \"\"\"Tests for multi-pass helper functions.\"\"\"\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"empty_list\",\n          all_extractions=[],\n          expected_count=0,\n          expected_classes=[],\n      ),\n      dict(\n          testcase_name=\"single_pass\",\n          all_extractions=[[\n              data.Extraction(\n                  \"class1\", \"text1\", char_interval=data.CharInterval(0, 5)\n              ),\n              data.Extraction(\n                  \"class2\", \"text2\", char_interval=data.CharInterval(10, 15)\n              ),\n          ]],\n          expected_count=2,\n          expected_classes=[\"class1\", \"class2\"],\n      ),\n      dict(\n          testcase_name=\"non_overlapping_passes\",\n          all_extractions=[\n              [\n                  data.Extraction(\n                      \"class1\", \"text1\", char_interval=data.CharInterval(0, 5)\n                  )\n              ],\n              [\n                  data.Extraction(\n                      \"class2\", \"text2\", char_interval=data.CharInterval(10, 15)\n                  )\n              ],\n          ],\n          expected_count=2,\n          expected_classes=[\"class1\", \"class2\"],\n      ),\n      dict(\n          testcase_name=\"overlapping_passes_first_wins\",\n          all_extractions=[\n              [\n                  data.Extraction(\n                      \"class1\", \"text1\", char_interval=data.CharInterval(0, 10)\n                  )\n              ],\n              [\n                  data.Extraction(\n                      \"class2\", \"text2\", char_interval=data.CharInterval(5, 15)\n                  ),  # Overlaps\n                  data.Extraction(\n                      \"class3\", \"text3\", char_interval=data.CharInterval(20, 25)\n                  ),  # No overlap\n              ],\n          ],\n          expected_count=2,\n          expected_classes=[\n              \"class1\",\n              \"class3\",\n          ],  # class2 excluded due to overlap\n      ),\n  )\n  def test_merge_non_overlapping_extractions(\n      self, all_extractions, expected_count, expected_classes\n  ):\n    \"\"\"Test merging extractions from multiple passes.\"\"\"\n    result = annotation._merge_non_overlapping_extractions(all_extractions)\n\n    self.assertLen(result, expected_count)\n    if expected_classes:\n      extraction_classes = [e.extraction_class for e in result]\n      self.assertCountEqual(extraction_classes, expected_classes)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"overlapping_intervals\",\n          ext1=data.Extraction(\n              \"class1\", \"text1\", char_interval=data.CharInterval(0, 10)\n          ),\n          ext2=data.Extraction(\n              \"class2\", \"text2\", char_interval=data.CharInterval(5, 15)\n          ),\n          expected=True,\n      ),\n      dict(\n          testcase_name=\"non_overlapping_intervals\",\n          ext1=data.Extraction(\n              \"class1\", \"text1\", char_interval=data.CharInterval(0, 5)\n          ),\n          ext2=data.Extraction(\n              \"class2\", \"text2\", char_interval=data.CharInterval(10, 15)\n          ),\n          expected=False,\n      ),\n      dict(\n          testcase_name=\"adjacent_intervals\",\n          ext1=data.Extraction(\n              \"class1\", \"text1\", char_interval=data.CharInterval(0, 5)\n          ),\n          ext2=data.Extraction(\n              \"class2\", \"text2\", char_interval=data.CharInterval(5, 10)\n          ),\n          expected=False,\n      ),\n      dict(\n          testcase_name=\"none_interval_first\",\n          ext1=data.Extraction(\"class1\", \"text1\", char_interval=None),\n          ext2=data.Extraction(\n              \"class2\", \"text2\", char_interval=data.CharInterval(5, 15)\n          ),\n          expected=False,\n      ),\n      dict(\n          testcase_name=\"none_interval_second\",\n          ext1=data.Extraction(\n              \"class1\", \"text1\", char_interval=data.CharInterval(0, 5)\n          ),\n          ext2=data.Extraction(\"class2\", \"text2\", char_interval=None),\n          expected=False,\n      ),\n      dict(\n          testcase_name=\"both_none_intervals\",\n          ext1=data.Extraction(\"class1\", \"text1\", char_interval=None),\n          ext2=data.Extraction(\"class2\", \"text2\", char_interval=None),\n          expected=False,\n      ),\n  )\n  def test_extractions_overlap(self, ext1, ext2, expected):\n    \"\"\"Test overlap detection between extractions.\"\"\"\n    result = annotation._extractions_overlap(ext1, ext2)\n    self.assertEqual(result, expected)\n\n\nclass AnnotateDocumentsGeneratorTest(absltest.TestCase):\n  \"\"\"Tests that annotate_documents uses 'yield from' for proper delegation.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self.mock_language_model = self.enter_context(\n        mock.patch.object(gemini, \"GeminiLanguageModel\", autospec=True)\n    )\n\n    def mock_infer(batch_prompts, **_):\n      \"\"\"Return medication extractions based on prompt content.\"\"\"\n      for prompt in batch_prompts:\n        if \"Ibuprofen\" in prompt:\n          text = textwrap.dedent(f\"\"\"\\\n            ```yaml\n            {data.EXTRACTIONS_KEY}:\n            - medication: \"Ibuprofen\"\n              medication_index: 4\n            ```\"\"\")\n        elif \"Cefazolin\" in prompt:\n          text = textwrap.dedent(f\"\"\"\\\n            ```yaml\n            {data.EXTRACTIONS_KEY}:\n            - medication: \"Cefazolin\"\n              medication_index: 4\n            ```\"\"\")\n        else:\n          text = f\"```yaml\\n{data.EXTRACTIONS_KEY}: []\\n```\"\n        yield [types.ScoredOutput(score=1.0, output=text)]\n\n    self.mock_language_model.infer.side_effect = mock_infer\n\n    self.annotator = annotation.Annotator(\n        language_model=self.mock_language_model,\n        prompt_template=prompting.PromptTemplateStructured(description=\"\"),\n    )\n\n  def test_yields_documents_not_generators(self):\n    \"\"\"Verifies annotate_documents yields AnnotatedDocument, not generators.\"\"\"\n    docs = [\n        data.Document(\n            text=\"Patient took 400 mg PO Ibuprofen q4h for two days.\",\n            document_id=\"doc1\",\n        ),\n        data.Document(\n            text=\"Patient was given 250 mg IV Cefazolin TID for one week.\",\n            document_id=\"doc2\",\n        ),\n    ]\n\n    results = list(\n        self.annotator.annotate_documents(\n            docs,\n            resolver=resolver_lib.Resolver(\n                fence_output=True,\n                format_type=data.FormatType.YAML,\n                extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n            ),\n            show_progress=False,\n            debug=False,\n        )\n    )\n\n    self.assertLen(results, 2)\n    self.assertFalse(\n        any(inspect.isgenerator(item) for item in results),\n        msg=\"Must use 'yield from' to delegate, not 'yield'\",\n    )\n    meds_doc1 = {\n        e.extraction_text\n        for e in results[0].extractions\n        if e.extraction_class == \"medication\"\n    }\n    meds_doc2 = {\n        e.extraction_text\n        for e in results[1].extractions\n        if e.extraction_class == \"medication\"\n    }\n    self.assertIn(\"Ibuprofen\", meds_doc1)\n    self.assertNotIn(\"Cefazolin\", meds_doc1)\n    self.assertIn(\"Cefazolin\", meds_doc2)\n    self.assertNotIn(\"Ibuprofen\", meds_doc2)\n\n\nclass CrossChunkContextTest(absltest.TestCase):\n  \"\"\"Tests for cross-chunk context window feature.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self.mock_language_model = self.enter_context(\n        mock.patch.object(gemini, \"GeminiLanguageModel\", autospec=True)\n    )\n    self.annotator = annotation.Annotator(\n        language_model=self.mock_language_model,\n        prompt_template=prompting.PromptTemplateStructured(description=\"\"),\n    )\n\n  def test_context_window_includes_previous_chunk_text(self):\n    \"\"\"Verifies that context_window_chars passes previous chunk text.\"\"\"\n    # Chunk 1: \"Dr. Sarah Johnson is a cardiologist.\"\n    # Chunk 2: \"She specializes in heart surgery.\"\n    text = (\n        \"Dr. Sarah Johnson is a cardiologist. She specializes in heart surgery.\"\n    )\n    self.mock_language_model.infer.side_effect = [\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n                  ```yaml\n                  {data.EXTRACTIONS_KEY}:\n                  - person: \"Dr. Sarah Johnson\"\n                  ```\"\"\"),\n            )\n        ]],\n        [[\n            types.ScoredOutput(\n                score=1.0,\n                output=textwrap.dedent(f\"\"\"\\\n                  ```yaml\n                  {data.EXTRACTIONS_KEY}:\n                  - specialization: \"heart surgery\"\n                  ```\"\"\"),\n            )\n        ]],\n    ]\n    resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML)\n\n    _ = self.annotator.annotate_text(\n        text,\n        max_char_buffer=40,\n        batch_length=1,\n        resolver=resolver,\n        context_window_chars=30,\n        enable_fuzzy_alignment=False,\n    )\n\n    calls = self.mock_language_model.infer.call_args_list\n    self.assertLen(calls, 2)\n\n    first_prompt = calls[0].kwargs[\"batch_prompts\"][0]\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n    self.assertNotIn(context_prefix, first_prompt)\n\n    second_prompt = calls[1].kwargs[\"batch_prompts\"][0]\n    self.assertIn(context_prefix, second_prompt)\n    self.assertIn(\"cardiologist\", second_prompt)\n\n  def test_no_context_included_when_disabled(self):\n    \"\"\"Verifies that no context is included when context_window_chars=None.\"\"\"\n    text = (\n        \"Dr. Sarah Johnson is a cardiologist. She specializes in heart surgery.\"\n    )\n    self.mock_language_model.infer.side_effect = [\n        [[\n            types.ScoredOutput(\n                score=1.0, output=f\"```yaml\\n{data.EXTRACTIONS_KEY}: []\\n```\"\n            )\n        ]],\n        [[\n            types.ScoredOutput(\n                score=1.0, output=f\"```yaml\\n{data.EXTRACTIONS_KEY}: []\\n```\"\n            )\n        ]],\n    ]\n    resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML)\n\n    _ = self.annotator.annotate_text(\n        text,\n        max_char_buffer=40,\n        batch_length=1,\n        resolver=resolver,\n        context_window_chars=None,  # Disabled\n        enable_fuzzy_alignment=False,\n    )\n\n    calls = self.mock_language_model.infer.call_args_list\n    self.assertLen(calls, 2)\n\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n    first_prompt = calls[0].kwargs[\"batch_prompts\"][0]\n    second_prompt = calls[1].kwargs[\"batch_prompts\"][0]\n\n    self.assertNotIn(context_prefix, first_prompt)\n    self.assertNotIn(context_prefix, second_prompt)\n\n  def test_context_window_per_document_isolation(self):\n    \"\"\"Verifies context is tracked per document, not across documents.\"\"\"\n    docs = [\n        data.Document(text=\"Doc1 chunk1. Doc1 chunk2.\", document_id=\"doc1\"),\n        data.Document(text=\"Doc2 chunk1. Doc2 chunk2.\", document_id=\"doc2\"),\n    ]\n    empty_response = [[\n        types.ScoredOutput(\n            score=1.0, output=f\"```yaml\\n{data.EXTRACTIONS_KEY}: []\\n```\"\n        )\n    ]]\n    self.mock_language_model.infer.side_effect = [\n        empty_response,  # Doc1 chunk1\n        empty_response,  # Doc1 chunk2\n        empty_response,  # Doc2 chunk1\n        empty_response,  # Doc2 chunk2\n    ]\n    resolver = resolver_lib.Resolver(format_type=data.FormatType.YAML)\n\n    _ = list(\n        self.annotator.annotate_documents(\n            docs,\n            resolver=resolver,\n            max_char_buffer=15,\n            batch_length=1,\n            context_window_chars=20,  # Large enough to capture \"Doc1 chunk1.\"\n            show_progress=False,\n        )\n    )\n\n    calls = self.mock_language_model.infer.call_args_list\n    self.assertLen(calls, 4)\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n\n    # Extract prompts in order: doc1_chunk1, doc1_chunk2, doc2_chunk1, doc2_chunk2\n    doc1_chunk1_prompt = calls[0].kwargs[\"batch_prompts\"][0]\n    doc1_chunk2_prompt = calls[1].kwargs[\"batch_prompts\"][0]\n    doc2_chunk1_prompt = calls[2].kwargs[\"batch_prompts\"][0]\n    doc2_chunk2_prompt = calls[3].kwargs[\"batch_prompts\"][0]\n\n    # First chunks of each document should NOT have context prefix\n    self.assertNotIn(context_prefix, doc1_chunk1_prompt)\n    self.assertNotIn(context_prefix, doc2_chunk1_prompt)\n\n    # Second chunks should have context from their own document only\n    self.assertIn(context_prefix, doc1_chunk2_prompt)\n    self.assertIn(\"Doc1\", doc1_chunk2_prompt)\n\n    self.assertIn(context_prefix, doc2_chunk2_prompt)\n    self.assertIn(\"Doc2\", doc2_chunk2_prompt)\n\n    # Doc2's chunks should never contain Doc1 content\n    self.assertNotIn(\"Doc1\", doc2_chunk1_prompt)\n    self.assertNotIn(\"Doc1\", doc2_chunk2_prompt)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/chunking_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport textwrap\nfrom unittest import mock\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import chunking\nfrom langextract.core import data\nfrom langextract.core import tokenizer\n\n\nclass SentenceIterTest(absltest.TestCase):\n\n  def test_basic(self):\n    text = \"This is a sentence. This is a longer sentence. Mr. Bond\\nasks\\nwhy?\"\n    tokenized_text = tokenizer.tokenize(text)\n    sentence_iter = chunking.SentenceIterator(tokenized_text)\n    sentence_interval = next(sentence_iter)\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=0, end_index=5), sentence_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, sentence_interval),\n        \"This is a sentence.\",\n    )\n    sentence_interval = next(sentence_iter)\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=5, end_index=11), sentence_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, sentence_interval),\n        \"This is a longer sentence.\",\n    )\n    sentence_interval = next(sentence_iter)\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=11, end_index=17), sentence_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, sentence_interval),\n        \"Mr. Bond\\nasks\\nwhy?\",\n    )\n    with self.assertRaises(StopIteration):\n      next(sentence_iter)\n\n  def test_empty(self):\n    text = \"\"\n    tokenized_text = tokenizer.tokenize(text)\n    sentence_iter = chunking.SentenceIterator(tokenized_text)\n    with self.assertRaises(StopIteration):\n      next(sentence_iter)\n\n\nclass ChunkIteratorTest(absltest.TestCase):\n\n  def test_multi_sentence_chunk(self):\n    text = \"This is a sentence. This is a longer sentence. Mr. Bond\\nasks\\nwhy?\"\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=50,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=0, end_index=11), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"This is a sentence. This is a longer sentence.\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=11, end_index=17), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"Mr. Bond\\nasks\\nwhy?\",\n    )\n    with self.assertRaises(StopIteration):\n      next(chunk_iter)\n\n  def test_sentence_with_multiple_newlines_and_right_interval(self):\n    text = (\n        \"This is a sentence\\n\\n\"\n        + \"This is a longer sentence\\n\\n\"\n        + \"Mr\\n\\nBond\\n\\nasks why?\"\n    )\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_interval = tokenizer.TokenInterval(\n        start_index=0, end_index=len(tokenized_text.tokens)\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        text,\n    )\n\n  def test_break_sentence(self):\n    text = \"This is a sentence. This is a longer sentence. Mr. Bond\\nasks\\nwhy?\"\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=12,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=0, end_index=3), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"This is a\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=3, end_index=5), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"sentence.\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=5, end_index=8), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"This is a\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=8, end_index=9), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"longer\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=9, end_index=11), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"sentence.\",\n    )\n    for _ in range(2):\n      next(chunk_iter)\n    with self.assertRaises(StopIteration):\n      next(chunk_iter)\n\n  def test_long_token_gets_own_chunk(self):\n    text = \"This is a sentence. This is a longer sentence. Mr. Bond\\nasks\\nwhy?\"\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=7,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=0, end_index=2), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"This is\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=2, end_index=3), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval), \"a\"\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=3, end_index=4), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval),\n        \"sentence\",\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(start_index=4, end_index=5), chunk_interval\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval), \".\"\n    )\n    for _ in range(9):\n      next(chunk_iter)\n    with self.assertRaises(StopIteration):\n      next(chunk_iter)\n\n  def test_newline_at_chunk_boundary_does_not_create_empty_interval(self):\n    \"\"\"Test that newlines at chunk boundaries don't create empty token intervals.\n\n    When a newline occurs exactly at a chunk boundary, the chunking algorithm\n    should not attempt to create an empty interval (where start_index == end_index).\n    This was causing a ValueError in create_token_interval().\n    \"\"\"\n    text = \"First sentence.\\nSecond sentence that is longer.\\nThird sentence.\"\n    tokenized_text = tokenizer.tokenize(text)\n\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=20,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    chunks = list(chunk_iter)\n\n    for chunk in chunks:\n      self.assertLess(\n          chunk.token_interval.start_index,\n          chunk.token_interval.end_index,\n          \"Chunk should have non-empty interval\",\n      )\n\n    expected_intervals = [(0, 3), (3, 6), (6, 9), (9, 12)]\n    actual_intervals = [\n        (chunk.token_interval.start_index, chunk.token_interval.end_index)\n        for chunk in chunks\n    ]\n    self.assertEqual(actual_intervals, expected_intervals)\n\n  def test_chunk_unicode_text(self):\n    text = textwrap.dedent(\"\"\"\\\n    Chief Complaint:\n    ‘swelling of tongue and difficulty breathing and swallowing’\n    History of Present Illness:\n    77 y o woman in NAD with a h/o CAD, DM2, asthma and HTN on altace.\"\"\")\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=200,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    chunk_interval = next(chunk_iter).token_interval\n    self.assertEqual(\n        tokenizer.TokenInterval(\n            start_index=0, end_index=len(tokenized_text.tokens)\n        ),\n        chunk_interval,\n    )\n    self.assertEqual(\n        chunking.get_token_interval_text(tokenized_text, chunk_interval), text\n    )\n\n  def test_newlines_is_secondary_sentence_break(self):\n    text = textwrap.dedent(\"\"\"\\\n    Medications:\n    Theophyline (Uniphyl) 600 mg qhs – bronchodilator by increasing cAMP used\n    for treating asthma\n    Diltiazem 300 mg qhs – Ca channel blocker used to control hypertension\n    Simvistatin (Zocor) 20 mg qhs- HMGCo Reductase inhibitor for\n    hypercholesterolemia\n    Ramipril (Altace) 10 mg BID – ACEI for hypertension and diabetes for\n    renal protective effect\"\"\")\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=200,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n\n    first_chunk = next(chunk_iter)\n    expected_first_chunk_text = textwrap.dedent(\"\"\"\\\n    Medications:\n    Theophyline (Uniphyl) 600 mg qhs – bronchodilator by increasing cAMP used\n    for treating asthma\n    Diltiazem 300 mg qhs – Ca channel blocker used to control hypertension\"\"\")\n    self.assertEqual(\n        chunking.get_token_interval_text(\n            tokenized_text, first_chunk.token_interval\n        ),\n        expected_first_chunk_text,\n    )\n\n    self.assertGreater(\n        first_chunk.token_interval.end_index,\n        first_chunk.token_interval.start_index,\n    )\n\n    second_chunk = next(chunk_iter)\n    expected_second_chunk_text = textwrap.dedent(\"\"\"\\\n    Simvistatin (Zocor) 20 mg qhs- HMGCo Reductase inhibitor for\n    hypercholesterolemia\n    Ramipril (Altace) 10 mg BID – ACEI for hypertension and diabetes for\n    renal protective effect\"\"\")\n    self.assertEqual(\n        chunking.get_token_interval_text(\n            tokenized_text, second_chunk.token_interval\n        ),\n        expected_second_chunk_text,\n    )\n\n    with self.assertRaises(StopIteration):\n      next(chunk_iter)\n\n  def test_tokenizer_propagation(self):\n    \"\"\"Test that tokenizer is correctly propagated to TextChunk's Document.\"\"\"\n    text = \"Some text.\"\n    mock_tokenizer = mock.Mock(spec=tokenizer.Tokenizer)\n    mock_tokens = [\n        tokenizer.Token(\n            index=0,\n            token_type=tokenizer.TokenType.WORD,\n            char_interval=data.CharInterval(start_pos=0, end_pos=4),\n        ),\n        tokenizer.Token(\n            index=1,\n            token_type=tokenizer.TokenType.WORD,\n            char_interval=data.CharInterval(start_pos=5, end_pos=9),\n        ),\n        tokenizer.Token(\n            index=2,\n            token_type=tokenizer.TokenType.PUNCTUATION,\n            char_interval=data.CharInterval(start_pos=9, end_pos=10),\n        ),\n    ]\n    mock_tokenized_text = tokenizer.TokenizedText(text=text, tokens=mock_tokens)\n    mock_tokenizer.tokenize.return_value = mock_tokenized_text\n\n    chunk_iter = chunking.ChunkIterator(\n        text=text, max_char_buffer=100, tokenizer_impl=mock_tokenizer\n    )\n    text_chunk = next(chunk_iter)\n\n    self.assertEqual(text_chunk.document_text, mock_tokenized_text)\n    self.assertEqual(text_chunk.chunk_text, text)\n\n\nclass BatchingTest(parameterized.TestCase):\n\n  _SAMPLE_DOCUMENT = data.Document(\n      text=(\n          \"Sample text with numerical values such as 120/80 mmHg, 98.6°F, and\"\n          \" 50mg.\"\n      ),\n  )\n\n  @parameterized.named_parameters(\n      (\n          \"test_with_data\",\n          _SAMPLE_DOCUMENT.tokenized_text,\n          15,\n          10,\n          [[\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=0, end_index=1\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=1, end_index=3\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=3, end_index=4\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=4, end_index=5\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=5, end_index=7\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=7, end_index=10\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=10, end_index=14\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=14, end_index=19\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n              chunking.TextChunk(\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=19, end_index=22\n                  ),\n                  document=_SAMPLE_DOCUMENT,\n              ),\n          ]],\n      ),\n      (\n          \"test_empty_input\",\n          \"\",\n          15,\n          10,\n          [],\n      ),\n  )\n  def test_make_batches_of_textchunk(\n      self,\n      tokenized_text: tokenizer.TokenizedText,\n      batch_length: int,\n      max_char_buffer: int,\n      expected_batches: list[list[chunking.TextChunk]],\n  ):\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    batches_iter = chunking.make_batches_of_textchunk(chunk_iter, batch_length)\n    actual_batches = [list(batch) for batch in batches_iter]\n\n    self.assertListEqual(\n        actual_batches,\n        expected_batches,\n        \"Batched chunks should match expected structure\",\n    )\n\n\nclass TextChunkTest(absltest.TestCase):\n\n  def test_string_output(self):\n    text = \"Example input text.\"\n    expected = textwrap.dedent(\"\"\"\\\n    TextChunk(\n      interval=[start_index: 0, end_index: 1],\n      Document ID: test_doc_123,\n      Chunk Text: 'Example'\n    )\"\"\")\n    document = data.Document(text=text, document_id=\"test_doc_123\")\n    tokenized_text = tokenizer.tokenize(text)\n    chunk_iter = chunking.ChunkIterator(\n        tokenized_text,\n        max_char_buffer=7,\n        document=document,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    text_chunk = next(chunk_iter)\n    self.assertEqual(str(text_chunk), expected)\n\n\nclass TextAdditionalContextTest(absltest.TestCase):\n\n  _ADDITIONAL_CONTEXT = \"Some additional context for prompt...\"\n\n  def test_text_chunk_additional_context(self):\n    document = data.Document(\n        text=\"Sample text.\", additional_context=self._ADDITIONAL_CONTEXT\n    )\n    chunk_iter = chunking.ChunkIterator(\n        text=document.tokenized_text,\n        max_char_buffer=100,\n        document=document,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    text_chunk = next(chunk_iter)\n    self.assertEqual(text_chunk.additional_context, self._ADDITIONAL_CONTEXT)\n\n  def test_chunk_iterator_without_additional_context(self):\n    document = data.Document(text=\"Sample text.\")\n    chunk_iter = chunking.ChunkIterator(\n        text=document.tokenized_text,\n        max_char_buffer=100,\n        document=document,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    text_chunk = next(chunk_iter)\n    self.assertIsNone(text_chunk.additional_context)\n\n  def test_multiple_chunks_with_additional_context(self):\n    text = \"Sentence one. Sentence two. Sentence three.\"\n    document = data.Document(\n        text=text, additional_context=self._ADDITIONAL_CONTEXT\n    )\n    chunk_iter = chunking.ChunkIterator(\n        text=document.tokenized_text,\n        max_char_buffer=15,\n        document=document,\n        tokenizer_impl=tokenizer.RegexTokenizer(),\n    )\n    chunks = list(chunk_iter)\n    self.assertGreater(\n        len(chunks), 1, \"Should create multiple chunks with small buffer\"\n    )\n    additional_contexts = [chunk.additional_context for chunk in chunks]\n    expected_additional_contexts = [self._ADDITIONAL_CONTEXT] * len(chunks)\n    self.assertListEqual(additional_contexts, expected_additional_contexts)\n\n\nclass TextChunkPropertyTest(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      {\n          \"testcase_name\": \"with_document\",\n          \"document\": data.Document(\n              text=\"Sample text.\",\n              document_id=\"doc123\",\n              additional_context=\"Additional info\",\n          ),\n          \"expected_id\": \"doc123\",\n          \"expected_text\": \"Sample text.\",\n          \"expected_context\": \"Additional info\",\n      },\n      {\n          \"testcase_name\": \"no_document\",\n          \"document\": None,\n          \"expected_id\": None,\n          \"expected_text\": None,\n          \"expected_context\": None,\n      },\n      {\n          \"testcase_name\": \"no_additional_context\",\n          \"document\": data.Document(\n              text=\"Sample text.\",\n              document_id=\"doc123\",\n          ),\n          \"expected_id\": \"doc123\",\n          \"expected_text\": \"Sample text.\",\n          \"expected_context\": None,\n      },\n  )\n  def test_text_chunk_properties(\n      self, document, expected_id, expected_text, expected_context\n  ):\n    chunk = chunking.TextChunk(\n        token_interval=tokenizer.TokenInterval(start_index=0, end_index=1),\n        document=document,\n    )\n    self.assertEqual(chunk.document_id, expected_id)\n    if chunk.document_text:\n      self.assertEqual(chunk.document_text.text, expected_text)\n    else:\n      self.assertIsNone(chunk.document_text)\n    self.assertEqual(chunk.additional_context, expected_context)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/data_lib_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport numpy as np\n\nfrom langextract import data_lib\nfrom langextract import io\nfrom langextract.core import data\nfrom langextract.core import tokenizer\n\n\nclass DataLibToDictParameterizedTest(parameterized.TestCase):\n  \"\"\"Tests conversion of AnnotatedDocument objects to JSON dicts.\n\n  Verifies that `annotated_document_to_dict` correctly serializes documents by:\n  - Excluding private fields (e.g., token_interval).\n  - Converting all expected extraction attributes properly.\n  - Handling int64 values for extraction indexes.\n  \"\"\"\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"single_extraction_no_token_interval\",\n          annotated_doc=data.AnnotatedDocument(\n              document_id=\"docA\",\n              text=\"Just a short sentence.\",\n              extractions=[\n                  data.Extraction(\n                      extraction_class=\"note\",\n                      extraction_text=\"short sentence\",\n                      extraction_index=1,\n                      group_index=0,\n                  ),\n              ],\n          ),\n          expected_dict={\n              \"document_id\": \"docA\",\n              \"extractions\": [\n                  {\n                      \"extraction_class\": \"note\",\n                      \"extraction_text\": \"short sentence\",\n                      \"char_interval\": None,\n                      \"alignment_status\": None,\n                      \"extraction_index\": 1,\n                      \"group_index\": 0,\n                      \"description\": None,\n                      \"attributes\": None,\n                  },\n              ],\n              \"text\": \"Just a short sentence.\",\n          },\n      ),\n      dict(\n          testcase_name=\"multiple_extractions_with_token_interval\",\n          annotated_doc=data.AnnotatedDocument(\n              document_id=\"docB\",\n              text=\"Patient Jane reported a headache.\",\n              extractions=[\n                  data.Extraction(\n                      extraction_class=\"patient\",\n                      extraction_text=\"Jane\",\n                      extraction_index=1,\n                      group_index=0,\n                  ),\n                  data.Extraction(\n                      extraction_class=\"symptom\",\n                      extraction_text=\"headache\",\n                      extraction_index=2,\n                      group_index=0,\n                      char_interval=data.CharInterval(start_pos=24, end_pos=32),\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=4, end_index=5\n                      ),  # should be ignored\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  ),\n              ],\n          ),\n          expected_dict={\n              \"document_id\": \"docB\",\n              \"extractions\": [\n                  {\n                      \"extraction_class\": \"patient\",\n                      \"extraction_text\": \"Jane\",\n                      \"char_interval\": None,\n                      \"alignment_status\": None,\n                      \"extraction_index\": 1,\n                      \"group_index\": 0,\n                      \"description\": None,\n                      \"attributes\": None,\n                  },\n                  {\n                      \"extraction_class\": \"symptom\",\n                      \"extraction_text\": \"headache\",\n                      \"char_interval\": {\"start_pos\": 24, \"end_pos\": 32},\n                      \"alignment_status\": \"match_exact\",\n                      \"extraction_index\": 2,\n                      \"group_index\": 0,\n                      \"description\": None,\n                      \"attributes\": None,\n                  },\n              ],\n              \"text\": \"Patient Jane reported a headache.\",\n          },\n      ),\n      dict(\n          testcase_name=\"extraction_with_attributes_and_token_interval\",\n          annotated_doc=data.AnnotatedDocument(\n              document_id=\"docC\",\n              text=\"He has mild chest pain and a cough.\",\n              extractions=[\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"chest pain\",\n                      extraction_index=2,\n                      group_index=1,\n                      attributes={\n                          \"severity\": \"mild\",\n                          \"persistence\": \"persistent\",\n                      },\n                      char_interval=data.CharInterval(start_pos=12, end_pos=22),\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=5\n                      ),  # should be ignored\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  ),\n                  data.Extraction(\n                      extraction_class=\"symptom\",\n                      extraction_text=\"cough\",\n                      extraction_index=3,\n                      group_index=1,\n                  ),\n              ],\n          ),\n          expected_dict={\n              \"document_id\": \"docC\",\n              \"extractions\": [\n                  {\n                      \"extraction_class\": \"condition\",\n                      \"extraction_text\": \"chest pain\",\n                      \"char_interval\": {\"start_pos\": 12, \"end_pos\": 22},\n                      \"alignment_status\": \"match_exact\",\n                      \"extraction_index\": 2,\n                      \"group_index\": 1,\n                      \"description\": None,\n                      \"attributes\": {\n                          \"severity\": \"mild\",\n                          \"persistence\": \"persistent\",\n                      },\n                  },\n                  {\n                      \"extraction_class\": \"symptom\",\n                      \"extraction_text\": \"cough\",\n                      \"char_interval\": None,\n                      \"alignment_status\": None,\n                      \"extraction_index\": 3,\n                      \"group_index\": 1,\n                      \"description\": None,\n                      \"attributes\": None,\n                  },\n              ],\n              \"text\": \"He has mild chest pain and a cough.\",\n          },\n      ),\n  )\n  def test_annotated_document_to_dict(self, annotated_doc, expected_dict):\n    actual_dict = data_lib.annotated_document_to_dict(annotated_doc)\n    self.assertDictEqual(\n        actual_dict,\n        expected_dict,\n        \"annotated_document_to_dict() output differs from expected JSON dict.\",\n    )\n\n  def test_annotated_document_to_dict_with_int64(self):\n    doc = data.AnnotatedDocument(\n        document_id=\"doc_int64\",\n        text=\"Sample text with int64 index\",\n        extractions=[\n            data.Extraction(\n                extraction_class=\"demo_extraction\",\n                extraction_text=\"placeholder\",\n                extraction_index=np.int64(42),  # pytype: disable=wrong-arg-types\n            ),\n        ],\n    )\n\n    doc_dict = data_lib.annotated_document_to_dict(doc)\n\n    json_str = json.dumps(doc_dict, ensure_ascii=False)\n    self.assertIn('\"extraction_index\": 42', json_str)\n\n\nclass IsUrlTest(absltest.TestCase):\n  \"\"\"Tests for io.is_url function validation.\"\"\"\n\n  def test_valid_urls(self):\n    \"\"\"Test that valid URLs are recognized.\"\"\"\n    self.assertTrue(io.is_url(\"http://example.com\"))\n    self.assertTrue(io.is_url(\"https://www.example.com\"))\n    self.assertTrue(io.is_url(\"http://localhost:8080\"))\n    self.assertTrue(io.is_url(\"http://192.168.1.1\"))\n    self.assertTrue(io.is_url(\"http://[2001:db8::1]\"))  # IPv6\n    self.assertTrue(io.is_url(\"http://[::1]:8080\"))  # IPv6 localhost with port\n\n  def test_invalid_urls_with_text(self):\n    \"\"\"Test that URLs with additional text are rejected.\"\"\"\n    # Validates fix for issue where text starting with URL was incorrectly fetched\n    self.assertFalse(io.is_url(\"http://example.com is a website\"))\n    self.assertFalse(io.is_url(\"http://medical-journal.com published a study\"))\n\n  def test_invalid_urls_no_scheme(self):\n    \"\"\"Test that URLs without proper scheme are rejected.\"\"\"\n    self.assertFalse(io.is_url(\"example.com\"))\n    self.assertFalse(io.is_url(\"www.example.com\"))\n    self.assertFalse(io.is_url(\"ftp://example.com\"))\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/extract_precedence_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for parameter precedence in extract().\"\"\"\n\nfrom unittest import mock\n\nfrom absl.testing import absltest\n\nfrom langextract import factory\nimport langextract as lx\nfrom langextract.core import data\nfrom langextract.providers import openai\n\n\nclass ExtractParameterPrecedenceTest(absltest.TestCase):\n  \"\"\"Tests ensuring correct precedence among extract() parameters.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self.examples = [\n        data.ExampleData(\n            text=\"example\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"example\",\n                )\n            ],\n        )\n    ]\n    self.description = \"description\"\n\n  @mock.patch(\"langextract.annotation.Annotator\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_model_overrides_all_other_parameters(\n      self, mock_create_model, mock_annotator_cls\n  ):\n    \"\"\"Test that model parameter overrides all other model-related parameters.\"\"\"\n    provided_model = mock.MagicMock()\n    mock_annotator = mock_annotator_cls.return_value\n    mock_annotator.annotate_text.return_value = \"ok\"\n\n    config = factory.ModelConfig(model_id=\"config-id\")\n\n    result = lx.extract(\n        text_or_documents=\"text\",\n        prompt_description=self.description,\n        examples=self.examples,\n        model=provided_model,\n        config=config,\n        model_id=\"ignored-model\",\n        api_key=\"ignored-key\",\n        language_model_type=openai.OpenAILanguageModel,\n        use_schema_constraints=False,\n    )\n\n    mock_create_model.assert_not_called()\n    _, kwargs = mock_annotator_cls.call_args\n    self.assertIs(kwargs[\"language_model\"], provided_model)\n    self.assertEqual(result, \"ok\")\n\n  @mock.patch(\"langextract.annotation.Annotator\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_config_overrides_model_id_and_language_model_type(\n      self, mock_create_model, mock_annotator_cls\n  ):\n    \"\"\"Test that config parameter overrides model_id and language_model_type.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"config-model\", provider_kwargs={\"api_key\": \"config-key\"}\n    )\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n    mock_annotator = mock_annotator_cls.return_value\n    mock_annotator.annotate_text.return_value = \"ok\"\n\n    with mock.patch(\n        \"langextract.extraction.factory.ModelConfig\"\n    ) as mock_model_config:\n      result = lx.extract(\n          text_or_documents=\"text\",\n          prompt_description=self.description,\n          examples=self.examples,\n          config=config,\n          model_id=\"other-model\",\n          api_key=\"other-key\",\n          language_model_type=openai.OpenAILanguageModel,\n          use_schema_constraints=False,\n      )\n      mock_model_config.assert_not_called()\n\n    mock_create_model.assert_called_once()\n    called_config = mock_create_model.call_args[1][\"config\"]\n    self.assertEqual(called_config.model_id, \"config-model\")\n    self.assertEqual(called_config.provider_kwargs, {\"api_key\": \"config-key\"})\n\n    _, kwargs = mock_annotator_cls.call_args\n    self.assertIs(kwargs[\"language_model\"], mock_model)\n    self.assertEqual(result, \"ok\")\n\n  @mock.patch(\"langextract.annotation.Annotator\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_model_id_and_base_kwargs_override_language_model_type(\n      self, mock_create_model, mock_annotator_cls\n  ):\n    \"\"\"Test that model_id and other kwargs are used when no model or config.\"\"\"\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n    mock_annotator_cls.return_value.annotate_text.return_value = \"ok\"\n    mock_config = mock.MagicMock()\n\n    with mock.patch(\n        \"langextract.extraction.factory.ModelConfig\", return_value=mock_config\n    ) as mock_model_config:\n      with self.assertWarns(FutureWarning):\n        result = lx.extract(\n            text_or_documents=\"text\",\n            prompt_description=self.description,\n            examples=self.examples,\n            model_id=\"model-123\",\n            api_key=\"api-key\",\n            temperature=0.9,\n            model_url=\"http://model\",\n            language_model_type=openai.OpenAILanguageModel,\n            use_schema_constraints=False,\n        )\n\n    mock_model_config.assert_called_once()\n    _, kwargs = mock_model_config.call_args\n    self.assertEqual(kwargs[\"model_id\"], \"model-123\")\n    provider_kwargs = kwargs[\"provider_kwargs\"]\n    self.assertEqual(provider_kwargs[\"api_key\"], \"api-key\")\n    self.assertEqual(provider_kwargs[\"temperature\"], 0.9)\n    self.assertEqual(provider_kwargs[\"model_url\"], \"http://model\")\n    self.assertEqual(provider_kwargs[\"base_url\"], \"http://model\")\n    mock_create_model.assert_called_once()\n    self.assertEqual(result, \"ok\")\n\n  @mock.patch(\"langextract.annotation.Annotator\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_language_model_type_only_emits_warning_and_works(\n      self, mock_create_model, mock_annotator_cls\n  ):\n    \"\"\"Test that language_model_type emits deprecation warning but still works.\"\"\"\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n    mock_annotator_cls.return_value.annotate_text.return_value = \"ok\"\n    mock_config = mock.MagicMock()\n\n    with mock.patch(\n        \"langextract.extraction.factory.ModelConfig\", return_value=mock_config\n    ) as mock_model_config:\n      with self.assertWarns(FutureWarning):\n        result = lx.extract(\n            text_or_documents=\"text\",\n            prompt_description=self.description,\n            examples=self.examples,\n            language_model_type=openai.OpenAILanguageModel,\n            use_schema_constraints=False,\n        )\n\n    mock_model_config.assert_called_once()\n    _, kwargs = mock_model_config.call_args\n    self.assertEqual(kwargs[\"model_id\"], \"gemini-2.5-flash\")\n    mock_create_model.assert_called_once()\n    self.assertEqual(result, \"ok\")\n\n  @mock.patch(\"langextract.annotation.Annotator\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_use_schema_constraints_warns_with_config(\n      self, mock_create_model, mock_annotator_cls\n  ):\n    \"\"\"Test that use_schema_constraints emits warning when used with config.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-2.5-flash\", provider_kwargs={\"api_key\": \"test-key\"}\n    )\n\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n    mock_annotator = mock_annotator_cls.return_value\n    mock_annotator.annotate_text.return_value = \"ok\"\n\n    with self.assertWarns(UserWarning) as cm:\n      result = lx.extract(\n          text_or_documents=\"text\",\n          prompt_description=self.description,\n          examples=self.examples,\n          config=config,\n          use_schema_constraints=True,\n      )\n\n    self.assertIn(\"schema constraints\", str(cm.warning))\n    self.assertIn(\"applied\", str(cm.warning))\n    mock_create_model.assert_called_once()\n    called_config = mock_create_model.call_args[1][\"config\"]\n    self.assertEqual(called_config.model_id, \"gemini-2.5-flash\")\n    self.assertEqual(result, \"ok\")\n\n  @mock.patch(\"langextract.annotation.Annotator\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_use_schema_constraints_warns_with_model(\n      self, mock_create_model, mock_annotator_cls\n  ):\n    \"\"\"Test that use_schema_constraints emits warning when used with model.\"\"\"\n    provided_model = mock.MagicMock()\n    mock_annotator = mock_annotator_cls.return_value\n    mock_annotator.annotate_text.return_value = \"ok\"\n\n    with self.assertWarns(UserWarning) as cm:\n      result = lx.extract(\n          text_or_documents=\"text\",\n          prompt_description=self.description,\n          examples=self.examples,\n          model=provided_model,\n          use_schema_constraints=True,\n      )\n\n    self.assertIn(\"use_schema_constraints\", str(cm.warning))\n    self.assertIn(\"ignored\", str(cm.warning))\n    mock_create_model.assert_not_called()\n    self.assertEqual(result, \"ok\")\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/extract_schema_integration_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Integration tests for extract function with new schema system.\"\"\"\n\nfrom unittest import mock\nimport warnings\n\nfrom absl.testing import absltest\n\nimport langextract as lx\nfrom langextract.core import data\n\n\nclass ExtractSchemaIntegrationTest(absltest.TestCase):\n  \"\"\"Tests for extract function with schema system integration.\"\"\"\n\n  def setUp(self):\n    \"\"\"Set up test fixtures.\"\"\"\n    super().setUp()\n    self.examples = [\n        data.ExampleData(\n            text=\"Patient has diabetes\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"condition\",\n                    extraction_text=\"diabetes\",\n                    attributes={\"severity\": \"moderate\"},\n                )\n            ],\n        )\n    ]\n    self.test_text = \"Patient has hypertension\"\n\n  @mock.patch.dict(\"os.environ\", {\"GEMINI_API_KEY\": \"test_key\"})\n  def test_extract_with_gemini_uses_schema(self):\n    \"\"\"Test that extract with Gemini automatically uses schema.\"\"\"\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      with mock.patch(\n          \"langextract.providers.gemini.GeminiLanguageModel.infer\",\n          return_value=iter([[mock.Mock(output='{\"extractions\": []}')]]),\n      ):\n        with mock.patch(\n            \"langextract.annotation.Annotator.annotate_text\",\n            return_value=data.AnnotatedDocument(\n                text=self.test_text, extractions=[]\n            ),\n        ):\n          result = lx.extract(\n              text_or_documents=self.test_text,\n              prompt_description=\"Extract conditions\",\n              examples=self.examples,\n              model_id=\"gemini-2.5-flash\",\n              use_schema_constraints=True,\n              fence_output=None,  # Let it compute\n          )\n\n          # Should have been called with response_schema\n          call_kwargs = mock_init.call_args[1]\n          self.assertIn(\"response_schema\", call_kwargs)\n\n          # Result should be an AnnotatedDocument\n          self.assertIsInstance(result, data.AnnotatedDocument)\n\n  @mock.patch.dict(\"os.environ\", {\"OLLAMA_BASE_URL\": \"http://localhost:11434\"})\n  def test_extract_with_ollama_uses_json_mode(self):\n    \"\"\"Test that extract with Ollama uses JSON mode.\"\"\"\n    with mock.patch(\n        \"langextract.providers.ollama.OllamaLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      with mock.patch(\n          \"langextract.providers.ollama.OllamaLanguageModel.infer\",\n          return_value=iter([[mock.Mock(output='{\"extractions\": []}')]]),\n      ):\n        with mock.patch(\n            \"langextract.annotation.Annotator.annotate_text\",\n            return_value=data.AnnotatedDocument(\n                text=self.test_text, extractions=[]\n            ),\n        ):\n          result = lx.extract(\n              text_or_documents=self.test_text,\n              prompt_description=\"Extract conditions\",\n              examples=self.examples,\n              model_id=\"gemma2:2b\",\n              use_schema_constraints=True,\n              fence_output=None,  # Let it compute\n          )\n\n          # Should have been called with format=\"json\"\n          call_kwargs = mock_init.call_args[1]\n          self.assertIn(\"format\", call_kwargs)\n          self.assertEqual(call_kwargs[\"format\"], \"json\")\n\n          # Result should be an AnnotatedDocument\n          self.assertIsInstance(result, data.AnnotatedDocument)\n\n  def test_extract_explicit_fence_respected(self):\n    \"\"\"Test that explicit fence_output is respected in extract.\"\"\"\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ):\n      with mock.patch(\n          \"langextract.providers.gemini.GeminiLanguageModel.infer\",\n          return_value=iter([[mock.Mock(output='{\"extractions\": []}')]]),\n      ):\n        with mock.patch(\n            \"langextract.annotation.Annotator.__init__\", return_value=None\n        ) as mock_annotator_init:\n          with mock.patch(\n              \"langextract.annotation.Annotator.annotate_text\",\n              return_value=data.AnnotatedDocument(\n                  text=self.test_text, extractions=[]\n              ),\n          ):\n            _ = lx.extract(\n                text_or_documents=self.test_text,\n                prompt_description=\"Extract conditions\",\n                examples=self.examples,\n                model_id=\"gemini-2.5-flash\",\n                api_key=\"test_key\",\n                use_schema_constraints=True,\n                fence_output=True,  # Explicitly set\n            )\n\n            # Annotator should be created with format_handler that has use_fences=True\n            call_kwargs = mock_annotator_init.call_args[1]\n            self.assertIn(\"format_handler\", call_kwargs)\n            self.assertTrue(call_kwargs[\"format_handler\"].use_fences)\n\n  def test_extract_gemini_schema_deprecation_warning(self):\n    \"\"\"Test that passing gemini_schema triggers deprecation warning.\"\"\"\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ):\n      with mock.patch(\n          \"langextract.providers.gemini.GeminiLanguageModel.infer\",\n          return_value=iter([[mock.Mock(output='{\"extractions\": []}')]]),\n      ):\n        with mock.patch(\n            \"langextract.annotation.Annotator.annotate_text\",\n            return_value=data.AnnotatedDocument(\n                text=self.test_text, extractions=[]\n            ),\n        ):\n          with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n\n            _ = lx.extract(\n                text_or_documents=self.test_text,\n                prompt_description=\"Extract conditions\",\n                examples=self.examples,\n                model_id=\"gemini-2.5-flash\",\n                api_key=\"test_key\",\n                language_model_params={\n                    \"gemini_schema\": \"some_schema\"\n                },  # Deprecated\n            )\n\n            # Should have triggered deprecation warning\n            deprecation_warnings = [\n                warning\n                for warning in w\n                if issubclass(warning.category, FutureWarning)\n                and \"gemini_schema\" in str(warning.message)\n            ]\n            self.assertGreater(len(deprecation_warnings), 0)\n\n  def test_extract_no_schema_when_disabled(self):\n    \"\"\"Test that no schema is used when use_schema_constraints=False.\"\"\"\n    # Create a mock instance with required attributes\n    mock_model = mock.MagicMock()\n    mock_model._schema = None\n    mock_model._fence_output_override = None\n    mock_model.gemini_schema = None\n    mock_model.requires_fence_output = True\n    mock_model.infer.return_value = iter(\n        [[mock.Mock(output='{\"extractions\": []}')]]\n    )\n\n    # Track the kwargs passed to the constructor\n    constructor_kwargs = {}\n\n    def mock_constructor(**kwargs):\n      constructor_kwargs.update(kwargs)\n      return mock_model\n\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel\",\n        side_effect=mock_constructor,\n    ):\n      with mock.patch(\n          \"langextract.annotation.Annotator.annotate_text\",\n          return_value=data.AnnotatedDocument(\n              text=self.test_text, extractions=[]\n          ),\n      ):\n        _ = lx.extract(\n            text_or_documents=self.test_text,\n            prompt_description=\"Extract conditions\",\n            examples=self.examples,\n            model_id=\"gemini-2.5-flash\",\n            api_key=\"test_key\",\n            use_schema_constraints=False,  # Disabled\n        )\n\n        # Should NOT have response_schema when schema constraints are disabled\n        self.assertNotIn(\"response_schema\", constructor_kwargs)\n        self.assertNotIn(\"gemini_schema\", constructor_kwargs)\n\n  @mock.patch(\"langextract.factory.create_model\")\n  def test_validation_triggers_warning_for_gemini(self, mock_create_model):\n    \"\"\"Test that Gemini schema validation triggers warnings.\"\"\"\n\n    # Setup mock model with Gemini schema\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = True\n    mock_model.infer.return_value = [\n        [mock.MagicMock(output='{\"extractions\": []}', score=1.0)]\n    ]\n\n    # Create a mock Gemini schema with validate_format that issues warnings\n    mock_schema = mock.MagicMock()\n\n    def mock_validate_format(format_handler, level=None):\n      # Simulate the warning that would be issued\n      warnings.warn(\n          \"Gemini outputs native JSON via\"\n          \" response_mime_type='application/json'\",\n          UserWarning,\n          stacklevel=3,\n      )\n\n    mock_schema.validate_format = mock_validate_format\n    mock_model.schema = mock_schema\n\n    mock_create_model.return_value = mock_model\n\n    # Run extraction with warnings captured\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter(\"always\")\n\n      result = lx.extract(\n          text_or_documents=\"Sample text\",\n          prompt_description=\"Extract entities\",\n          examples=self.examples,\n          model_id=\"gemini-pro\",\n          api_key=\"test_key\",\n          use_schema_constraints=True,\n      )\n\n      # Check that a warning was issued\n      warning_messages = [str(warning.message) for warning in w]\n      self.assertTrue(\n          any(\"Gemini outputs native JSON\" in msg for msg in warning_messages),\n          f\"Expected Gemini-specific warning not found in: {warning_messages}\",\n      )\n\n    # Result should still be returned\n    self.assertIsNotNone(result)\n\n  @mock.patch(\"langextract.factory.create_model\")\n  def test_no_validation_without_schema(self, mock_create_model):\n    \"\"\"Test that validation is skipped when no schema is present.\"\"\"\n\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = False\n    mock_model.schema = None  # No schema\n    mock_model.infer.return_value = [\n        [mock.MagicMock(output='{\"extractions\": []}', score=1.0)]\n    ]\n\n    mock_create_model.return_value = mock_model\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter(\"always\")\n\n      result = lx.extract(\n          text_or_documents=\"Sample text\",\n          prompt_description=\"Extract\",\n          examples=self.examples,\n          model_id=\"some-model\",\n          api_key=\"key\",\n          use_schema_constraints=False,  # No schema constraints\n      )\n\n      # No format compatibility warnings should be issued\n      warning_messages = [str(warning.message) for warning in w]\n      self.assertFalse(\n          any(\"Format compatibility\" in msg for msg in warning_messages),\n          f\"Unexpected format warning found in: {warning_messages}\",\n      )\n\n    self.assertIsNotNone(result)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/factory_schema_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for factory schema integration and fence defaulting.\"\"\"\n\nfrom unittest import mock\n\nfrom absl.testing import absltest\n\nfrom langextract import factory\nfrom langextract import schema\nfrom langextract.core import base_model\nfrom langextract.core import data\n\n\nclass FactorySchemaIntegrationTest(absltest.TestCase):\n  \"\"\"Tests for create_model_with_schema factory function.\"\"\"\n\n  def setUp(self):\n    \"\"\"Set up test fixtures.\"\"\"\n    super().setUp()\n    self.examples = [\n        data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test_class\",\n                    extraction_text=\"test extraction\",\n                )\n            ],\n        )\n    ]\n\n  @mock.patch.dict(\"os.environ\", {\"GEMINI_API_KEY\": \"test_key\"})\n  def test_gemini_with_schema_returns_false_fence(self):\n    \"\"\"Test that Gemini with schema returns fence_output=False.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-2.5-flash\", provider_kwargs={\"api_key\": \"test_key\"}\n    )\n\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      model = factory._create_model_with_schema(\n          config=config,\n          examples=self.examples,\n          use_schema_constraints=True,\n          fence_output=None,\n      )\n\n      mock_init.assert_called_once()\n      call_kwargs = mock_init.call_args[1]\n      self.assertIn(\"response_schema\", call_kwargs)\n\n      self.assertFalse(model.requires_fence_output)\n\n  @mock.patch.dict(\"os.environ\", {\"OLLAMA_BASE_URL\": \"http://localhost:11434\"})\n  def test_ollama_with_schema_returns_false_fence(self):\n    \"\"\"Test that Ollama with JSON mode returns fence_output=False.\"\"\"\n    config = factory.ModelConfig(model_id=\"gemma2:2b\")\n\n    with mock.patch(\n        \"langextract.providers.ollama.OllamaLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      model = factory._create_model_with_schema(\n          config=config,\n          examples=self.examples,\n          use_schema_constraints=True,\n          fence_output=None,\n      )\n\n      mock_init.assert_called_once()\n      call_kwargs = mock_init.call_args[1]\n      self.assertIn(\"format\", call_kwargs)\n      self.assertEqual(call_kwargs[\"format\"], \"json\")\n\n      self.assertFalse(model.requires_fence_output)\n\n  def test_explicit_fence_output_respected(self):\n    \"\"\"Test that explicit fence_output is not overridden.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-2.5-flash\", provider_kwargs={\"api_key\": \"test_key\"}\n    )\n\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ):\n      model = factory._create_model_with_schema(\n          config=config,\n          examples=self.examples,\n          use_schema_constraints=True,\n          fence_output=True,\n      )\n\n      self.assertTrue(model.requires_fence_output)\n\n  def test_no_schema_defaults_to_true_fence(self):\n    \"\"\"Test that models without schema support default to fence_output=True.\"\"\"\n\n    class NoSchemaModel(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n      def infer(self, batch_prompts, **kwargs):\n        yield []\n\n    config = factory.ModelConfig(model_id=\"test-model\")\n\n    with mock.patch(\n        \"langextract.providers.registry.resolve\", return_value=NoSchemaModel\n    ):\n      with mock.patch.object(NoSchemaModel, \"__init__\", return_value=None):\n        model = factory._create_model_with_schema(\n            config=config,\n            examples=self.examples,\n            use_schema_constraints=True,\n            fence_output=None,\n        )\n\n        self.assertTrue(model.requires_fence_output)\n\n  def test_schema_disabled_returns_true_fence(self):\n    \"\"\"Test that disabling schema constraints returns fence_output=True.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-2.5-flash\", provider_kwargs={\"api_key\": \"test_key\"}\n    )\n\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      model = factory._create_model_with_schema(\n          config=config,\n          examples=self.examples,\n          use_schema_constraints=False,\n          fence_output=None,\n      )\n\n      call_kwargs = mock_init.call_args[1]\n      self.assertNotIn(\"response_schema\", call_kwargs)\n\n      self.assertTrue(model.requires_fence_output)\n\n  def test_caller_overrides_schema_config(self):\n    \"\"\"Test that caller's provider_kwargs override schema configuration.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemma2:2b\",\n        provider_kwargs={\"format\": \"yaml\"},\n    )\n\n    with mock.patch(\n        \"langextract.providers.ollama.OllamaLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      _ = factory._create_model_with_schema(\n          config=config,\n          examples=self.examples,\n          use_schema_constraints=True,\n          fence_output=None,\n      )\n\n      mock_init.assert_called_once()\n      call_kwargs = mock_init.call_args[1]\n      self.assertIn(\"format\", call_kwargs)\n      self.assertEqual(call_kwargs[\"format\"], \"yaml\")\n\n  def test_no_examples_no_schema(self):\n    \"\"\"Test that no examples means no schema is created.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-2.5-flash\", provider_kwargs={\"api_key\": \"test_key\"}\n    )\n\n    with mock.patch(\n        \"langextract.providers.gemini.GeminiLanguageModel.__init__\",\n        return_value=None,\n    ) as mock_init:\n      model = factory._create_model_with_schema(\n          config=config,\n          examples=None,\n          use_schema_constraints=True,\n          fence_output=None,\n      )\n\n      call_kwargs = mock_init.call_args[1]\n      self.assertNotIn(\"response_schema\", call_kwargs)\n\n      self.assertTrue(model.requires_fence_output)\n\n\nclass SchemaApplicationTest(absltest.TestCase):\n  \"\"\"Tests for apply_schema being called on models.\"\"\"\n\n  def test_apply_schema_called_when_supported(self):\n    \"\"\"Test that apply_schema is called on models that support it.\"\"\"\n    examples = [\n        data.ExampleData(\n            text=\"Test\",\n            extractions=[\n                data.Extraction(extraction_class=\"test\", extraction_text=\"test\")\n            ],\n        )\n    ]\n\n    class SchemaAwareModel(base_model.BaseLanguageModel):\n\n      @classmethod\n      def get_schema_class(cls):\n        return schema.GeminiSchema\n\n      def infer(self, batch_prompts, **kwargs):\n        yield []\n\n    config = factory.ModelConfig(model_id=\"test-model\")\n\n    with mock.patch(\n        \"langextract.providers.registry.resolve\", return_value=SchemaAwareModel\n    ):\n      with mock.patch.object(SchemaAwareModel, \"__init__\", return_value=None):\n        with mock.patch.object(SchemaAwareModel, \"apply_schema\") as mock_apply:\n          _ = factory._create_model_with_schema(\n              config=config,\n              examples=examples,\n              use_schema_constraints=True,\n          )\n\n          mock_apply.assert_called_once()\n          schema_arg = mock_apply.call_args[0][0]\n          self.assertIsInstance(schema_arg, schema.GeminiSchema)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/factory_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for the factory module.\n\nNote: This file tests the deprecated registry module which is now an alias\nfor router. The no-name-in-module warning for providers.registry is expected.\n\"\"\"\n# pylint: disable=no-name-in-module\n\nimport os\nfrom unittest import mock\n\nfrom absl.testing import absltest\n\nfrom langextract import exceptions\nfrom langextract import factory\nfrom langextract.core import base_model\nfrom langextract.core import types\nfrom langextract.providers import router\n\n\nclass FakeGeminiProvider(base_model.BaseLanguageModel):\n  \"\"\"Fake Gemini provider for testing.\"\"\"\n\n  def __init__(self, model_id, api_key=None, **kwargs):\n    self.model_id = model_id\n    self.api_key = api_key\n    self.kwargs = kwargs\n    super().__init__()\n\n  def infer(self, batch_prompts, **kwargs):\n    return [[types.ScoredOutput(score=1.0, output=\"gemini\")]]\n\n  def infer_batch(self, prompts, batch_size=32):\n    return self.infer(prompts)\n\n\nclass FakeOpenAIProvider(base_model.BaseLanguageModel):\n  \"\"\"Fake OpenAI provider for testing.\"\"\"\n\n  def __init__(self, model_id, api_key=None, **kwargs):\n    if not api_key:\n      raise ValueError(\"API key required\")\n    self.model_id = model_id\n    self.api_key = api_key\n    self.kwargs = kwargs\n    super().__init__()\n\n  def infer(self, batch_prompts, **kwargs):\n    return [[types.ScoredOutput(score=1.0, output=\"openai\")]]\n\n  def infer_batch(self, prompts, batch_size=32):\n    return self.infer(prompts)\n\n\nclass FactoryTest(absltest.TestCase):  # pylint: disable=too-many-public-methods\n\n  def setUp(self):\n    super().setUp()\n    router.clear()\n    import langextract.providers as providers_module  # pylint: disable=import-outside-toplevel\n\n    providers_module._plugins_loaded = True\n    # Use direct registration for test providers to avoid module path issues\n    router.register(r\"^gemini\", priority=100)(FakeGeminiProvider)\n    router.register(r\"^gpt\", r\"^o1\", priority=100)(FakeOpenAIProvider)\n\n  def tearDown(self):\n    super().tearDown()\n    router.clear()\n    import langextract.providers as providers_module  # pylint: disable=import-outside-toplevel\n\n    providers_module._plugins_loaded = False\n\n  def test_create_model_basic(self):\n    \"\"\"Test basic model creation.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-pro\", provider_kwargs={\"api_key\": \"test-key\"}\n    )\n\n    model = factory.create_model(config)\n    self.assertIsInstance(model, FakeGeminiProvider)\n    self.assertEqual(model.model_id, \"gemini-pro\")\n    self.assertEqual(model.api_key, \"test-key\")\n\n  def test_create_model_from_id(self):\n    \"\"\"Test convenience function for creating model from ID.\"\"\"\n    model = factory.create_model_from_id(\"gemini-flash\", api_key=\"test-key\")\n\n    self.assertIsInstance(model, FakeGeminiProvider)\n    self.assertEqual(model.model_id, \"gemini-flash\")\n    self.assertEqual(model.api_key, \"test-key\")\n\n  @mock.patch.dict(os.environ, {\"GEMINI_API_KEY\": \"env-gemini-key\"})\n  def test_uses_gemini_api_key_from_environment(self):\n    \"\"\"Factory should use GEMINI_API_KEY from environment for Gemini models.\"\"\"\n    config = factory.ModelConfig(model_id=\"gemini-pro\")\n\n    model = factory.create_model(config)\n    self.assertEqual(model.api_key, \"env-gemini-key\")\n\n  @mock.patch.dict(os.environ, {\"OPENAI_API_KEY\": \"env-openai-key\"})\n  def test_uses_openai_api_key_from_environment(self):\n    \"\"\"Factory should use OPENAI_API_KEY from environment for OpenAI models.\"\"\"\n    config = factory.ModelConfig(model_id=\"gpt-4\")\n\n    model = factory.create_model(config)\n    self.assertEqual(model.api_key, \"env-openai-key\")\n\n  @mock.patch.dict(\n      os.environ, {\"LANGEXTRACT_API_KEY\": \"env-langextract-key\"}, clear=True\n  )\n  def test_falls_back_to_langextract_api_key_when_provider_key_missing(self):\n    \"\"\"Factory uses LANGEXTRACT_API_KEY when provider-specific key is missing.\"\"\"\n    config = factory.ModelConfig(model_id=\"gemini-pro\")\n\n    model = factory.create_model(config)\n    self.assertEqual(model.api_key, \"env-langextract-key\")\n\n  @mock.patch.dict(\n      os.environ,\n      {\n          \"GEMINI_API_KEY\": \"gemini-key\",\n          \"LANGEXTRACT_API_KEY\": \"langextract-key\",\n      },\n  )\n  def test_provider_specific_key_takes_priority_over_langextract_key(self):\n    \"\"\"Factory prefers provider-specific API key over LANGEXTRACT_API_KEY.\"\"\"\n    config = factory.ModelConfig(model_id=\"gemini-pro\")\n\n    model = factory.create_model(config)\n    self.assertEqual(model.api_key, \"gemini-key\")\n\n  def test_explicit_kwargs_override_env(self):\n    \"\"\"Test that explicit kwargs override environment variables.\"\"\"\n    with mock.patch.dict(os.environ, {\"GEMINI_API_KEY\": \"env-key\"}):\n      config = factory.ModelConfig(\n          model_id=\"gemini-pro\", provider_kwargs={\"api_key\": \"explicit-key\"}\n      )\n\n      model = factory.create_model(config)\n      self.assertEqual(model.api_key, \"explicit-key\")\n\n  @mock.patch.dict(os.environ, {}, clear=True)\n  def test_wraps_provider_initialization_error_in_inference_config_error(self):\n    \"\"\"Factory should wrap provider errors in InferenceConfigError.\"\"\"\n    config = factory.ModelConfig(model_id=\"gpt-4\")\n\n    with self.assertRaises(exceptions.InferenceConfigError) as cm:\n      factory.create_model(config)\n\n    self.assertIn(\"Failed to create provider\", str(cm.exception))\n    self.assertIn(\"API key required\", str(cm.exception))\n\n  def test_raises_error_when_no_provider_matches_model_id(self):\n    \"\"\"Factory should raise InferenceConfigError for unregistered model IDs.\"\"\"\n    config = factory.ModelConfig(model_id=\"unknown-model\")\n\n    with self.assertRaises(exceptions.InferenceConfigError) as cm:\n      factory.create_model(config)\n\n    self.assertIn(\"No provider registered\", str(cm.exception))\n\n  def test_additional_kwargs_passed_through(self):\n    \"\"\"Test that additional kwargs are passed to provider.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-pro\",\n        provider_kwargs={\n            \"api_key\": \"test-key\",\n            \"temperature\": 0.5,\n            \"max_tokens\": 100,\n            \"custom_param\": \"value\",\n        },\n    )\n\n    model = factory.create_model(config)\n    self.assertEqual(model.kwargs[\"temperature\"], 0.5)\n    self.assertEqual(model.kwargs[\"max_tokens\"], 100)\n    self.assertEqual(model.kwargs[\"custom_param\"], \"value\")\n\n  @mock.patch.dict(os.environ, {\"OLLAMA_BASE_URL\": \"http://custom:11434\"})\n  def test_ollama_uses_base_url_from_environment(self):\n    \"\"\"Factory should use OLLAMA_BASE_URL from environment for Ollama models.\"\"\"\n\n    @router.register(r\"^ollama\")\n    class FakeOllamaProvider(base_model.BaseLanguageModel):  # pylint: disable=unused-variable\n\n      def __init__(self, model_id, base_url=None, **kwargs):\n        self.model_id = model_id\n        self.base_url = base_url\n        super().__init__()\n\n      def infer(self, batch_prompts, **kwargs):\n        return [[types.ScoredOutput(score=1.0, output=\"ollama\")]]\n\n      def infer_batch(self, prompts, batch_size=32):\n        return self.infer(prompts)\n\n    config = factory.ModelConfig(model_id=\"ollama/llama2\")\n    model = factory.create_model(config)\n\n    self.assertEqual(model.base_url, \"http://custom:11434\")\n\n  def test_ollama_models_select_without_api_keys(self):\n    \"\"\"Test that Ollama models resolve without API keys or explicit type.\"\"\"\n\n    @router.register(r\"^llama\", r\"^gemma\", r\"^mistral\", r\"^qwen\", priority=100)\n    class FakeOllamaProvider(base_model.BaseLanguageModel):\n\n      def __init__(self, model_id, **kwargs):\n        self.model_id = model_id\n        super().__init__()\n\n      def infer(self, batch_prompts, **kwargs):\n        return [[types.ScoredOutput(score=1.0, output=\"test\")]]\n\n      def infer_batch(self, prompts, batch_size=32):\n        return self.infer(prompts)\n\n    test_models = [\"llama3\", \"gemma2:2b\", \"mistral:7b\", \"qwen3:0.6b\"]\n\n    for model_id in test_models:\n      with self.subTest(model_id=model_id):\n        with mock.patch.dict(os.environ, {}, clear=True):\n          config = factory.ModelConfig(model_id=model_id)\n          model = factory.create_model(config)\n          self.assertIsInstance(model, FakeOllamaProvider)\n          self.assertEqual(model.model_id, model_id)\n\n  def test_model_config_fields_are_immutable(self):\n    \"\"\"ModelConfig fields should not be modifiable after creation.\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-pro\", provider_kwargs={\"api_key\": \"test\"}\n    )\n\n    with self.assertRaises(AttributeError):\n      config.model_id = \"different\"\n\n  def test_model_config_allows_dict_contents_modification(self):\n    \"\"\"ModelConfig allows modification of dict contents (not deeply frozen).\"\"\"\n    config = factory.ModelConfig(\n        model_id=\"gemini-pro\", provider_kwargs={\"api_key\": \"test\"}\n    )\n\n    config.provider_kwargs[\"new_key\"] = \"value\"\n\n    self.assertEqual(config.provider_kwargs[\"new_key\"], \"value\")\n\n  def test_uses_highest_priority_provider_when_multiple_match(self):\n    \"\"\"Factory uses highest priority provider when multiple patterns match.\"\"\"\n\n    @router.register(r\"^gemini\", priority=90)\n    class AnotherGeminiProvider(base_model.BaseLanguageModel):  # pylint: disable=unused-variable\n\n      def __init__(self, model_id=None, **kwargs):\n        self.model_id = model_id or \"default-model\"\n        self.kwargs = kwargs\n        super().__init__()\n\n      def infer(self, batch_prompts, **kwargs):\n        return [[types.ScoredOutput(score=1.0, output=\"another\")]]\n\n      def infer_batch(self, prompts, batch_size=32):\n        return self.infer(prompts)\n\n    config = factory.ModelConfig(model_id=\"gemini-pro\")\n    model = factory.create_model(config)\n\n    self.assertIsInstance(model, FakeGeminiProvider)  # Priority 100 wins\n\n  def test_explicit_provider_overrides_pattern_matching(self):\n    \"\"\"Factory should use explicit provider even when pattern doesn't match.\"\"\"\n\n    @router.register(r\"^another\", priority=90)\n    class AnotherProvider(base_model.BaseLanguageModel):\n\n      def __init__(self, model_id=None, **kwargs):\n        self.model_id = model_id or \"default-model\"\n        self.kwargs = kwargs\n        super().__init__()\n\n      def infer(self, batch_prompts, **kwargs):\n        return [[types.ScoredOutput(score=1.0, output=\"another\")]]\n\n      def infer_batch(self, prompts, batch_size=32):\n        return self.infer(prompts)\n\n    config = factory.ModelConfig(\n        model_id=\"gemini-pro\", provider=\"AnotherProvider\"\n    )\n    model = factory.create_model(config)\n\n    self.assertIsInstance(model, AnotherProvider)\n    self.assertEqual(model.model_id, \"gemini-pro\")\n\n  def test_provider_without_model_id_uses_provider_default(self):\n    \"\"\"Factory should use provider's default model_id when none specified.\"\"\"\n\n    @router.register(r\"^default-provider$\", priority=50)\n    class DefaultProvider(base_model.BaseLanguageModel):\n\n      def __init__(self, model_id=\"default-model\", **kwargs):\n        self.model_id = model_id\n        self.kwargs = kwargs\n        super().__init__()\n\n      def infer(self, batch_prompts, **kwargs):\n        return [[types.ScoredOutput(score=1.0, output=\"default\")]]\n\n      def infer_batch(self, prompts, batch_size=32):\n        return self.infer(prompts)\n\n    config = factory.ModelConfig(provider=\"DefaultProvider\")\n    model = factory.create_model(config)\n\n    self.assertIsInstance(model, DefaultProvider)\n    self.assertEqual(model.model_id, \"default-model\")\n\n  def test_raises_error_when_neither_model_id_nor_provider_specified(self):\n    \"\"\"Factory raises ValueError when config has neither model_id nor provider.\"\"\"\n    config = factory.ModelConfig()\n\n    with self.assertRaises(ValueError) as cm:\n      factory.create_model(config)\n\n    self.assertIn(\n        \"Either model_id or provider must be specified\", str(cm.exception)\n    )\n\n  def test_gemini_vertexai_parameters_accepted(self):\n    \"\"\"Test that Vertex AI parameters are properly passed to Gemini provider.\"\"\"\n    original_entries = router._entries.copy()  # pylint: disable=protected-access\n    original_keys = router._entry_keys.copy()  # pylint: disable=protected-access\n\n    try:\n\n      @router.register(r\"^gemini\", priority=200)\n      class MockGeminiWithVertexAI(base_model.BaseLanguageModel):  # pylint: disable=unused-variable\n\n        def __init__(\n            self,\n            model_id=\"gemini-2.5-flash\",\n            api_key=None,\n            vertexai=False,\n            credentials=None,\n            project=None,\n            location=None,\n            **kwargs,\n        ):\n          self.model_id = model_id\n          self.api_key = api_key\n          self.vertexai = vertexai\n          self.credentials = credentials\n          self.project = project\n          self.location = location\n          super().__init__()\n\n        def infer(self, batch_prompts, **kwargs):\n          return [[types.ScoredOutput(score=1.0, output=\"vertexai-test\")]]\n\n      config = factory.ModelConfig(\n          model_id=\"gemini-pro\",\n          provider_kwargs={\n              \"vertexai\": True,\n              \"project\": \"test-project\",\n              \"location\": \"us-central1\",\n          },\n      )\n      model = factory.create_model(config)\n\n      self.assertTrue(model.vertexai)\n      self.assertEqual(model.project, \"test-project\")\n      self.assertEqual(model.location, \"us-central1\")\n      self.assertIsNone(model.api_key)\n    finally:\n      router._entries = original_entries  # pylint: disable=protected-access\n      router._entry_keys = original_keys  # pylint: disable=protected-access\n\n  def test_gemini_vertexai_with_credentials(self):\n    \"\"\"Test that Vertex AI credentials can be passed through.\"\"\"\n    original_entries = router._entries.copy()  # pylint: disable=protected-access\n    original_keys = router._entry_keys.copy()  # pylint: disable=protected-access\n\n    try:\n\n      @router.register(r\"^gemini\", priority=200)\n      class MockGeminiWithCredentials(base_model.BaseLanguageModel):  # pylint: disable=unused-variable\n\n        def __init__(\n            self, model_id=\"gemini-2.5-flash\", credentials=None, **kwargs\n        ):\n          self.model_id = model_id\n          self.credentials = credentials\n          super().__init__()\n\n        def infer(self, batch_prompts, **kwargs):\n          return [[types.ScoredOutput(score=1.0, output=\"creds-test\")]]\n\n      mock_credentials = {\"type\": \"service_account\"}  # Simplified mock\n      config = factory.ModelConfig(\n          model_id=\"gemini-2.5-flash\",\n          provider_kwargs={\"credentials\": mock_credentials},\n      )\n      model = factory.create_model(config)\n\n      self.assertEqual(model.credentials, mock_credentials)\n    finally:\n      router._entries = original_entries  # pylint: disable=protected-access\n      router._entry_keys = original_keys  # pylint: disable=protected-access\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/format_handler_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for centralized format handler.\"\"\"\n\nimport textwrap\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import prompting\nfrom langextract import resolver\nfrom langextract.core import data\nfrom langextract.core import format_handler\n\n\nclass FormatHandlerTest(parameterized.TestCase):\n  \"\"\"Tests for FormatHandler.\"\"\"\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"json_with_wrapper_and_fences\",\n          format_type=data.FormatType.JSON,\n          use_wrapper=True,\n          wrapper_key=\"extractions\",\n          use_fences=True,\n          extraction_class=\"person\",\n          extraction_text=\"Alice\",\n          attributes={\"role\": \"engineer\"},\n          expected_fence=\"```json\",\n          expected_wrapper='\"extractions\":',\n          expected_extraction='\"person\": \"Alice\"',\n          model_output=textwrap.dedent(\"\"\"\n              Here is the result:\n              ```json\n              {\n                \"extractions\": [\n                  {\"person\": \"Bob\", \"person_attributes\": {\"role\": \"manager\"}}\n                ]\n              }\n              ```\n          \"\"\").strip(),\n          parsed_class=\"person\",\n          parsed_text=\"Bob\",\n      ),\n      dict(\n          testcase_name=\"json_no_wrapper_no_fences\",\n          format_type=data.FormatType.JSON,\n          use_wrapper=False,\n          wrapper_key=None,\n          use_fences=False,\n          extraction_class=\"item\",\n          extraction_text=\"book\",\n          attributes=None,\n          expected_fence=None,\n          expected_wrapper=None,\n          expected_extraction='\"item\": \"book\"',\n          model_output='[{\"item\": \"pen\", \"item_attributes\": {}}]',\n          parsed_class=\"item\",\n          parsed_text=\"pen\",\n      ),\n      dict(\n          testcase_name=\"yaml_with_wrapper_and_fences\",\n          format_type=data.FormatType.YAML,\n          use_wrapper=True,\n          wrapper_key=\"extractions\",\n          use_fences=True,\n          extraction_class=\"city\",\n          extraction_text=\"Paris\",\n          attributes=None,\n          expected_fence=\"```yaml\",\n          expected_wrapper=\"extractions:\",\n          expected_extraction=\"city: Paris\",\n          model_output=textwrap.dedent(\"\"\"\n              ```yaml\n              extractions:\n                - city: London\n                  city_attributes: {}\n              ```\n          \"\"\").strip(),\n          parsed_class=\"city\",\n          parsed_text=\"London\",\n      ),\n  )\n  def test_format_and_parse(  # pylint: disable=too-many-arguments\n      self,\n      format_type,\n      use_wrapper,\n      wrapper_key,\n      use_fences,\n      extraction_class,\n      extraction_text,\n      attributes,\n      expected_fence,\n      expected_wrapper,\n      expected_extraction,\n      model_output,\n      parsed_class,\n      parsed_text,\n  ):\n    \"\"\"Test formatting and parsing with various configurations.\"\"\"\n    handler = format_handler.FormatHandler(\n        format_type=format_type,\n        use_wrapper=use_wrapper,\n        wrapper_key=wrapper_key,\n        use_fences=use_fences,\n    )\n\n    extractions = [\n        data.Extraction(\n            extraction_class=extraction_class,\n            extraction_text=extraction_text,\n            attributes=attributes,\n        )\n    ]\n\n    formatted = handler.format_extraction_example(extractions)\n\n    if expected_fence:\n      self.assertIn(expected_fence, formatted)\n    else:\n      self.assertNotIn(\"```\", formatted)\n\n    if expected_wrapper:\n      self.assertIn(expected_wrapper, formatted)\n    else:\n      if wrapper_key:\n        self.assertNotIn(wrapper_key, formatted)\n\n    self.assertIn(expected_extraction, formatted)\n\n    parsed = handler.parse_output(model_output)\n    self.assertLen(parsed, 1)\n    self.assertEqual(parsed[0][parsed_class], parsed_text)\n\n  def test_end_to_end_integration_with_prompt_and_resolver(self):\n    \"\"\"Test that FormatHandler unifies prompt generation and parsing.\"\"\"\n    handler = format_handler.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=True,\n    )\n\n    template = prompting.PromptTemplateStructured(\n        description=\"Extract entities from text.\",\n        examples=[\n            data.ExampleData(\n                text=\"Alice is an engineer\",\n                extractions=[\n                    data.Extraction(\n                        extraction_class=\"person\",\n                        extraction_text=\"Alice\",\n                        attributes={\"role\": \"engineer\"},\n                    )\n                ],\n            )\n        ],\n    )\n\n    prompt_gen = prompting.QAPromptGenerator(\n        template=template,\n        format_handler=handler,\n    )\n\n    prompt = prompt_gen.render(\"Bob is a manager\")\n    self.assertIn(\"```json\", prompt, \"Prompt should contain JSON fence\")\n    self.assertIn('\"extractions\":', prompt, \"Prompt should contain wrapper key\")\n\n    test_resolver = resolver.Resolver(\n        format_handler=handler,\n        extraction_index_suffix=None,\n    )\n\n    model_output = textwrap.dedent(\"\"\"\n        ```json\n        {\n          \"extractions\": [\n            {\n              \"person\": \"Bob\",\n              \"person_attributes\": {\"role\": \"manager\"}\n            }\n          ]\n        }\n        ```\n    \"\"\").strip()\n\n    extractions = test_resolver.resolve(model_output)\n    self.assertLen(extractions, 1, \"Should extract exactly one entity\")\n    self.assertEqual(\n        extractions[0].extraction_class,\n        \"person\",\n        \"Extraction class should be 'person'\",\n    )\n    self.assertEqual(\n        extractions[0].extraction_text, \"Bob\", \"Extraction text should be 'Bob'\"\n    )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"yaml_no_wrapper_no_fences\",\n          format_type=data.FormatType.YAML,\n          use_wrapper=False,\n          use_fences=False,\n      ),\n      dict(\n          testcase_name=\"json_with_wrapper_and_fences\",\n          format_type=data.FormatType.JSON,\n          use_wrapper=True,\n          wrapper_key=\"extractions\",\n          use_fences=True,\n      ),\n      dict(\n          testcase_name=\"yaml_with_wrapper_no_fences\",\n          format_type=data.FormatType.YAML,\n          use_wrapper=True,\n          wrapper_key=\"extractions\",\n          use_fences=False,\n      ),\n  )\n  def test_format_parse_roundtrip(\n      self, format_type, use_wrapper, use_fences, wrapper_key=None\n  ):\n    \"\"\"Test that what we format can be parsed back identically.\"\"\"\n    handler = format_handler.FormatHandler(\n        format_type=format_type,\n        use_wrapper=use_wrapper,\n        wrapper_key=wrapper_key,\n        use_fences=use_fences,\n    )\n\n    extractions = [\n        data.Extraction(\n            extraction_class=\"test\",\n            extraction_text=\"value\",\n            attributes={\"key\": \"data\"},\n        )\n    ]\n    formatted = handler.format_extraction_example(extractions)\n\n    parsed = handler.parse_output(formatted)\n    self.assertEqual(parsed[0][\"test\"], \"value\")\n    self.assertEqual(parsed[0][\"test_attributes\"][\"key\"], \"data\")\n\n\nclass NonGeminiModelParsingTest(parameterized.TestCase):\n  \"\"\"Regression tests for non-Gemini model parsing edge cases.\"\"\"\n\n  def test_think_tags_stripped_before_parsing(self):\n    # Reasoning models output <think> tags before JSON\n    handler = format_handler.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=False,\n    )\n    input_with_think = (\n        \"<think>Let me analyze this text...</think>\"\n        '{\"extractions\": [{\"person\": \"Alice\"}]}'\n    )\n    parsed = handler.parse_output(input_with_think)\n    self.assertLen(parsed, 1)\n    self.assertEqual(parsed[0][\"person\"], \"Alice\")\n\n  def test_top_level_list_accepted_as_fallback(self):\n    # Some models return [...] instead of {\"extractions\": [...]}\n    handler = format_handler.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=False,\n    )\n    input_list = '[{\"person\": \"Bob\"}, {\"person\": \"Carol\"}]'\n    parsed = handler.parse_output(input_list)\n    self.assertLen(parsed, 2)\n    self.assertEqual(parsed[0][\"person\"], \"Bob\")\n    self.assertEqual(parsed[1][\"person\"], \"Carol\")\n\n  def test_deepseek_r1_real_output(self):\n    # Real output captured from DeepSeek-R1:1.5b model\n    handler = format_handler.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=False,\n    )\n    deepseek_output = textwrap.dedent(\"\"\"\\\n        <think>\n        Alright, so I need to extract people from the given text.\n        I see John Smith is mentioned as an engineer.\n        </think>\n        {\"extractions\": [{\"person\": \"John Smith\"}]}\"\"\")\n    parsed = handler.parse_output(deepseek_output)\n    self.assertLen(parsed, 1)\n    self.assertEqual(parsed[0][\"person\"], \"John Smith\")\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/inference_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for inference module.\n\nNote: This file contains test helper classes that intentionally have\nfew public methods and define attributes outside __init__. These\npylint warnings are expected for test fixtures.\n\"\"\"\n# pylint: disable=attribute-defined-outside-init\n\nfrom unittest import mock\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import exceptions\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import types\nfrom langextract.providers import gemini\nfrom langextract.providers import ollama\nfrom langextract.providers import openai\n\n\nclass TestBaseLanguageModel(absltest.TestCase):\n\n  def test_merge_kwargs_with_none(self):\n    \"\"\"Test merge_kwargs handles None runtime_kwargs.\"\"\"\n\n    class TestModel(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n      def infer(self, batch_prompts, **kwargs):\n        return iter([])\n\n    model = TestModel()\n    model._extra_kwargs = {\"a\": 1, \"b\": 2}\n\n    result = model.merge_kwargs(None)\n    self.assertEqual(\n        {\"a\": 1, \"b\": 2},\n        result,\n        \"merge_kwargs(None) should return stored kwargs unchanged\",\n    )\n\n    result = model.merge_kwargs({})\n    self.assertEqual(\n        {\"a\": 1, \"b\": 2},\n        result,\n        \"merge_kwargs({}) should return stored kwargs unchanged\",\n    )\n\n    result = model.merge_kwargs({\"b\": 3, \"c\": 4})\n    self.assertEqual(\n        {\"a\": 1, \"b\": 3, \"c\": 4},\n        result,\n        \"Runtime kwargs should override stored kwargs and add new keys\",\n    )\n\n  def test_merge_kwargs_without_extra_kwargs(self):\n    \"\"\"Test merge_kwargs when _extra_kwargs doesn't exist.\"\"\"\n\n    class TestModel(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n      def infer(self, batch_prompts, **kwargs):\n        return iter([])\n\n    model = TestModel()\n    # Intentionally not setting _extra_kwargs to test fallback behavior\n\n    result = model.merge_kwargs({\"a\": 1})\n    self.assertEqual(\n        {\"a\": 1},\n        result,\n        \"merge_kwargs should work even without _extra_kwargs attribute\",\n    )\n\n\nclass TestOllamaLanguageModel(absltest.TestCase):\n\n  @mock.patch(\"langextract.providers.ollama.OllamaLanguageModel._ollama_query\")\n  def test_ollama_infer(self, mock_ollama_query):\n\n    # Real gemma2 response structure from Ollama API for validation\n    gemma_response = {\n        \"model\": \"gemma2:latest\",\n        \"created_at\": \"2025-01-23T22:37:08.579440841Z\",\n        \"response\": \"{'bus' : '**autóbusz**'} \\n\\n\\n  \\n\",\n        \"done\": True,\n        \"done_reason\": \"stop\",\n        \"context\": [\n            106,\n            1645,\n            108,\n            1841,\n            603,\n            1986,\n            575,\n            59672,\n            235336,\n            107,\n            108,\n            106,\n            2516,\n            108,\n            9766,\n            6710,\n            235281,\n            865,\n            664,\n            688,\n            7958,\n            235360,\n            6710,\n            235306,\n            688,\n            12990,\n            235248,\n            110,\n            139,\n            108,\n        ],\n        \"total_duration\": 24038204381,\n        \"load_duration\": 21551375738,\n        \"prompt_eval_count\": 15,\n        \"prompt_eval_duration\": 633000000,\n        \"eval_count\": 17,\n        \"eval_duration\": 1848000000,\n    }\n    mock_ollama_query.return_value = gemma_response\n    model = ollama.OllamaLanguageModel(\n        model_id=\"gemma2:latest\",\n        model_url=\"http://localhost:11434\",\n        structured_output_format=\"json\",\n    )\n    batch_prompts = [\"What is bus in Hungarian?\"]\n    results = list(model.infer(batch_prompts))\n\n    mock_ollama_query.assert_called_once_with(\n        prompt=\"What is bus in Hungarian?\",\n        model=\"gemma2:latest\",\n        structured_output_format=\"json\",\n        model_url=\"http://localhost:11434\",\n    )\n    expected_results = [[\n        types.ScoredOutput(\n            score=1.0, output=\"{'bus' : '**autóbusz**'} \\n\\n\\n  \\n\"\n        )\n    ]]\n    self.assertEqual(results, expected_results)\n\n  @mock.patch(\"requests.post\")\n  def test_ollama_extra_kwargs_passed_to_api(self, mock_post):\n    \"\"\"Verify extra kwargs like timeout and keep_alive are passed to the API.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\n        \"response\": '{\"test\": \"value\"}',\n        \"done\": True,\n    }\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id=\"test-model\",\n        timeout=300,\n        keep_alive=600,\n        num_threads=8,\n    )\n\n    prompts = [\"Test prompt\"]\n    list(model.infer(prompts))\n\n    mock_post.assert_called_once()\n    call_args = mock_post.call_args\n    json_payload = call_args.kwargs[\"json\"]\n\n    self.assertEqual(json_payload[\"options\"][\"keep_alive\"], 600)\n    self.assertEqual(json_payload[\"options\"][\"num_thread\"], 8)\n    # timeout is passed to requests.post, not in the JSON payload\n    self.assertEqual(call_args.kwargs[\"timeout\"], 300)\n\n  @mock.patch(\"requests.post\")\n  def test_ollama_stop_and_top_p_passthrough(self, mock_post):\n    \"\"\"Verify stop and top_p parameters are passed to Ollama API.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\n        \"response\": '{\"test\": \"value\"}',\n        \"done\": True,\n    }\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id=\"test-model\",\n        top_p=0.9,\n        stop=[\"\\\\n\\\\n\", \"END\"],\n    )\n\n    prompts = [\"Test prompt\"]\n    list(model.infer(prompts))\n\n    mock_post.assert_called_once()\n    call_args = mock_post.call_args\n    json_payload = call_args.kwargs[\"json\"]\n\n    # Ollama expects 'stop' at top level, not in options\n    self.assertEqual(json_payload[\"stop\"], [\"\\\\n\\\\n\", \"END\"])\n    self.assertEqual(json_payload[\"options\"][\"top_p\"], 0.9)\n\n  @mock.patch(\"requests.post\")\n  def test_ollama_defaults_when_unspecified(self, mock_post):\n    \"\"\"Verify Ollama uses correct defaults when parameters are not specified.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\n        \"response\": '{\"test\": \"value\"}',\n        \"done\": True,\n    }\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(model_id=\"test-model\")\n\n    prompts = [\"Test prompt\"]\n    list(model.infer(prompts))\n\n    mock_post.assert_called_once()\n    call_args = mock_post.call_args\n    json_payload = call_args.kwargs[\"json\"]\n\n    self.assertEqual(json_payload[\"options\"][\"temperature\"], 0.1)\n    self.assertEqual(json_payload[\"options\"][\"keep_alive\"], 300)\n    self.assertEqual(json_payload[\"options\"][\"num_ctx\"], 2048)\n    self.assertEqual(call_args.kwargs[\"timeout\"], 120)\n\n  @mock.patch(\"requests.post\")\n  def test_ollama_runtime_kwargs_override_stored(self, mock_post):\n    \"\"\"Verify runtime kwargs override stored kwargs.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\n        \"response\": '{\"test\": \"value\"}',\n        \"done\": True,\n    }\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id=\"test-model\",\n        temperature=0.5,\n        keep_alive=300,\n    )\n\n    prompts = [\"Test prompt\"]\n    list(model.infer(prompts, temperature=0.8, keep_alive=600))\n\n    mock_post.assert_called_once()\n    call_args = mock_post.call_args\n    json_payload = call_args.kwargs[\"json\"]\n\n    self.assertEqual(json_payload[\"options\"][\"temperature\"], 0.8)\n    self.assertEqual(json_payload[\"options\"][\"keep_alive\"], 600)\n\n  @mock.patch(\"requests.post\")\n  def test_ollama_temperature_zero(self, mock_post):\n    \"\"\"Test that temperature=0.0 is properly passed to Ollama.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\n        \"response\": '{\"test\": \"value\"}',\n        \"done\": True,\n    }\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id=\"test-model\",\n        temperature=0.0,\n    )\n\n    list(model.infer([\"test prompt\"]))\n\n    mock_post.assert_called_once()\n    call_args = mock_post.call_args\n    json_payload = call_args.kwargs[\"json\"]\n\n    self.assertEqual(json_payload[\"options\"][\"temperature\"], 0.0)\n\n  def test_ollama_default_timeout(self):\n    \"\"\"Test that default timeout is used when not specified.\"\"\"\n    model = ollama.OllamaLanguageModel(\n        model_id=\"test-model\",\n        model_url=\"http://localhost:11434\",\n    )\n\n    mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\"response\": \"test output\"}\n\n    with mock.patch.object(\n        model._requests, \"post\", return_value=mock_response\n    ) as mock_post:\n      model._ollama_query(prompt=\"test prompt\")\n\n      mock_post.assert_called_once()\n      call_kwargs = mock_post.call_args[1]\n      self.assertEqual(\n          120,\n          call_kwargs[\"timeout\"],\n          \"Should use default timeout of 120 seconds\",\n      )\n\n  def test_ollama_timeout_through_infer(self):\n    \"\"\"Test that timeout flows correctly through the infer() method.\"\"\"\n    model = ollama.OllamaLanguageModel(\n        model_id=\"test-model\",\n        model_url=\"http://localhost:11434\",\n        timeout=60,\n    )\n\n    mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n    mock_response.status_code = 200\n    mock_response.json.return_value = {\"response\": \"test output\"}\n\n    with mock.patch.object(\n        model._requests, \"post\", return_value=mock_response\n    ) as mock_post:\n      list(model.infer([\"test prompt\"]))\n\n      mock_post.assert_called_once()\n      call_kwargs = mock_post.call_args[1]\n      self.assertEqual(\n          60,\n          call_kwargs[\"timeout\"],\n          \"Timeout from constructor should flow through infer()\",\n      )\n\n\nclass TestGeminiLanguageModel(absltest.TestCase):\n\n  @mock.patch(\"google.genai.Client\")\n  def test_gemini_allowlist_filtering(self, mock_client_class):\n    \"\"\"Test that only allow-listed keys are passed through.\"\"\"\n    mock_client = mock.Mock()\n    mock_client_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.text = '{\"result\": \"test\"}'\n    mock_client.models.generate_content.return_value = mock_response\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        api_key=\"test-key\",\n        # Allow-listed parameters\n        tools=[\"tool1\", \"tool2\"],\n        stop_sequences=[\"\\n\\n\"],\n        system_instruction=\"Be helpful\",\n        # Unknown parameters to test filtering\n        unknown_param=\"should_be_ignored\",\n        another_unknown=\"also_ignored\",\n    )\n\n    expected_extra_kwargs = {\n        \"tools\": [\"tool1\", \"tool2\"],\n        \"stop_sequences\": [\"\\n\\n\"],\n        \"system_instruction\": \"Be helpful\",\n    }\n    self.assertEqual(\n        expected_extra_kwargs,\n        model._extra_kwargs,\n        \"Only allow-listed kwargs should be stored in _extra_kwargs\",\n    )\n\n    prompts = [\"Test prompt\"]\n    list(model.infer(prompts))\n\n    mock_client.models.generate_content.assert_called_once()\n    call_args = mock_client.models.generate_content.call_args\n    config = call_args.kwargs[\"config\"]\n\n    for key in [\"tools\", \"stop_sequences\", \"system_instruction\"]:\n      self.assertIn(key, config, f\"Expected {key} to be in API config\")\n      self.assertEqual(\n          expected_extra_kwargs[key],\n          config[key],\n          f\"Config value for {key} should match what was provided\",\n      )\n\n  @mock.patch(\"google.genai.Client\")\n  def test_gemini_runtime_kwargs_filtered(self, mock_client_class):\n    \"\"\"Test that runtime kwargs are also filtered by allow-list.\"\"\"\n    mock_client = mock.Mock()\n    mock_client_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.text = '{\"result\": \"test\"}'\n    mock_client.models.generate_content.return_value = mock_response\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        api_key=\"test-key\",\n    )\n\n    prompts = [\"Test prompt\"]\n    list(\n        model.infer(\n            prompts,\n            candidate_count=2,\n            safety_settings={\"HARM_CATEGORY_DANGEROUS\": \"BLOCK_NONE\"},\n            unknown_runtime_param=\"ignored\",\n        )\n    )\n\n    call_args = mock_client.models.generate_content.call_args\n    config = call_args.kwargs[\"config\"]\n\n    self.assertEqual(\n        2,\n        config.get(\"candidate_count\"),\n        \"candidate_count should be passed through to API\",\n    )\n    self.assertEqual(\n        {\"HARM_CATEGORY_DANGEROUS\": \"BLOCK_NONE\"},\n        config.get(\"safety_settings\"),\n        \"safety_settings should be passed through to API\",\n    )\n    self.assertNotIn(\n        \"unknown_runtime_param\", config, \"Unknown kwargs should be filtered out\"\n    )\n\n  def test_gemini_requires_auth_config(self):\n    \"\"\"Test that Gemini requires either API key or Vertex AI config.\"\"\"\n    with self.assertRaises(exceptions.InferenceConfigError) as cm:\n      gemini.GeminiLanguageModel()\n\n    self.assertIn(\"Gemini models require either\", str(cm.exception))\n    self.assertIn(\"API key\", str(cm.exception))\n    self.assertIn(\"Vertex AI\", str(cm.exception))\n\n  def test_gemini_vertexai_requires_project_and_location(self):\n    \"\"\"Test that Vertex AI mode requires both project and location.\"\"\"\n    with self.assertRaises(exceptions.InferenceConfigError) as cm:\n      gemini.GeminiLanguageModel(vertexai=True)\n\n    self.assertIn(\"requires both project and location\", str(cm.exception))\n\n  @mock.patch(\"google.genai.Client\")\n  def test_gemini_vertexai_initialization(self, mock_client_class):\n    \"\"\"Test successful initialization with Vertex AI config.\"\"\"\n    mock_client = mock.Mock()\n    mock_client_class.return_value = mock_client\n\n    model = gemini.GeminiLanguageModel(\n        vertexai=True, project=\"test-project\", location=\"us-central1\"\n    )\n\n    self.assertIsNone(model.api_key)\n    self.assertTrue(model.vertexai)\n    self.assertEqual(model.project, \"test-project\")\n    self.assertEqual(model.location, \"us-central1\")\n    mock_client_class.assert_called_once_with(\n        api_key=None,\n        vertexai=True,\n        credentials=None,\n        project=\"test-project\",\n        location=\"us-central1\",\n        http_options=None,\n    )\n\n  @mock.patch(\"absl.logging.warning\")\n  @mock.patch(\"google.genai.Client\")\n  def test_gemini_warns_when_both_auth_provided(\n      self, mock_client_class, mock_warning\n  ):\n    \"\"\"Test that warning is logged when both API key and Vertex AI are provided.\"\"\"\n    mock_client = mock.Mock()\n    mock_client_class.return_value = mock_client\n\n    gemini.GeminiLanguageModel(\n        api_key=\"test-key\",\n        vertexai=True,\n        project=\"test-project\",\n        location=\"us-central1\",\n    )\n\n    mock_warning.assert_called_once()\n    warning_msg = mock_warning.call_args[0][0]\n    self.assertIn(\"Both API key and Vertex AI\", warning_msg)\n    self.assertIn(\"API key will take precedence\", warning_msg)\n\n  @mock.patch(\"google.genai.Client\")\n  def test_gemini_vertexai_with_http_options(self, mock_client_class):\n    \"\"\"Test that http_options are passed to genai.Client for VPC endpoints.\"\"\"\n    mock_client = mock.Mock()\n    mock_client_class.return_value = mock_client\n\n    http_options = {\"base_url\": \"https://custom-vpc.p.googleapis.com\"}\n    model = gemini.GeminiLanguageModel(\n        vertexai=True,\n        project=\"test-project\",\n        location=\"us-central1\",\n        http_options=http_options,\n    )\n\n    self.assertEqual(model.http_options, http_options)\n    mock_client_class.assert_called_once_with(\n        api_key=None,\n        vertexai=True,\n        credentials=None,\n        project=\"test-project\",\n        location=\"us-central1\",\n        http_options=http_options,\n    )\n\n\nclass TestOpenAILanguageModelInference(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      (\"without\", \"test-api-key\", None, \"gpt-4o-mini\", 0.5),\n      (\"with\", \"test-api-key\", \"http://127.0.0.1:9001/v1\", \"gpt-4o-mini\", 0.5),\n  )\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_infer_with_parameters(\n      self, api_key, base_url, model_id, temperature, mock_openai_class\n  ):\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"name\": \"John\", \"age\": 30}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id=model_id,\n        api_key=api_key,\n        base_url=base_url,\n        temperature=temperature,\n    )\n\n    batch_prompts = [\"Extract name and age from: John is 30 years old\"]\n    results = list(model.infer(batch_prompts))\n\n    # JSON format adds a system message; only explicitly set params are passed\n    mock_client.chat.completions.create.assert_called_once()\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs[\"model\"], \"gpt-4o-mini\")\n    self.assertEqual(call_args.kwargs[\"temperature\"], temperature)\n    self.assertEqual(call_args.kwargs[\"n\"], 1)\n    self.assertEqual(len(call_args.kwargs[\"messages\"]), 2)\n    self.assertEqual(call_args.kwargs[\"messages\"][0][\"role\"], \"system\")\n    self.assertEqual(call_args.kwargs[\"messages\"][1][\"role\"], \"user\")\n\n    expected_results = [\n        [types.ScoredOutput(score=1.0, output='{\"name\": \"John\", \"age\": 30}')]\n    ]\n    self.assertEqual(results, expected_results)\n\n\nclass TestOpenAILanguageModel(absltest.TestCase):\n\n  def test_openai_parse_output_json(self):\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\", format_type=data.FormatType.JSON\n    )\n\n    output = '{\"key\": \"value\", \"number\": 42}'\n    parsed = model.parse_output(output)\n    self.assertEqual(parsed, {\"key\": \"value\", \"number\": 42})\n\n    with self.assertRaises(ValueError) as context:\n      model.parse_output(\"invalid json\")\n    self.assertIn(\"Failed to parse output as JSON\", str(context.exception))\n\n  def test_openai_parse_output_yaml(self):\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\", format_type=data.FormatType.YAML\n    )\n\n    output = \"key: value\\nnumber: 42\"\n    parsed = model.parse_output(output)\n    self.assertEqual(parsed, {\"key\": \"value\", \"number\": 42})\n\n    with self.assertRaises(ValueError) as context:\n      model.parse_output(\"invalid: yaml: bad\")\n    self.assertIn(\"Failed to parse output as YAML\", str(context.exception))\n\n  def test_openai_no_api_key_raises_error(self):\n    with self.assertRaises(exceptions.InferenceConfigError) as context:\n      openai.OpenAILanguageModel(api_key=None)\n    self.assertEqual(str(context.exception), \"API key not provided.\")\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_extra_kwargs_passed(self, mock_openai_class):\n    \"\"\"Test that extra kwargs are passed to OpenAI API.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\",\n        frequency_penalty=0.5,\n        presence_penalty=0.7,\n        seed=42,\n    )\n\n    list(model.infer([\"test prompt\"]))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs[\"frequency_penalty\"], 0.5)\n    self.assertEqual(call_args.kwargs[\"presence_penalty\"], 0.7)\n    self.assertEqual(call_args.kwargs[\"seed\"], 42)\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_runtime_kwargs_override(self, mock_openai_class):\n    \"\"\"Test that runtime kwargs override stored kwargs.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\",\n        temperature=0.5,\n        seed=123,\n    )\n\n    list(model.infer([\"test prompt\"], temperature=0.8, seed=456))\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs[\"temperature\"], 0.8)\n    self.assertEqual(call_args.kwargs[\"seed\"], 456)\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_json_response_format(self, mock_openai_class):\n    \"\"\"Test that JSON format adds response_format parameter.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\", format_type=data.FormatType.JSON\n    )\n\n    list(model.infer([\"test prompt\"]))\n\n    mock_client.chat.completions.create.assert_called_once()\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(\n        call_args.kwargs[\"response_format\"], {\"type\": \"json_object\"}\n    )\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_temperature_zero(self, mock_openai_class):\n    \"\"\"Verify temperature=0.0 is properly passed to the API.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(api_key=\"test-key\", temperature=0.0)\n\n    list(model.infer([\"test prompt\"]))\n\n    mock_client.chat.completions.create.assert_called_once()\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs[\"temperature\"], 0.0)\n    self.assertEqual(call_args.kwargs[\"model\"], \"gpt-4o-mini\")\n    self.assertEqual(call_args.kwargs[\"n\"], 1)\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_temperature_none_not_sent(self, mock_openai_class):\n    \"\"\"Test that temperature=None is not sent to the API.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    # Test with temperature=None in model init\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\",\n        temperature=None,\n    )\n\n    list(model.infer([\"test prompt\"]))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertNotIn(\"temperature\", call_args.kwargs)\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_none_values_filtered(self, mock_openai_class):\n    \"\"\"Test that None values are not passed to the API.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\",\n        top_p=0.9,\n    )\n\n    list(model.infer([\"test prompt\"], top_p=None, seed=None))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertNotIn(\"top_p\", call_args.kwargs)\n    self.assertNotIn(\"seed\", call_args.kwargs)\n\n  @mock.patch(\"openai.OpenAI\")\n  def test_openai_no_system_message_when_not_json_yaml(self, mock_openai_class):\n    \"\"\"Test that no system message is sent when format_type is not JSON/YAML.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content=\"test output\"))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        api_key=\"test-key\",\n        format_type=None,\n    )\n\n    list(model.infer([\"test prompt\"]))\n\n    call_args = mock_client.chat.completions.create.call_args\n    messages = call_args.kwargs[\"messages\"]\n\n    self.assertEqual(len(messages), 1)\n    self.assertEqual(messages[0][\"role\"], \"user\")\n    self.assertEqual(messages[0][\"content\"], \"test prompt\")\n\n  @mock.patch(\"google.genai.Client\")\n  def test_gemini_none_values_filtered(self, mock_client_class):\n    \"\"\"Test that None values are not passed to Gemini API.\"\"\"\n    mock_client = mock.Mock()\n    mock_client_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.text = '{\"result\": \"test\"}'\n    mock_client.models.generate_content.return_value = mock_response\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        api_key=\"test-key\",\n    )\n\n    list(model.infer([\"test prompt\"], candidate_count=None))\n\n    call_args = mock_client.models.generate_content.call_args\n    config = call_args.kwargs[\"config\"]\n\n    self.assertNotIn(\"candidate_count\", config)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/init_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for the main package functions in __init__.py.\"\"\"\n\nimport textwrap\nfrom unittest import mock\nimport warnings\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import prompting\nimport langextract as lx\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import format_handler as fh\nfrom langextract.core import schema\nfrom langextract.core import types\nfrom langextract.providers import schemas\n\n\nclass InitTest(parameterized.TestCase):\n  \"\"\"Test cases for the main package functions.\"\"\"\n\n  @mock.patch.object(\n      schemas.gemini.GeminiSchema, \"from_examples\", autospec=True\n  )\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_lang_extract_as_lx_extract(\n      self, mock_create_model, mock_gemini_schema\n  ):\n\n    input_text = \"Patient takes Aspirin 100mg every morning.\"\n\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [[\n        types.ScoredOutput(\n            output=textwrap.dedent(\"\"\"\\\n            ```json\n            {\n              \"extractions\": [\n                {\n                  \"entity\": \"Aspirin\",\n                  \"entity_attributes\": {\n                    \"class\": \"medication\"\n                  }\n                },\n                {\n                  \"entity\": \"100mg\",\n                  \"entity_attributes\": {\n                    \"frequency\": \"every morning\",\n                    \"class\": \"dosage\"\n                  }\n                }\n              ]\n            }\n            ```\"\"\"),\n            score=0.9,\n        )\n    ]]\n\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n\n    mock_gemini_schema.return_value = None\n\n    expected_result = data.AnnotatedDocument(\n        document_id=None,\n        extractions=[\n            data.Extraction(\n                extraction_class=\"entity\",\n                extraction_text=\"Aspirin\",\n                char_interval=data.CharInterval(start_pos=14, end_pos=21),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                extraction_index=1,\n                group_index=0,\n                description=None,\n                attributes={\"class\": \"medication\"},\n            ),\n            data.Extraction(\n                extraction_class=\"entity\",\n                extraction_text=\"100mg\",\n                char_interval=data.CharInterval(start_pos=22, end_pos=27),\n                alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                extraction_index=2,\n                group_index=1,\n                description=None,\n                attributes={\"frequency\": \"every morning\", \"class\": \"dosage\"},\n            ),\n        ],\n        text=\"Patient takes Aspirin 100mg every morning.\",\n    )\n\n    mock_description = textwrap.dedent(\"\"\"\\\n        Extract medication and dosage information in order of occurrence.\n        \"\"\")\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Patient takes Tylenol 500mg daily.\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"Tylenol\",\n                    attributes={\n                        \"type\": \"analgesic\",\n                        \"class\": \"medication\",\n                    },\n                ),\n            ],\n        )\n    ]\n    mock_prompt_template = prompting.PromptTemplateStructured(\n        description=mock_description, examples=mock_examples\n    )\n\n    format_handler = fh.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=True,\n    )\n\n    prompt_generator = prompting.QAPromptGenerator(\n        template=mock_prompt_template, format_handler=format_handler\n    )\n\n    actual_result = lx.extract(\n        text_or_documents=input_text,\n        prompt_description=mock_description,\n        examples=mock_examples,\n        api_key=\"some_api_key\",\n        fence_output=True,\n        use_schema_constraints=False,\n    )\n\n    mock_gemini_schema.assert_not_called()\n    mock_create_model.assert_called_once()\n    mock_model.infer.assert_called_once_with(\n        batch_prompts=[prompt_generator.render(input_text)],\n        max_workers=10,\n    )\n\n    self.assertDataclassEqual(expected_result, actual_result)\n\n  @mock.patch(\"langextract.extraction.resolver.Resolver.align\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_resolver_params_alignment_passthrough(\n      self, mock_create_model, mock_align\n  ):\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [\n        [types.ScoredOutput(output='{\"extractions\":[]}')]\n    ]\n    mock_model.requires_fence_output = False\n    mock_create_model.return_value = mock_model\n    mock_align.return_value = []\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Patient takes Tylenol 500mg daily.\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"Tylenol\",\n                    attributes={\n                        \"type\": \"analgesic\",\n                        \"class\": \"medication\",\n                    },\n                ),\n            ],\n        )\n    ]\n\n    lx.extract(\n        text_or_documents=\"test text\",\n        prompt_description=\"desc\",\n        examples=mock_examples,\n        api_key=\"test_key\",\n        resolver_params={\n            \"enable_fuzzy_alignment\": False,\n            \"fuzzy_alignment_threshold\": 0.8,\n            \"accept_match_lesser\": False,\n        },\n    )\n\n    mock_align.assert_called()\n    _, kwargs = mock_align.call_args\n    self.assertFalse(kwargs.get(\"enable_fuzzy_alignment\"))\n    self.assertEqual(kwargs.get(\"fuzzy_alignment_threshold\"), 0.8)\n    self.assertFalse(kwargs.get(\"accept_match_lesser\"))\n\n  @mock.patch(\"langextract.annotation.Annotator.annotate_text\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_resolver_params_suppress_parse_errors(\n      self, mock_create_model, mock_annotate\n  ):\n    \"\"\"Test that suppress_parse_errors can be passed through resolver_params.\"\"\"\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = False\n    mock_model.schema = None\n    mock_create_model.return_value = mock_model\n\n    mock_annotate.return_value = lx.data.AnnotatedDocument(\n        text=\"test\", extractions=[]\n    )\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Example text\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"example\",\n                ),\n            ],\n        )\n    ]\n\n    # This should not raise a TypeError about unknown key\n    lx.extract(\n        text_or_documents=\"test text\",\n        prompt_description=\"desc\",\n        examples=mock_examples,\n        api_key=\"test_key\",\n        resolver_params={\n            \"suppress_parse_errors\": True,\n            \"enable_fuzzy_alignment\": False,\n        },\n    )\n\n    mock_annotate.assert_called()\n    _, kwargs = mock_annotate.call_args\n    self.assertIn(\"suppress_parse_errors\", kwargs)\n    self.assertTrue(kwargs.get(\"suppress_parse_errors\"))\n    self.assertFalse(kwargs.get(\"enable_fuzzy_alignment\"))\n\n  @mock.patch(\"langextract.extraction.resolver.Resolver\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_resolver_params_none_handling(\n      self, mock_create_model, mock_resolver_class\n  ):\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [\n        [types.ScoredOutput(output='{\"extractions\":[]}')]\n    ]\n    mock_model.requires_fence_output = False\n    mock_create_model.return_value = mock_model\n\n    mock_resolver = mock.MagicMock()\n    mock_resolver_class.return_value = mock_resolver\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"test\",\n                ),\n            ],\n        )\n    ]\n\n    with mock.patch(\n        \"langextract.annotation.Annotator.annotate_text\"\n    ) as mock_annotate:\n      mock_annotate.return_value = lx.data.AnnotatedDocument(\n          text=\"test\", extractions=[]\n      )\n\n      lx.extract(\n          text_or_documents=\"test text\",\n          prompt_description=\"desc\",\n          examples=mock_examples,\n          api_key=\"test_key\",\n          resolver_params={\n              \"enable_fuzzy_alignment\": None,\n              \"fuzzy_alignment_threshold\": 0.8,\n          },\n      )\n\n      _, resolver_kwargs = mock_resolver_class.call_args\n      self.assertNotIn(\"enable_fuzzy_alignment\", resolver_kwargs)\n      self.assertNotIn(\"fuzzy_alignment_threshold\", resolver_kwargs)\n      self.assertIn(\"format_handler\", resolver_kwargs)\n\n      _, annotate_kwargs = mock_annotate.call_args\n      self.assertNotIn(\"enable_fuzzy_alignment\", annotate_kwargs)\n      self.assertEqual(annotate_kwargs[\"fuzzy_alignment_threshold\"], 0.8)\n\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_resolver_params_typo_error(self, mock_create_model):\n    mock_model = mock.MagicMock()\n    mock_model.requires_fence_output = False\n    mock_create_model.return_value = mock_model\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Test\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"test\",\n                ),\n            ],\n        )\n    ]\n\n    with self.assertRaisesRegex(TypeError, \"Unknown key in resolver_params\"):\n      lx.extract(\n          text_or_documents=\"test\",\n          prompt_description=\"desc\",\n          examples=mock_examples,\n          api_key=\"test_key\",\n          resolver_params={\n              \"fuzzy_alignment_treshold\": (  # Typo: treshold instead of threshold\n                  0.5\n              ),\n          },\n      )\n\n  @mock.patch(\"langextract.annotation.Annotator.annotate_documents\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_resolver_params_docs_path_passthrough(\n      self, mock_create_model, mock_annotate_docs\n  ):\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [\n        [types.ScoredOutput(output='{\"extractions\":[]}')]\n    ]\n    mock_model.requires_fence_output = False\n    mock_create_model.return_value = mock_model\n    mock_annotate_docs.return_value = []\n\n    docs = [lx.data.Document(text=\"doc1\")]\n    examples = [\n        lx.data.ExampleData(\n            text=\"Example text\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"example\",\n                ),\n            ],\n        )\n    ]\n\n    lx.extract(\n        text_or_documents=docs,\n        prompt_description=\"desc\",\n        examples=examples,\n        api_key=\"k\",\n        resolver_params={\n            \"enable_fuzzy_alignment\": False,\n            \"fuzzy_alignment_threshold\": 0.9,\n            \"accept_match_lesser\": False,\n        },\n    )\n\n    _, kwargs = mock_annotate_docs.call_args\n    self.assertFalse(kwargs.get(\"enable_fuzzy_alignment\"))\n    self.assertEqual(kwargs.get(\"fuzzy_alignment_threshold\"), 0.9)\n    self.assertFalse(kwargs.get(\"accept_match_lesser\"))\n\n  @mock.patch(\"langextract.annotation.Annotator.annotate_text\")\n  @mock.patch(\"langextract.extraction.resolver.Resolver\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_resolver_params_none_threshold(\n      self, mock_create_model, mock_resolver_cls, mock_annotate\n  ):\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [\n        [types.ScoredOutput(output='{\"extractions\":[]}')]\n    ]\n    mock_model.requires_fence_output = False\n    mock_create_model.return_value = mock_model\n    mock_resolver_cls.return_value = mock.MagicMock()\n    mock_annotate.return_value = lx.data.AnnotatedDocument(\n        text=\"t\", extractions=[]\n    )\n\n    lx.extract(\n        text_or_documents=\"t\",\n        prompt_description=\"d\",\n        examples=[\n            lx.data.ExampleData(\n                text=\"example\",\n                extractions=[\n                    lx.data.Extraction(\n                        extraction_class=\"entity\",\n                        extraction_text=\"ex\",\n                    ),\n                ],\n            )\n        ],\n        api_key=\"k\",\n        resolver_params={\"fuzzy_alignment_threshold\": None},\n    )\n\n    _, resolver_kwargs = mock_resolver_cls.call_args\n    self.assertNotIn(\"fuzzy_alignment_threshold\", resolver_kwargs)\n\n    _, annotate_kwargs = mock_annotate.call_args\n    self.assertNotIn(\"fuzzy_alignment_threshold\", annotate_kwargs)\n\n  @mock.patch.object(\n      schemas.gemini.GeminiSchema, \"from_examples\", autospec=True\n  )\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_custom_params_reach_inference(\n      self, mock_create_model, mock_gemini_schema\n  ):\n    \"\"\"Sanity check that custom parameters reach the inference layer.\"\"\"\n    input_text = \"Test text\"\n\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [[\n        types.ScoredOutput(\n            output='```json\\n{\"extractions\": []}\\n```',\n            score=0.9,\n        )\n    ]]\n\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n    mock_gemini_schema.return_value = None\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Example\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"test\",\n                    extraction_text=\"example\",\n                ),\n            ],\n        )\n    ]\n\n    lx.extract(\n        text_or_documents=input_text,\n        prompt_description=\"Test extraction\",\n        examples=mock_examples,\n        api_key=\"test_key\",\n        max_workers=5,\n        fence_output=True,\n        use_schema_constraints=False,\n    )\n\n    mock_model.infer.assert_called_once()\n    _, kwargs = mock_model.infer.call_args\n    self.assertEqual(kwargs.get(\"max_workers\"), 5)\n\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_extract_with_custom_tokenizer(self, mock_create_model):\n    \"\"\"Test that a custom tokenizer can be passed to extract().\"\"\"\n    input_text = \"Test text\"\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [[\n        types.ScoredOutput(\n            output='```json\\n{\"extractions\": []}\\n```',\n            score=0.9,\n        )\n    ]]\n    mock_model.requires_fence_output = True\n    mock_create_model.return_value = mock_model\n\n    def mock_tokenize(text):\n      if text == \"\\u241F\":  # Delimiter\n        return lx.tokenizer.TokenizedText(\n            text=text,\n            tokens=[\n                lx.tokenizer.Token(\n                    index=0,\n                    token_type=lx.tokenizer.TokenType.PUNCTUATION,\n                    char_interval=lx.tokenizer.CharInterval(0, 1),\n                )\n            ],\n        )\n      # Return dummy tokens for other text to avoid \"empty tokens\" error in aligner\n      return lx.tokenizer.TokenizedText(\n          text=text,\n          tokens=[\n              lx.tokenizer.Token(\n                  index=0,\n                  token_type=lx.tokenizer.TokenType.WORD,\n                  char_interval=lx.tokenizer.CharInterval(0, len(text)),\n              )\n          ],\n      )\n\n    mock_tokenizer = mock.MagicMock()\n    mock_tokenizer.tokenize.side_effect = mock_tokenize\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Example\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"test\",\n                    extraction_text=\"example\",\n                ),\n            ],\n        )\n    ]\n\n    lx.extract(\n        text_or_documents=input_text,\n        prompt_description=\"Test extraction\",\n        examples=mock_examples,\n        api_key=\"test_key\",\n        tokenizer=mock_tokenizer,\n    )\n\n    mock_tokenizer.tokenize.assert_called_with(input_text)\n\n  def test_data_module_exports_via_compatibility_shim(self):\n    \"\"\"Verify data module exports are accessible via lx.data.\"\"\"\n    expected_exports = [\n        \"AlignmentStatus\",\n        \"CharInterval\",\n        \"Extraction\",\n        \"Document\",\n        \"AnnotatedDocument\",\n        \"ExampleData\",\n        \"FormatType\",\n    ]\n\n    for name in expected_exports:\n      with self.subTest(export=name):\n        self.assertTrue(\n            hasattr(lx.data, name),\n            f\"lx.data.{name} not accessible via compatibility shim\",\n        )\n\n  def test_tokenizer_module_exports_via_compatibility_shim(self):\n    \"\"\"Verify tokenizer module exports are accessible via lx.tokenizer.\"\"\"\n    expected_exports = [\n        \"BaseTokenizerError\",\n        \"InvalidTokenIntervalError\",\n        \"SentenceRangeError\",\n        \"CharInterval\",\n        \"TokenInterval\",\n        \"TokenType\",\n        \"Token\",\n        \"TokenizedText\",\n        \"tokenize\",\n        \"tokens_text\",\n        \"find_sentence_range\",\n    ]\n\n    for name in expected_exports:\n      with self.subTest(export=name):\n        self.assertTrue(\n            hasattr(lx.tokenizer, name),\n            f\"lx.tokenizer.{name} not accessible via compatibility shim\",\n        )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"show_progress_true_debug_false\",\n          show_progress=True,\n          debug=False,\n          expected_progress_disabled=False,\n      ),\n      dict(\n          testcase_name=\"show_progress_false_debug_false\",\n          show_progress=False,\n          debug=False,\n          expected_progress_disabled=True,\n      ),\n      dict(\n          testcase_name=\"show_progress_true_debug_true\",\n          show_progress=True,\n          debug=True,\n          expected_progress_disabled=False,\n      ),\n      dict(\n          testcase_name=\"show_progress_false_debug_true\",\n          show_progress=False,\n          debug=True,\n          expected_progress_disabled=True,\n      ),\n  )\n  @mock.patch(\"langextract.progress.create_extraction_progress_bar\")\n  @mock.patch(\"langextract.extraction.factory.create_model\")\n  def test_show_progress_controls_progress_bar(\n      self,\n      mock_create_model,\n      mock_progress,\n      show_progress,\n      debug,\n      expected_progress_disabled,\n  ):\n    \"\"\"Test that show_progress parameter controls progress bar visibility.\"\"\"\n    mock_model = mock.MagicMock()\n    mock_model.infer.return_value = [\n        [\n            types.ScoredOutput(\n                output='{\"extractions\": []}',\n                score=0.9,\n            )\n        ]\n    ]\n    mock_model.requires_fence_output = False\n    mock_create_model.return_value = mock_model\n\n    mock_progress.side_effect = lambda iterable, **kwargs: iter(iterable)\n\n    mock_examples = [\n        lx.data.ExampleData(\n            text=\"Example text\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"example\",\n                ),\n            ],\n        )\n    ]\n\n    lx.extract(\n        text_or_documents=\"test text\",\n        prompt_description=\"extract entities\",\n        examples=mock_examples,\n        api_key=\"test_key\",\n        show_progress=show_progress,\n        debug=debug,\n    )\n\n    mock_progress.assert_called()\n    call_args = mock_progress.call_args\n    self.assertEqual(\n        call_args.kwargs.get(\"disable\", False), expected_progress_disabled\n    )\n\n  @mock.patch(\"langextract.factory.create_model\")\n  def test_schema_validation_warning_issued(self, mock_create_model):\n    \"\"\"Test that schema validation warnings are properly issued.\"\"\"\n    mock_model = mock.Mock(spec=base_model.BaseLanguageModel)\n    mock_model.requires_fence_output = True\n    mock_model.infer.return_value = [\n        [types.ScoredOutput(output='{\"extractions\": []}', score=1.0)]\n    ]\n\n    mock_schema = mock.Mock(spec=schema.BaseSchema)\n\n    def validate_format_side_effect(format_handler):\n      warnings.warn(\"Test validation warning\", UserWarning, stacklevel=3)\n\n    mock_schema.validate_format = mock.Mock(\n        side_effect=validate_format_side_effect\n    )\n    mock_model.schema = mock_schema\n\n    mock_create_model.return_value = mock_model\n    test_examples = [\n        lx.data.ExampleData(\n            text=\"test\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"test\",\n                ),\n            ],\n        )\n    ]\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter(\"always\")\n\n      result = lx.extract(\n          text_or_documents=\"Sample text\",\n          prompt_description=\"Extract\",\n          examples=test_examples,\n          model_id=\"test-model\",\n          api_key=\"key\",\n          use_schema_constraints=True,\n      )\n      warning_messages = [str(warning.message) for warning in w]\n      self.assertIn(\n          \"Test validation warning\",\n          \" \".join(warning_messages),\n          \"Schema validation warning should be issued\",\n      )\n\n    self.assertIsNotNone(result)\n\n  def test_gemini_schema_deprecation_warning(self):\n    \"\"\"Test that passing gemini_schema triggers deprecation warning.\"\"\"\n    mock_model = mock.MagicMock(spec=base_model.BaseLanguageModel)\n    mock_model.infer.return_value = iter(\n        [[mock.Mock(output='{\"extractions\": []}')]]\n    )\n    mock_model.requires_fence_output = True\n    mock_model.schema = None\n\n    self.enter_context(\n        mock.patch(\n            \"langextract.factory.create_model\",\n            return_value=mock_model,\n        )\n    )\n\n    self.enter_context(\n        mock.patch(\n            \"langextract.annotation.Annotator.annotate_text\",\n            return_value=data.AnnotatedDocument(text=\"test\", extractions=[]),\n        )\n    )\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter(\"always\")\n\n      _ = lx.extract(\n          text_or_documents=\"test\",\n          prompt_description=\"Extract conditions\",\n          examples=[\n              lx.data.ExampleData(\n                  text=\"test\",\n                  extractions=[\n                      lx.data.Extraction(\n                          extraction_class=\"entity\",\n                          extraction_text=\"test\",\n                      ),\n                  ],\n              )\n          ],\n          model_id=\"gemini-2.5-flash\",\n          api_key=\"test_key\",\n          language_model_params={\"gemini_schema\": \"deprecated\"},\n      )\n\n      self.assertTrue(\n          any(\n              issubclass(warning.category, FutureWarning)\n              and \"gemini_schema\" in str(warning.message)\n              for warning in w\n          ),\n          \"Expected deprecation warning for gemini_schema\",\n      )\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/progress_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for langextract.progress module.\"\"\"\n\nimport unittest\nfrom unittest import mock\n\nimport tqdm\n\nfrom langextract import progress\n\n\nclass ProgressTest(unittest.TestCase):\n\n  def test_download_progress_bar(self):\n    \"\"\"Test download progress bar creation.\"\"\"\n    pbar = progress.create_download_progress_bar(\n        1024, \"https://example.com/file.txt\"\n    )\n\n    self.assertIsInstance(pbar, tqdm.tqdm)\n    self.assertEqual(pbar.total, 1024)\n    self.assertIn(\"Downloading\", pbar.desc)\n\n  def test_extraction_progress_bar(self):\n    \"\"\"Test extraction progress bar creation.\"\"\"\n    pbar = progress.create_extraction_progress_bar(\n        range(10), \"gemini-2.0-flash\"\n    )\n\n    self.assertIsInstance(pbar, tqdm.tqdm)\n    self.assertIn(\"LangExtract\", pbar.desc)\n    self.assertIn(\"gemini-2.0-flash\", pbar.desc)\n\n  def test_save_load_progress_bars(self):\n    \"\"\"Test save and load progress bar creation.\"\"\"\n    save_pbar = progress.create_save_progress_bar(\"/path/file.json\")\n    load_pbar = progress.create_load_progress_bar(\"/path/file.json\")\n\n    self.assertIsInstance(save_pbar, tqdm.tqdm)\n    self.assertIsInstance(load_pbar, tqdm.tqdm)\n    self.assertIn(\"Saving\", save_pbar.desc)\n    self.assertIn(\"Loading\", load_pbar.desc)\n\n  def test_model_info_extraction(self):\n    \"\"\"Test extracting model info from objects.\"\"\"\n    mock_model = mock.MagicMock()\n    mock_model.model_id = \"gemini-1.5-pro\"\n    self.assertEqual(progress.get_model_info(mock_model), \"gemini-1.5-pro\")\n\n    mock_model = mock.MagicMock()\n    del mock_model.model_id\n    del mock_model.model_url\n    self.assertIsNone(progress.get_model_info(mock_model))\n\n  def test_formatting_functions(self):\n    \"\"\"Test message formatting functions.\"\"\"\n    stats = progress.format_extraction_stats(1500, 5000)\n    self.assertIn(\"1,500\", stats)\n    self.assertIn(\"5,000\", stats)\n\n    desc = progress.format_extraction_progress(\"gemini-2.0-flash\")\n    self.assertIn(\"LangExtract\", desc)\n    self.assertIn(\"gemini-2.0-flash\", desc)\n\n    desc_no_model = progress.format_extraction_progress(None)\n    self.assertIn(\"Processing\", desc_no_model)\n\n\nif __name__ == \"__main__\":\n  unittest.main()\n"
  },
  {
    "path": "tests/prompt_validation_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for prompt validation module.\"\"\"\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import extraction\nfrom langextract import prompt_validation\nfrom langextract.core import data\n\n\nclass PromptAlignmentValidationTest(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"exact_alignment\",\n          text=\"Patient takes lisinopril.\",\n          extraction_class=\"Medication\",\n          extraction_text=\"lisinopril\",\n          expected_issues=0,\n          expected_has_failed=False,\n          expected_has_non_exact=False,\n          expected_alignment_status=None,\n      ),\n      dict(\n          testcase_name=\"fuzzy_match_lesser\",\n          text=\"Type 2 diabetes.\",\n          extraction_class=\"Diagnosis\",\n          extraction_text=\"type-2 diabetes\",\n          expected_issues=1,\n          expected_has_failed=False,\n          expected_has_non_exact=True,\n          expected_alignment_status=data.AlignmentStatus.MATCH_LESSER,\n      ),\n      dict(\n          testcase_name=\"extraction_not_found\",\n          text=\"No medications mentioned in this text.\",\n          extraction_class=\"Medication\",\n          extraction_text=\"lisinopril\",\n          expected_issues=1,\n          expected_has_failed=True,\n          expected_has_non_exact=False,\n          expected_alignment_status=None,\n      ),\n  )\n  def test_alignment_detection(\n      self,\n      text,\n      extraction_class,\n      extraction_text,\n      expected_issues,\n      expected_has_failed,\n      expected_has_non_exact,\n      expected_alignment_status,\n  ):\n    \"\"\"Test that different alignment types are correctly detected.\"\"\"\n    example = data.ExampleData(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=extraction_class,\n                extraction_text=extraction_text,\n                attributes={},\n            )\n        ],\n    )\n\n    report = prompt_validation.validate_prompt_alignment([example])\n\n    self.assertLen(report.issues, expected_issues)\n    self.assertEqual(report.has_failed, expected_has_failed)\n    self.assertEqual(report.has_non_exact, expected_has_non_exact)\n\n    if expected_issues > 0:\n      issue = report.issues[0]\n      self.assertEqual(issue.alignment_status, expected_alignment_status)\n      self.assertEqual(issue.extraction_class, extraction_class)\n      if expected_has_failed:\n        self.assertIsNone(issue.alignment_status)\n      elif expected_has_non_exact:\n        self.assertIsNotNone(issue.alignment_status)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"one_fails\",\n          text=\"Patient takes lisinopril and has diabetes mellitus.\",\n          extractions=[\n              (\"Medication\", \"lisinopril\"),  # PASSES - found exactly\n              (\"Diagnosis\", \"diabetes\"),  # PASSES - found exactly\n              (\"Medication\", \"metformin\"),  # FAILS - not in text\n          ],\n          expected_issues=1,\n          expected_has_failed=True,\n          expected_has_non_exact=False,\n          expected_failed_text=\"metformin\",\n      ),\n      dict(\n          testcase_name=\"all_pass\",\n          text=\"Patient takes lisinopril and aspirin for diabetes management.\",\n          extractions=[\n              (\"Medication\", \"lisinopril\"),\n              (\"Medication\", \"aspirin\"),\n              (\"Diagnosis\", \"diabetes\"),\n          ],\n          expected_issues=0,\n          expected_has_failed=False,\n          expected_has_non_exact=False,\n          expected_failed_text=None,\n      ),\n  )\n  def test_multiple_extractions_per_example(\n      self,\n      text,\n      extractions,\n      expected_issues,\n      expected_has_failed,\n      expected_has_non_exact,\n      expected_failed_text,\n  ):\n    \"\"\"Test validation with multiple extractions in a single example.\"\"\"\n    example = data.ExampleData(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=extraction_class,\n                extraction_text=extraction_text,\n                attributes={},\n            )\n            for extraction_class, extraction_text in extractions\n        ],\n    )\n\n    report = prompt_validation.validate_prompt_alignment([example])\n\n    self.assertLen(report.issues, expected_issues)\n    self.assertEqual(report.has_failed, expected_has_failed)\n    self.assertEqual(report.has_non_exact, expected_has_non_exact)\n\n    if expected_failed_text:\n      issue = report.issues[0]\n      self.assertIsNone(issue.alignment_status)\n      self.assertEqual(issue.extraction_text_preview, expected_failed_text)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"warning_mode_with_failed\",\n          text=\"Patient has no known allergies.\",\n          extraction_text=\"penicillin\",\n          validation_level=prompt_validation.PromptValidationLevel.WARNING,\n          strict_non_exact=False,\n      ),\n      dict(\n          testcase_name=\"off_mode_with_failed\",\n          text=\"Patient history incomplete.\",\n          extraction_text=\"aspirin\",\n          validation_level=prompt_validation.PromptValidationLevel.OFF,\n          strict_non_exact=False,\n      ),\n  )\n  def test_validation_levels_that_dont_raise(\n      self, text, extraction_text, validation_level, strict_non_exact\n  ):\n    \"\"\"Test that WARNING and OFF modes don't raise exceptions.\"\"\"\n    example = data.ExampleData(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=\"Medication\",\n                extraction_text=extraction_text,\n                attributes={},\n            )\n        ],\n    )\n\n    report = prompt_validation.validate_prompt_alignment([example])\n\n    # This should not raise an exception in WARNING or OFF modes\n    prompt_validation.handle_alignment_report(\n        report, validation_level, strict_non_exact=strict_non_exact\n    )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"error_mode_failed_alignment\",\n          text=\"Patient has no known allergies.\",\n          extraction_class=\"Medication\",\n          extraction_text=\"penicillin\",\n          strict_non_exact=False,\n          error_pattern=r\"1 extraction\\(s\\).*could not be aligned\",\n      ),\n      dict(\n          testcase_name=\"error_mode_strict_fuzzy_match\",\n          text=\"Type 2 diabetes.\",\n          extraction_class=\"Diagnosis\",\n          extraction_text=\"type-2 diabetes\",\n          strict_non_exact=True,\n          error_pattern=r\"strict mode.*1 non-exact\",\n      ),\n  )\n  def test_error_mode_raises_appropriately(\n      self,\n      text,\n      extraction_class,\n      extraction_text,\n      strict_non_exact,\n      error_pattern,\n  ):\n    \"\"\"Test that ERROR mode raises with appropriate messages.\"\"\"\n    example = data.ExampleData(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=extraction_class,\n                extraction_text=extraction_text,\n                attributes={},\n            )\n        ],\n    )\n\n    report = prompt_validation.validate_prompt_alignment([example])\n\n    with self.assertRaisesRegex(\n        prompt_validation.PromptAlignmentError, error_pattern\n    ):\n      prompt_validation.handle_alignment_report(\n          report,\n          prompt_validation.PromptValidationLevel.ERROR,\n          strict_non_exact=strict_non_exact,\n      )\n\n  def test_empty_examples_produces_empty_report(self):\n    report = prompt_validation.validate_prompt_alignment([])\n\n    self.assertEmpty(report.issues)\n    self.assertFalse(report.has_failed)\n    self.assertFalse(report.has_non_exact)\n\n  def test_multiple_examples_preserve_indices(self):\n    examples = [\n        data.ExampleData(  # Example 0: FAILS - \"metformin\" not in text\n            text=\"First patient record.\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"Medication\",\n                    extraction_text=\"metformin\",\n                    attributes={},\n                )\n            ],\n        ),\n        data.ExampleData(  # Example 1: PASSES - \"aspirin\" found exactly\n            text=\"Patient takes aspirin daily.\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"Medication\",\n                    extraction_text=\"aspirin\",\n                    attributes={},\n                )\n            ],\n        ),\n        data.ExampleData(  # Example 2: NON-EXACT - \"type-2\" fuzzy matches \"Type 2\"\n            text=\"Type 2 diabetes mellitus.\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"Diagnosis\",\n                    extraction_text=\"type-2 diabetes\",\n                    attributes={},\n                )\n            ],\n        ),\n    ]\n\n    report = prompt_validation.validate_prompt_alignment(examples)\n\n    # Expect 2 issues: example 0 (failed) and example 2 (non-exact)\n    self.assertLen(report.issues, 2)\n    self.assertTrue(report.has_failed)\n    self.assertTrue(report.has_non_exact)\n\n    issue_by_index = {issue.example_index: issue for issue in report.issues}\n\n    # Example 0: Failed alignment (metformin not found)\n    self.assertIn(0, issue_by_index)\n    self.assertIsNone(issue_by_index[0].alignment_status)\n\n    # Example 1: No issue (aspirin found exactly)\n    self.assertNotIn(1, issue_by_index)\n\n    # Example 2: Non-exact match (type-2 vs Type 2)\n    self.assertIn(2, issue_by_index)\n    self.assertIsNotNone(issue_by_index[2].alignment_status)\n\n  def test_validation_does_not_mutate_input(self):\n    example = data.ExampleData(\n        text=\"Patient takes lisinopril 10mg daily.\",\n        extractions=[\n            data.Extraction(\n                extraction_class=\"Medication\",\n                extraction_text=\"lisinopril\",\n                attributes={},\n            )\n        ],\n    )\n\n    original_extraction = example.extractions[0]\n\n    self.assertIsNone(getattr(original_extraction, \"token_interval\", None))\n    self.assertIsNone(getattr(original_extraction, \"char_interval\", None))\n    self.assertIsNone(getattr(original_extraction, \"alignment_status\", None))\n\n    _ = prompt_validation.validate_prompt_alignment([example])\n\n    self.assertIsNone(getattr(original_extraction, \"token_interval\", None))\n    self.assertIsNone(getattr(original_extraction, \"char_interval\", None))\n    self.assertIsNone(getattr(original_extraction, \"alignment_status\", None))\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"fuzzy_disabled_rejects_non_exact\",\n          text=\"Patient has type 2 diabetes.\",\n          extraction_class=\"Diagnosis\",\n          extraction_text=\"Type-2 Diabetes\",\n          enable_fuzzy=False,\n          accept_lesser=False,\n          fuzzy_threshold=0.75,\n          expected_has_failed=True,\n          expected_has_non_exact=False,\n      ),\n      dict(\n          testcase_name=\"fuzzy_enabled_accepts_close_match\",\n          text=\"Patient has type 2 diabetes.\",\n          extraction_class=\"Diagnosis\",\n          extraction_text=\"Type-2 Diabetes\",\n          enable_fuzzy=True,\n          accept_lesser=False,\n          fuzzy_threshold=0.75,\n          expected_has_failed=False,\n          expected_has_non_exact=True,\n      ),\n  )\n  def test_alignment_policies(\n      self,\n      text,\n      extraction_class,\n      extraction_text,\n      enable_fuzzy,\n      accept_lesser,\n      fuzzy_threshold,\n      expected_has_failed,\n      expected_has_non_exact,\n  ):\n    \"\"\"Test different alignment policy configurations.\"\"\"\n    example = data.ExampleData(\n        text=text,\n        extractions=[\n            data.Extraction(\n                extraction_class=extraction_class,\n                extraction_text=extraction_text,\n                attributes={},\n            )\n        ],\n    )\n\n    if not enable_fuzzy:\n      default_report = prompt_validation.validate_prompt_alignment([example])\n      self.assertFalse(default_report.has_failed)\n      self.assertTrue(default_report.has_non_exact)\n\n    policy = prompt_validation.AlignmentPolicy(\n        enable_fuzzy_alignment=enable_fuzzy,\n        accept_match_lesser=accept_lesser,\n        fuzzy_alignment_threshold=fuzzy_threshold,\n    )\n    report = prompt_validation.validate_prompt_alignment(\n        [example], policy=policy\n    )\n\n    self.assertEqual(report.has_failed, expected_has_failed)\n    self.assertEqual(report.has_non_exact, expected_has_non_exact)\n\n\nclass ExtractIntegrationTest(absltest.TestCase):\n  \"\"\"Minimal integration test for extract() entry point validation.\"\"\"\n\n  def test_extract_validates_in_error_mode(self):\n    \"\"\"Verify extract() runs validation when configured.\"\"\"\n    examples = [\n        data.ExampleData(\n            text=\"Patient takes aspirin.\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"Medication\",\n                    extraction_text=\"ibuprofen\",\n                    attributes={},\n                )\n            ],\n        )\n    ]\n\n    with self.assertRaisesRegex(\n        prompt_validation.PromptAlignmentError,\n        r\"1 extraction\\(s\\).*could not be aligned\",\n    ):\n      extraction.extract(\n          text_or_documents=\"Test document\",\n          prompt_description=\"Extract medications\",\n          examples=examples,\n          prompt_validation_level=prompt_validation.PromptValidationLevel.ERROR,\n          model_id=\"fake-model\",\n      )\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/prompting_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport textwrap\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import prompting\nfrom langextract.core import data\nfrom langextract.core import format_handler as fh\n\n\nclass QAPromptGeneratorTest(parameterized.TestCase):\n\n  def test_generate_prompt(self):\n    prompt_template_structured = prompting.PromptTemplateStructured(\n        description=(\n            \"You are an assistant specialized in extracting key extractions\"\n            \" from text.\\nIdentify and extract important extractions such as\"\n            \" people, places,\\norganizations, dates, and medical conditions\"\n            \" mentioned in the text.\\n**Please ensure that the extractions are\"\n            \" extracted in the same order as they\\nappear in the source\"\n            \" text.**\\nProvide the extracted extractions in a structured YAML\"\n            \" format.\"\n        ),\n        examples=[\n            data.ExampleData(\n                text=(\n                    \"The patient was diagnosed with hypertension and diabetes.\"\n                ),\n                extractions=[\n                    data.Extraction(\n                        extraction_text=\"hypertension\",\n                        extraction_class=\"medical_condition\",\n                        attributes={\n                            \"chronicity\": \"chronic\",\n                            \"system\": \"cardiovascular\",\n                        },\n                    ),\n                    data.Extraction(\n                        extraction_text=\"diabetes\",\n                        extraction_class=\"medical_condition\",\n                        attributes={\n                            \"chronicity\": \"chronic\",\n                            \"system\": \"endocrine\",\n                        },\n                    ),\n                ],\n            )\n        ],\n    )\n\n    format_handler = fh.FormatHandler(\n        format_type=data.FormatType.YAML,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=True,\n    )\n\n    prompt_generator = prompting.QAPromptGenerator(\n        template=prompt_template_structured,\n        format_handler=format_handler,\n        examples_heading=\"\",\n        question_prefix=\"\",\n        answer_prefix=\"\",\n    )\n\n    actual_prompt_text = prompt_generator.render(\n        \"The patient reports chest pain and shortness of breath.\"\n    )\n\n    expected_prompt_text = textwrap.dedent(f\"\"\"\\\n        You are an assistant specialized in extracting key extractions from text.\n        Identify and extract important extractions such as people, places,\n        organizations, dates, and medical conditions mentioned in the text.\n        **Please ensure that the extractions are extracted in the same order as they\n        appear in the source text.**\n        Provide the extracted extractions in a structured YAML format.\n\n\n        The patient was diagnosed with hypertension and diabetes.\n        ```yaml\n        {data.EXTRACTIONS_KEY}:\n        - medical_condition: hypertension\n          medical_condition_attributes:\n            chronicity: chronic\n            system: cardiovascular\n        - medical_condition: diabetes\n          medical_condition_attributes:\n            chronicity: chronic\n            system: endocrine\n        ```\n\n        The patient reports chest pain and shortness of breath.\n        \"\"\")\n    self.assertEqual(expected_prompt_text, actual_prompt_text)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"json_basic_format\",\n          format_type=data.FormatType.JSON,\n          example_text=\"Patient has diabetes and is prescribed insulin.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"diabetes\",\n                  extraction_class=\"medical_condition\",\n                  attributes={\"chronicity\": \"chronic\"},\n              ),\n              data.Extraction(\n                  extraction_text=\"insulin\",\n                  extraction_class=\"medication\",\n                  attributes={\"prescribed\": \"prescribed\"},\n              ),\n          ],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Patient has diabetes and is prescribed insulin.\n              ```json\n              {{\n                \"{data.EXTRACTIONS_KEY}\": [\n                  {{\n                    \"medical_condition\": \"diabetes\",\n                    \"medical_condition_attributes\": {{\n                      \"chronicity\": \"chronic\"\n                    }}\n                  }},\n                  {{\n                    \"medication\": \"insulin\",\n                    \"medication_attributes\": {{\n                      \"prescribed\": \"prescribed\"\n                    }}\n                  }}\n                ]\n              }}\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"yaml_basic_format\",\n          format_type=data.FormatType.YAML,\n          example_text=\"Patient has diabetes and is prescribed insulin.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"diabetes\",\n                  extraction_class=\"medical_condition\",\n                  attributes={\"chronicity\": \"chronic\"},\n              ),\n              data.Extraction(\n                  extraction_text=\"insulin\",\n                  extraction_class=\"medication\",\n                  attributes={\"prescribed\": \"prescribed\"},\n              ),\n          ],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Patient has diabetes and is prescribed insulin.\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - medical_condition: diabetes\n                medical_condition_attributes:\n                  chronicity: chronic\n              - medication: insulin\n                medication_attributes:\n                  prescribed: prescribed\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"custom_attribute_suffix\",\n          format_type=data.FormatType.YAML,\n          example_text=\"Patient has a fever.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"fever\",\n                  extraction_class=\"symptom\",\n                  attributes={\"severity\": \"mild\"},\n              ),\n          ],\n          attribute_suffix=\"_props\",\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Patient has a fever.\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - symptom: fever\n                symptom_props:\n                  severity: mild\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"yaml_empty_extractions\",\n          format_type=data.FormatType.YAML,\n          example_text=\"Text with no extractions.\",\n          example_extractions=[],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Text with no extractions.\n              ```yaml\n              {data.EXTRACTIONS_KEY}: []\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"json_empty_extractions\",\n          format_type=data.FormatType.JSON,\n          example_text=\"Text with no extractions.\",\n          example_extractions=[],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Text with no extractions.\n              ```json\n              {{\n                \"{data.EXTRACTIONS_KEY}\": []\n              }}\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"yaml_empty_attributes\",\n          format_type=data.FormatType.YAML,\n          example_text=\"Patient is resting comfortably.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"Patient\",\n                  extraction_class=\"person\",\n                  attributes={},\n              ),\n          ],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Patient is resting comfortably.\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - person: Patient\n                person_attributes: {{}}\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"json_empty_attributes\",\n          format_type=data.FormatType.JSON,\n          example_text=\"Patient is resting comfortably.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"Patient\",\n                  extraction_class=\"person\",\n                  attributes={},\n              ),\n          ],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Patient is resting comfortably.\n              ```json\n              {{\n                \"{data.EXTRACTIONS_KEY}\": [\n                  {{\n                    \"person\": \"Patient\",\n                    \"person_attributes\": {{}}\n                  }}\n                ]\n              }}\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"yaml_same_extraction_class_multiple_times\",\n          format_type=data.FormatType.YAML,\n          example_text=(\n              \"Patient has multiple medications: aspirin and lisinopril.\"\n          ),\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"aspirin\",\n                  extraction_class=\"medication\",\n                  attributes={\"dosage\": \"81mg\"},\n              ),\n              data.Extraction(\n                  extraction_text=\"lisinopril\",\n                  extraction_class=\"medication\",\n                  attributes={\"dosage\": \"10mg\"},\n              ),\n          ],\n          expected_formatted_example=textwrap.dedent(f\"\"\"\\\n              Patient has multiple medications: aspirin and lisinopril.\n              ```yaml\n              {data.EXTRACTIONS_KEY}:\n              - medication: aspirin\n                medication_attributes:\n                  dosage: 81mg\n              - medication: lisinopril\n                medication_attributes:\n                  dosage: 10mg\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"json_simplified_no_extractions_key\",\n          format_type=data.FormatType.JSON,\n          example_text=\"Patient has diabetes and is prescribed insulin.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"diabetes\",\n                  extraction_class=\"medical_condition\",\n                  attributes={\"chronicity\": \"chronic\"},\n              ),\n              data.Extraction(\n                  extraction_text=\"insulin\",\n                  extraction_class=\"medication\",\n                  attributes={\"prescribed\": \"prescribed\"},\n              ),\n          ],\n          require_extractions_key=False,\n          expected_formatted_example=textwrap.dedent(\"\"\"\\\n              Patient has diabetes and is prescribed insulin.\n              ```json\n              [\n                {\n                  \"medical_condition\": \"diabetes\",\n                  \"medical_condition_attributes\": {\n                    \"chronicity\": \"chronic\"\n                  }\n                },\n                {\n                  \"medication\": \"insulin\",\n                  \"medication_attributes\": {\n                    \"prescribed\": \"prescribed\"\n                  }\n                }\n              ]\n              ```\n              \"\"\"),\n      ),\n      dict(\n          testcase_name=\"yaml_simplified_no_extractions_key\",\n          format_type=data.FormatType.YAML,\n          example_text=\"Patient has a fever.\",\n          example_extractions=[\n              data.Extraction(\n                  extraction_text=\"fever\",\n                  extraction_class=\"symptom\",\n                  attributes={\"severity\": \"mild\"},\n              ),\n          ],\n          require_extractions_key=False,\n          expected_formatted_example=textwrap.dedent(\"\"\"\\\n              Patient has a fever.\n              ```yaml\n              - symptom: fever\n                symptom_attributes:\n                  severity: mild\n              ```\n              \"\"\"),\n      ),\n  )\n  def test_format_example(\n      self,\n      format_type,\n      example_text,\n      example_extractions,\n      expected_formatted_example,\n      attribute_suffix=\"_attributes\",\n      require_extractions_key=True,\n  ):\n    \"\"\"Tests formatting of examples in different formats and scenarios.\"\"\"\n    example_data = data.ExampleData(\n        text=example_text,\n        extractions=example_extractions,\n    )\n\n    structured_template = prompting.PromptTemplateStructured(\n        description=\"Extract information from the text.\",\n        examples=[example_data],\n    )\n\n    format_handler = fh.FormatHandler(\n        format_type=format_type,\n        use_wrapper=require_extractions_key,\n        wrapper_key=\"extractions\" if require_extractions_key else None,\n        use_fences=True,\n        attribute_suffix=attribute_suffix,\n    )\n\n    prompt_generator = prompting.QAPromptGenerator(\n        template=structured_template,\n        format_handler=format_handler,\n        question_prefix=\"\",\n        answer_prefix=\"\",\n    )\n\n    actual_formatted_example = prompt_generator.format_example_as_text(\n        example_data\n    )\n    self.assertEqual(expected_formatted_example, actual_formatted_example)\n\n\nclass PromptBuilderTest(absltest.TestCase):\n  \"\"\"Tests for PromptBuilder base class.\"\"\"\n\n  def _create_generator(self):\n    \"\"\"Creates a simple QAPromptGenerator for testing.\"\"\"\n    template = prompting.PromptTemplateStructured(\n        description=\"Extract entities.\",\n        examples=[\n            data.ExampleData(\n                text=\"Sample text.\",\n                extractions=[\n                    data.Extraction(\n                        extraction_text=\"Sample\",\n                        extraction_class=\"entity\",\n                    )\n                ],\n            )\n        ],\n    )\n    format_handler = fh.FormatHandler(\n        format_type=data.FormatType.YAML,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=True,\n    )\n    return prompting.QAPromptGenerator(\n        template=template,\n        format_handler=format_handler,\n    )\n\n  def test_build_prompt_renders_chunk_text(self):\n    \"\"\"Verifies build_prompt includes chunk text in the rendered prompt.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.PromptBuilder(generator)\n\n    prompt = builder.build_prompt(\n        chunk_text=\"Test input text.\",\n        document_id=\"doc1\",\n    )\n\n    self.assertIn(\"Test input text.\", prompt)\n    self.assertIn(\"Extract entities.\", prompt)\n\n  def test_build_prompt_includes_additional_context(self):\n    \"\"\"Verifies build_prompt passes additional_context to renderer.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.PromptBuilder(generator)\n\n    prompt = builder.build_prompt(\n        chunk_text=\"Test input.\",\n        document_id=\"doc1\",\n        additional_context=\"Important context here.\",\n    )\n\n    self.assertIn(\"Important context here.\", prompt)\n\n\nclass ContextAwarePromptBuilderTest(absltest.TestCase):\n  \"\"\"Tests for ContextAwarePromptBuilder.\"\"\"\n\n  def _create_generator(self):\n    \"\"\"Creates a simple QAPromptGenerator for testing.\"\"\"\n    template = prompting.PromptTemplateStructured(\n        description=\"Extract entities.\",\n        examples=[\n            data.ExampleData(\n                text=\"Sample text.\",\n                extractions=[\n                    data.Extraction(\n                        extraction_text=\"Sample\",\n                        extraction_class=\"entity\",\n                    )\n                ],\n            )\n        ],\n    )\n    format_handler = fh.FormatHandler(\n        format_type=data.FormatType.YAML,\n        use_wrapper=True,\n        wrapper_key=\"extractions\",\n        use_fences=True,\n    )\n    return prompting.QAPromptGenerator(\n        template=template,\n        format_handler=format_handler,\n    )\n\n  def test_context_window_chars_property(self):\n    \"\"\"Verifies the context_window_chars property returns configured value.\"\"\"\n    generator = self._create_generator()\n\n    builder_none = prompting.ContextAwarePromptBuilder(generator)\n    self.assertIsNone(builder_none.context_window_chars)\n\n    builder_with_value = prompting.ContextAwarePromptBuilder(\n        generator, context_window_chars=100\n    )\n    self.assertEqual(100, builder_with_value.context_window_chars)\n\n  def test_first_chunk_has_no_previous_context(self):\n    \"\"\"Verifies the first chunk does not include previous context.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.ContextAwarePromptBuilder(\n        generator, context_window_chars=50\n    )\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n\n    prompt = builder.build_prompt(\n        chunk_text=\"First chunk text.\",\n        document_id=\"doc1\",\n    )\n\n    self.assertNotIn(context_prefix, prompt)\n    self.assertIn(\"First chunk text.\", prompt)\n\n  def test_second_chunk_includes_previous_context(self):\n    \"\"\"Verifies the second chunk includes text from the first chunk.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.ContextAwarePromptBuilder(\n        generator, context_window_chars=20\n    )\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n\n    builder.build_prompt(chunk_text=\"First chunk ending.\", document_id=\"doc1\")\n    second_prompt = builder.build_prompt(\n        chunk_text=\"Second chunk text.\",\n        document_id=\"doc1\",\n    )\n\n    self.assertIn(context_prefix, second_prompt)\n    self.assertIn(\"chunk ending.\", second_prompt)\n\n  def test_context_disabled_when_none(self):\n    \"\"\"Verifies no context is added when context_window_chars is None.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.ContextAwarePromptBuilder(\n        generator, context_window_chars=None\n    )\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n\n    builder.build_prompt(chunk_text=\"First chunk.\", document_id=\"doc1\")\n    second_prompt = builder.build_prompt(\n        chunk_text=\"Second chunk.\",\n        document_id=\"doc1\",\n    )\n\n    self.assertNotIn(context_prefix, second_prompt)\n\n  def test_context_isolated_per_document(self):\n    \"\"\"Verifies context tracking is isolated per document_id.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.ContextAwarePromptBuilder(\n        generator, context_window_chars=50\n    )\n\n    builder.build_prompt(chunk_text=\"Doc A chunk one.\", document_id=\"docA\")\n    builder.build_prompt(chunk_text=\"Doc B chunk one.\", document_id=\"docB\")\n\n    prompt_a2 = builder.build_prompt(\n        chunk_text=\"Doc A chunk two.\",\n        document_id=\"docA\",\n    )\n    prompt_b2 = builder.build_prompt(\n        chunk_text=\"Doc B chunk two.\",\n        document_id=\"docB\",\n    )\n\n    self.assertIn(\"Doc A chunk one\", prompt_a2)\n    self.assertNotIn(\"Doc B\", prompt_a2)\n    self.assertIn(\"Doc B chunk one\", prompt_b2)\n    self.assertNotIn(\"Doc A\", prompt_b2)\n\n  def test_combines_previous_context_with_additional_context(self):\n    \"\"\"Verifies both previous chunk context and additional_context are included.\"\"\"\n    generator = self._create_generator()\n    builder = prompting.ContextAwarePromptBuilder(\n        generator, context_window_chars=30\n    )\n    context_prefix = prompting.ContextAwarePromptBuilder._CONTEXT_PREFIX\n\n    builder.build_prompt(chunk_text=\"Previous chunk text.\", document_id=\"doc1\")\n    prompt = builder.build_prompt(\n        chunk_text=\"Current chunk.\",\n        document_id=\"doc1\",\n        additional_context=\"Extra info here.\",\n    )\n\n    self.assertIn(context_prefix, prompt)\n    self.assertIn(\"Previous chunk text.\", prompt)\n    self.assertIn(\"Extra info here.\", prompt)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/provider_plugin_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for provider plugin system.\n\nNote: This file contains test helper classes that intentionally have\nfew public methods. The too-few-public-methods warnings are expected.\n\"\"\"\n\nfrom importlib import metadata\nimport os\nfrom pathlib import Path\nimport subprocess\nimport sys\nimport tempfile\nimport textwrap\nimport types as builtin_types\nfrom unittest import mock\nimport uuid\n\nfrom absl.testing import absltest\nimport pytest\n\nimport langextract as lx\nfrom langextract.core import base_model\nfrom langextract.core import types\n\n\ndef _create_mock_entry_points(entry_points_list):\n  \"\"\"Create a mock EntryPoints object for testing.\n\n  Args:\n    entry_points_list: List of entry points to return for langextract.providers.\n\n  Returns:\n    A mock object that behaves like importlib.metadata.EntryPoints.\n  \"\"\"\n\n  class MockEntryPoints:  # pylint: disable=too-few-public-methods\n    \"\"\"Mock EntryPoints that implements select() method.\"\"\"\n\n    def select(self, group=None):\n      if group == \"langextract.providers\":\n        return entry_points_list\n      return []\n\n  return MockEntryPoints()\n\n\nclass PluginSmokeTest(absltest.TestCase):\n  \"\"\"Basic smoke tests for plugin loading functionality.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    lx.providers.registry.clear()\n    # Always reset both flags to ensure clean state\n    lx.providers._reset_for_testing()\n    # Register cleanup\n    self.addCleanup(lx.providers.registry.clear)\n    self.addCleanup(lx.providers._reset_for_testing)\n\n  def test_plugin_discovery_and_usage(self):\n    \"\"\"Test plugin discovery via entry points.\n\n    Entry points can return a class or module. Registration happens via\n    the @register decorator in both cases.\n    \"\"\"\n\n    def _ep_load():\n      @lx.providers.registry.register(r\"^plugin-model\")\n      class PluginProvider(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n        def __init__(self, model_id=None, **kwargs):\n          super().__init__()\n          self.model_id = model_id\n\n        def infer(self, batch_prompts, **kwargs):\n          return [[types.ScoredOutput(score=1.0, output=\"ok\")]]\n\n      return PluginProvider\n\n    ep = builtin_types.SimpleNamespace(\n        name=\"plugin_provider\",\n        group=\"langextract.providers\",\n        value=\"my_pkg:PluginProvider\",\n        load=_ep_load,\n    )\n\n    with mock.patch.object(\n        metadata, \"entry_points\", return_value=_create_mock_entry_points([ep])\n    ):\n      lx.providers.load_plugins_once()\n\n    resolved_cls = lx.providers.registry.resolve(\"plugin-model-123\")\n    self.assertEqual(\n        resolved_cls.__name__,\n        \"PluginProvider\",\n        \"Provider should be resolvable after plugin load\",\n    )\n\n    cfg = lx.factory.ModelConfig(model_id=\"plugin-model-123\")\n    model = lx.factory.create_model(cfg)\n\n    out = model.infer([\"hi\"])[0][0].output\n    self.assertEqual(out, \"ok\", \"Provider should return expected output\")\n\n  def test_plugin_disabled_by_env_var(self):\n    \"\"\"Test that LANGEXTRACT_DISABLE_PLUGINS=1 prevents plugin loading.\"\"\"\n\n    with mock.patch.dict(\"os.environ\", {\"LANGEXTRACT_DISABLE_PLUGINS\": \"1\"}):\n      with mock.patch.object(metadata, \"entry_points\") as mock_ep:\n        lx.providers.load_plugins_once()\n        mock_ep.assert_not_called()\n\n  def test_handles_import_errors_gracefully(self):\n    \"\"\"Test that import errors during plugin loading don't crash.\"\"\"\n\n    def _bad_load():\n      raise ImportError(\"Plugin not found\")\n\n    bad_ep = builtin_types.SimpleNamespace(\n        name=\"bad_plugin\",\n        group=\"langextract.providers\",\n        value=\"bad_pkg:BadProvider\",\n        load=_bad_load,\n    )\n\n    with mock.patch.object(\n        metadata,\n        \"entry_points\",\n        return_value=_create_mock_entry_points([bad_ep]),\n    ):\n      lx.providers.load_plugins_once()\n\n      providers = lx.providers.registry.list_providers()\n      self.assertIsInstance(\n          providers,\n          list,\n          \"Registry should remain functional after import error\",\n      )\n      # Built-in providers should still be loaded even if plugin fails\n      self.assertGreater(\n          len(providers),\n          0,\n          \"Built-in providers should still be available after plugin failure\",\n      )\n\n  def test_load_plugins_once_is_idempotent(self):\n    \"\"\"Test that load_plugins_once only discovers once.\"\"\"\n\n    def _ep_load():\n      @lx.providers.registry.register(r\"^plugin-model\")\n      class Plugin(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n        def infer(self, *a, **k):\n          return [[types.ScoredOutput(score=1.0, output=\"ok\")]]\n\n      return Plugin\n\n    ep = builtin_types.SimpleNamespace(\n        name=\"plugin_provider\",\n        group=\"langextract.providers\",\n        value=\"pkg:Plugin\",\n        load=_ep_load,\n    )\n\n    with mock.patch.object(\n        metadata, \"entry_points\", return_value=_create_mock_entry_points([ep])\n    ) as m:\n      lx.providers.load_plugins_once()\n      lx.providers.load_plugins_once()  # should be a no-op\n      self.assertEqual(m.call_count, 1, \"Discovery should happen only once\")\n\n  def test_non_subclass_entry_point_does_not_crash(self):\n    \"\"\"Test that non-BaseLanguageModel classes don't crash the system.\"\"\"\n\n    class NotAProvider:  # pylint: disable=too-few-public-methods\n      \"\"\"Dummy class to test non-provider handling.\"\"\"\n\n    bad_ep = builtin_types.SimpleNamespace(\n        name=\"bad\",\n        group=\"langextract.providers\",\n        value=\"bad:NotAProvider\",\n        load=lambda: NotAProvider,\n    )\n\n    with mock.patch.object(\n        metadata,\n        \"entry_points\",\n        return_value=_create_mock_entry_points([bad_ep]),\n    ):\n      lx.providers.load_plugins_once()\n      # The system should remain functional even if a bad provider is loaded\n      # Trying to use it would fail, but discovery shouldn't crash\n      providers = lx.providers.registry.list_providers()\n      self.assertIsInstance(\n          providers,\n          list,\n          \"Registry should remain functional with bad provider\",\n      )\n      with self.assertRaisesRegex(\n          lx.exceptions.InferenceConfigError, \"No provider registered\"\n      ):\n        lx.providers.registry.resolve(\"bad\")\n\n  def test_plugin_priority_override_core_provider(self):\n    \"\"\"Plugin with higher priority should override core provider on conflicts.\"\"\"\n\n    lx.providers.registry.clear()\n    lx.providers._plugins_loaded = False\n\n    def _ep_load():\n      @lx.providers.registry.register(r\"^gemini\", priority=50)\n      class OverrideGemini(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n        def infer(self, batch_prompts, **kwargs):\n          return [[types.ScoredOutput(score=1.0, output=\"override\")]]\n\n      return OverrideGemini\n\n    ep = builtin_types.SimpleNamespace(\n        name=\"override_gemini\",\n        group=\"langextract.providers\",\n        value=\"pkg:OverrideGemini\",\n        load=_ep_load,\n    )\n\n    with mock.patch.object(\n        metadata, \"entry_points\", return_value=_create_mock_entry_points([ep])\n    ):\n      lx.providers.load_plugins_once()\n\n    # Core gemini registers with priority 10 in providers.gemini\n    # Our plugin registered with priority 50; it should win.\n    resolved = lx.providers.registry.resolve(\"gemini-2.5-flash\")\n    self.assertEqual(resolved.__name__, \"OverrideGemini\")\n\n  def test_resolve_provider_for_plugin(self):\n    \"\"\"resolve_provider should find plugin by class name and name-insensitive.\"\"\"\n\n    lx.providers.registry.clear()\n    lx.providers._plugins_loaded = False\n\n    def _ep_load():\n      @lx.providers.registry.register(r\"^plugin-resolve\")\n      class ResolveMePlease(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n        def infer(self, batch_prompts, **kwargs):\n          return [[types.ScoredOutput(score=1.0, output=\"ok\")]]\n\n      return ResolveMePlease\n\n    ep = builtin_types.SimpleNamespace(\n        name=\"resolver_plugin\",\n        group=\"langextract.providers\",\n        value=\"pkg:ResolveMePlease\",\n        load=_ep_load,\n    )\n\n    with mock.patch.object(\n        metadata, \"entry_points\", return_value=_create_mock_entry_points([ep])\n    ):\n      lx.providers.load_plugins_once()\n\n    cls_by_exact = lx.providers.registry.resolve_provider(\"ResolveMePlease\")\n    self.assertEqual(cls_by_exact.__name__, \"ResolveMePlease\")\n\n    cls_by_partial = lx.providers.registry.resolve_provider(\"resolveme\")\n    self.assertEqual(cls_by_partial.__name__, \"ResolveMePlease\")\n\n  def test_plugin_with_custom_schema(self):\n    \"\"\"Test that a plugin can provide its own schema implementation.\"\"\"\n\n    class TestPluginSchema(lx.schema.BaseSchema):\n      \"\"\"Test schema implementation.\"\"\"\n\n      def __init__(self, config):\n        self._config = config\n\n      @classmethod\n      def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n        return cls({\"generated\": True, \"count\": len(examples_data)})\n\n      def to_provider_config(self):\n        return {\"custom_schema\": self._config}\n\n      @property\n      def requires_raw_output(self):\n        return True\n\n    def _ep_load():\n      @lx.providers.registry.register(r\"^custom-schema-test\")\n      class SchemaTestProvider(base_model.BaseLanguageModel):\n\n        def __init__(self, model_id=None, **kwargs):\n          super().__init__()\n          self.model_id = model_id\n          self.schema_config = kwargs.get(\"custom_schema\")\n\n        @classmethod\n        def get_schema_class(cls):\n          return TestPluginSchema\n\n        def infer(self, batch_prompts, **kwargs):\n          output = (\n              f\"Schema={self.schema_config}\"\n              if self.schema_config\n              else \"No schema\"\n          )\n          return [[types.ScoredOutput(score=1.0, output=output)]]\n\n      return SchemaTestProvider\n\n    ep = builtin_types.SimpleNamespace(\n        name=\"schema_test\",\n        group=\"langextract.providers\",\n        value=\"test:SchemaTestProvider\",\n        load=_ep_load,\n    )\n\n    with mock.patch.object(\n        metadata, \"entry_points\", return_value=_create_mock_entry_points([ep])\n    ):\n      lx.providers.load_plugins_once()\n\n    provider_cls = lx.providers.registry.resolve(\"custom-schema-test-v1\")\n    self.assertEqual(\n        provider_cls.get_schema_class().__name__,\n        \"TestPluginSchema\",\n        \"Plugin should provide custom schema class\",\n    )\n\n    examples = [\n        lx.data.ExampleData(\n            text=\"Test\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"test\",\n                    extraction_text=\"test text\",\n                )\n            ],\n        )\n    ]\n\n    config = lx.factory.ModelConfig(model_id=\"custom-schema-test-v1\")\n    model = lx.factory._create_model_with_schema(\n        config=config,\n        examples=examples,\n        use_schema_constraints=True,\n        fence_output=None,\n    )\n\n    self.assertIsNotNone(\n        model.schema_config,\n        \"Model should have schema config applied\",\n    )\n    self.assertTrue(\n        model.schema_config[\"generated\"],\n        \"Schema should be generated from examples\",\n    )\n    self.assertFalse(\n        model.requires_fence_output,\n        \"Schema outputs raw JSON, no fences needed\",\n    )\n\n\nclass PluginE2ETest(absltest.TestCase):\n  \"\"\"End-to-end test with actual pip installation.\n\n  This test is expensive and only runs when explicitly requested\n  via tox -e plugin-e2e or in CI when provider files change.\n  \"\"\"\n\n  def test_plugin_with_schema_e2e(self):\n    \"\"\"Test that a plugin with custom schema works end-to-end with extract().\"\"\"\n\n    class TestPluginSchema(lx.schema.BaseSchema):\n      \"\"\"Test schema implementation.\"\"\"\n\n      def __init__(self, config):\n        self._config = config\n\n      @classmethod\n      def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n        return cls({\"generated\": True, \"count\": len(examples_data)})\n\n      def to_provider_config(self):\n        return {\"custom_schema\": self._config}\n\n      @property\n      def requires_raw_output(self):\n        return True\n\n    def _ep_load():\n      @lx.providers.registry.register(r\"^e2e-schema-test\")\n      class SchemaE2EProvider(base_model.BaseLanguageModel):\n\n        def __init__(self, model_id=None, **kwargs):\n          super().__init__()\n          self.model_id = model_id\n          self.schema_config = kwargs.get(\"custom_schema\")\n\n        @classmethod\n        def get_schema_class(cls):\n          return TestPluginSchema\n\n        def infer(self, batch_prompts, **kwargs):\n          # Return a mock extraction that includes schema info\n          if self.schema_config:\n            output = (\n                '{\"extractions\": [{\"entity\": \"test\", '\n                '\"entity_attributes\": {\"schema\": \"applied\"}}]}'\n            )\n          else:\n            output = '{\"extractions\": []}'\n          return [[types.ScoredOutput(score=1.0, output=output)]]\n\n      return SchemaE2EProvider\n\n    ep = builtin_types.SimpleNamespace(\n        name=\"schema_e2e\",\n        group=\"langextract.providers\",\n        value=\"test:SchemaE2EProvider\",\n        load=_ep_load,\n    )\n\n    # Clear and set up registry\n    lx.providers.registry.clear()\n    lx.providers._plugins_loaded = False\n    self.addCleanup(lx.providers.registry.clear)\n    self.addCleanup(setattr, lx.providers, \"_plugins_loaded\", False)\n\n    with mock.patch.object(\n        metadata, \"entry_points\", return_value=_create_mock_entry_points([ep])\n    ):\n      lx.providers.load_plugins_once()\n\n      # Test with extract() using schema constraints\n      examples = [\n          lx.data.ExampleData(\n              text=\"Find entities\",\n              extractions=[\n                  lx.data.Extraction(\n                      extraction_class=\"entity\",\n                      extraction_text=\"example\",\n                      attributes={\"type\": \"test\"},\n                  )\n              ],\n          )\n      ]\n\n      result = lx.extract(\n          text_or_documents=\"Test text for extraction\",\n          prompt_description=\"Extract entities\",\n          examples=examples,\n          model_id=\"e2e-schema-test-v1\",\n          use_schema_constraints=True,\n          fence_output=False,  # Schema supports strict mode\n      )\n\n      # Verify we got results\n      self.assertIsInstance(result, lx.data.AnnotatedDocument)\n      self.assertIsNotNone(result.extractions)\n      self.assertGreater(len(result.extractions), 0)\n\n      # Verify the schema was applied by checking the extraction\n      extraction = result.extractions[0]\n      self.assertEqual(extraction.extraction_class, \"entity\")\n      self.assertIn(\"schema\", extraction.attributes)\n      self.assertEqual(extraction.attributes[\"schema\"], \"applied\")\n\n  @pytest.mark.requires_pip\n  @pytest.mark.integration\n  def test_pip_install_discovery_and_cleanup(self):\n    \"\"\"Test complete plugin lifecycle: install, discovery, usage, uninstall.\n\n    This test:\n    1. Creates a Python package with a provider plugin\n    2. Installs it via pip\n    3. Verifies the plugin is discovered and usable\n    4. Uninstalls and verifies cleanup\n    \"\"\"\n\n    # Skip in Bazel environment where pip operations don't work\n    if os.environ.get(\"TEST_TMPDIR\") or os.environ.get(\n        \"BUILD_WORKING_DIRECTORY\"\n    ):\n      self.skipTest(\"pip install tests don't work in Bazel sandbox\")\n\n    # Also skip if pip is not available\n    try:\n      subprocess.run(\n          [sys.executable, \"-m\", \"pip\", \"--version\"],\n          capture_output=True,\n          check=True,\n      )\n    except (subprocess.CalledProcessError, FileNotFoundError):\n      self.skipTest(\"pip not available in test environment\")\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n      pkg_name = f\"test_langextract_plugin_{uuid.uuid4().hex[:8]}\"\n      pkg_dir = Path(tmpdir) / pkg_name\n      pkg_dir.mkdir()\n\n      (pkg_dir / pkg_name).mkdir()\n      (pkg_dir / pkg_name / \"__init__.py\").write_text(\"\")\n\n      (pkg_dir / pkg_name / \"provider.py\").write_text(textwrap.dedent(\"\"\"\n        import langextract as lx\n        from langextract.core import base_model\n        from langextract.core import types\n\n        USED_BY_EXTRACT = False\n\n        class TestPipSchema(lx.schema.BaseSchema):\n            '''Test schema for pip provider.'''\n\n            def __init__(self, config):\n                self._config = config\n\n            @classmethod\n            def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n                return cls({\"pip_schema\": True, \"examples\": len(examples_data)})\n\n            def to_provider_config(self):\n                return {\"schema_config\": self._config}\n\n            @property\n            def requires_raw_output(self):\n                return True\n\n        @lx.providers.registry.register(r'^test-pip-model', priority=50)\n        class TestPipProvider(base_model.BaseLanguageModel):\n            def __init__(self, model_id, **kwargs):\n                super().__init__()\n                self.model_id = model_id\n                self.schema_config = kwargs.get(\"schema_config\", {})\n\n            @classmethod\n            def get_schema_class(cls):\n                return TestPipSchema\n\n            def infer(self, batch_prompts, **kwargs):\n                global USED_BY_EXTRACT\n                USED_BY_EXTRACT = True\n                schema_info = \"with_schema\" if self.schema_config else \"no_schema\"\n                return [[types.ScoredOutput(score=1.0, output=f\"pip test response: {schema_info}\")]]\n      \"\"\"))\n\n      (pkg_dir / \"pyproject.toml\").write_text(textwrap.dedent(f\"\"\"\n        [build-system]\n        requires = [\"setuptools>=61.0\"]\n        build-backend = \"setuptools.build_meta\"\n\n        [project]\n        name = \"{pkg_name}\"\n        version = \"0.0.1\"\n        description = \"Test plugin for langextract\"\n\n        [project.entry-points.\"langextract.providers\"]\n        test_provider = \"{pkg_name}.provider:TestPipProvider\"\n      \"\"\"))\n\n      pip_env = {\n          **os.environ,\n          \"PIP_NO_INPUT\": \"1\",\n          \"PIP_DISABLE_PIP_VERSION_CHECK\": \"1\",\n      }\n      result = subprocess.run(\n          [\n              sys.executable,\n              \"-m\",\n              \"pip\",\n              \"install\",\n              \"-e\",\n              str(pkg_dir),\n              \"--no-deps\",\n              \"-q\",\n          ],\n          check=True,\n          capture_output=True,\n          text=True,\n          env=pip_env,\n      )\n\n      self.assertEqual(result.returncode, 0, \"pip install failed\")\n      self.assertNotIn(\n          \"ERROR\",\n          result.stderr.upper(),\n          f\"pip install had errors: {result.stderr}\",\n      )\n\n      try:\n        test_script = Path(tmpdir) / \"test_plugin.py\"\n        test_script.write_text(textwrap.dedent(f\"\"\"\n          import langextract as lx\n          import sys\n\n          lx.providers.load_plugins_once()\n\n          # Test 1: Basic usage without schema\n          cfg = lx.factory.ModelConfig(model_id=\"test-pip-model-123\")\n          model = lx.factory.create_model(cfg)\n          result = model.infer([\"test prompt\"])\n          assert \"no_schema\" in result[0][0].output, f\"Got: {{result[0][0].output}}\"\n\n          # Test 2: With schema constraints\n          examples = [\n              lx.data.ExampleData(\n                  text=\"test\",\n                  extractions=[\n                      lx.data.Extraction(\n                          extraction_class=\"test\",\n                          extraction_text=\"test\",\n                      )\n                  ],\n              )\n          ]\n\n          cfg2 = lx.factory.ModelConfig(model_id=\"test-pip-model-456\")\n          model2 = lx.factory._create_model_with_schema(\n              config=cfg2,\n              examples=examples,\n              use_schema_constraints=True,\n              fence_output=None,\n          )\n          result2 = model2.infer([\"test prompt\"])\n          assert \"with_schema\" in result2[0][0].output, f\"Got: {{result2[0][0].output}}\"\n          assert model2.requires_fence_output == False, \"Schema outputs raw JSON, should not need fences\"\n\n          # Test 3: Verify schema class is available\n          provider_cls = lx.providers.registry.resolve(\"test-pip-model-xyz\")\n          assert provider_cls.__name__ == \"TestPipProvider\", \"Plugin should be resolvable\"\n          schema_cls = provider_cls.get_schema_class()\n          assert schema_cls.__name__ == \"TestPipSchema\", f\"Schema class should be TestPipSchema, got {{schema_cls.__name__}}\"\n\n          from {pkg_name}.provider import USED_BY_EXTRACT\n          assert USED_BY_EXTRACT, \"Provider infer() was not called\"\n\n          print(\"SUCCESS: Plugin test with schema passed\")\n        \"\"\"))\n\n        result = subprocess.run(\n            [sys.executable, str(test_script)],\n            capture_output=True,\n            text=True,\n            check=False,\n        )\n\n        self.assertIn(\n            \"SUCCESS\",\n            result.stdout,\n            f\"Test failed. stdout: {result.stdout}, stderr: {result.stderr}\",\n        )\n\n      finally:\n        subprocess.run(\n            [sys.executable, \"-m\", \"pip\", \"uninstall\", \"-y\", pkg_name],\n            check=False,\n            capture_output=True,\n            env=pip_env,\n        )\n\n        lx.providers.registry.clear()\n        lx.providers._plugins_loaded = False\n        lx.providers.load_plugins_once()\n\n        with self.assertRaisesRegex(\n            lx.exceptions.InferenceConfigError,\n            \"No provider registered for model_id='test-pip-model\",\n        ):\n          lx.providers.registry.resolve(\"test-pip-model-789\")\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/provider_schema_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for provider schema discovery and implementations.\"\"\"\n\nfrom unittest import mock\n\nfrom absl.testing import absltest\n\nfrom langextract import exceptions\nfrom langextract import factory\nfrom langextract import schema\nimport langextract as lx\nfrom langextract.core import data\nfrom langextract.providers import gemini\nfrom langextract.providers import ollama\nfrom langextract.providers import openai\nfrom langextract.providers import schemas\n\n\nclass ProviderSchemaDiscoveryTest(absltest.TestCase):\n  \"\"\"Tests for provider schema discovery via get_schema_class().\"\"\"\n\n  def test_gemini_returns_gemini_schema(self):\n    \"\"\"Test that GeminiLanguageModel returns GeminiSchema.\"\"\"\n    schema_class = gemini.GeminiLanguageModel.get_schema_class()\n    self.assertEqual(\n        schema_class,\n        schemas.gemini.GeminiSchema,\n        msg=\"GeminiLanguageModel should return GeminiSchema class\",\n    )\n\n  def test_ollama_returns_format_mode_schema(self):\n    \"\"\"Test that OllamaLanguageModel returns FormatModeSchema.\"\"\"\n    schema_class = ollama.OllamaLanguageModel.get_schema_class()\n    self.assertEqual(\n        schema_class,\n        schema.FormatModeSchema,\n        msg=\"OllamaLanguageModel should return FormatModeSchema class\",\n    )\n\n  def test_openai_returns_none(self):\n    \"\"\"Test that OpenAILanguageModel returns None (no schema support yet).\"\"\"\n    # OpenAI imports dependencies in __init__, not at module level\n    schema_class = openai.OpenAILanguageModel.get_schema_class()\n    self.assertIsNone(\n        schema_class,\n        msg=\"OpenAILanguageModel should return None (no schema support)\",\n    )\n\n\nclass FormatModeSchemaTest(absltest.TestCase):\n  \"\"\"Tests for FormatModeSchema implementation.\"\"\"\n\n  def test_from_examples_ignores_examples(self):\n    \"\"\"Test that FormatModeSchema ignores examples and returns JSON mode.\"\"\"\n    examples_data = [\n        data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test_class\",\n                    extraction_text=\"test extraction\",\n                    attributes={\"key\": \"value\"},\n                )\n            ],\n        )\n    ]\n\n    test_schema = schema.FormatModeSchema.from_examples(examples_data)\n    self.assertEqual(\n        test_schema._format,\n        \"json\",\n        msg=\"FormatModeSchema should default to JSON format\",\n    )\n\n  def test_to_provider_config_returns_format(self):\n    \"\"\"Test that to_provider_config returns format parameter.\"\"\"\n    examples_data = []\n    test_schema = schema.FormatModeSchema.from_examples(examples_data)\n\n    provider_config = test_schema.to_provider_config()\n\n    self.assertEqual(\n        provider_config,\n        {\"format\": \"json\"},\n        msg=\"Provider config should contain format: json\",\n    )\n\n  def test_requires_raw_output_returns_true(self):\n    \"\"\"Test that FormatModeSchema requires raw output for JSON.\"\"\"\n    examples_data = []\n    test_schema = schema.FormatModeSchema.from_examples(examples_data)\n\n    self.assertTrue(\n        test_schema.requires_raw_output,\n        msg=\"FormatModeSchema with JSON should require raw output\",\n    )\n\n  def test_different_examples_same_output(self):\n    \"\"\"Test that different examples produce the same schema for Ollama.\"\"\"\n    examples1 = [\n        data.ExampleData(\n            text=\"Text 1\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"class1\", extraction_text=\"text1\"\n                )\n            ],\n        )\n    ]\n\n    examples2 = [\n        data.ExampleData(\n            text=\"Text 2\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"class2\",\n                    extraction_text=\"text2\",\n                    attributes={\"attr\": \"value\"},\n                )\n            ],\n        )\n    ]\n\n    schema1 = schema.FormatModeSchema.from_examples(examples1)\n    schema2 = schema.FormatModeSchema.from_examples(examples2)\n\n    # Examples are ignored by FormatModeSchema\n    self.assertEqual(\n        schema1.to_provider_config(),\n        schema2.to_provider_config(),\n        msg=\"Different examples should produce same config for Ollama\",\n    )\n\n\nclass OllamaFormatParameterTest(absltest.TestCase):\n  \"\"\"Tests for Ollama format parameter handling.\"\"\"\n\n  def test_ollama_json_format_in_request_payload(self):\n    \"\"\"Test that JSON format is passed to Ollama API by default.\"\"\"\n    with mock.patch(\"requests.post\", autospec=True) as mock_post:\n      mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n      mock_response.status_code = 200\n      mock_response.json.return_value = {\"response\": '{\"test\": \"value\"}'}\n      mock_post.return_value = mock_response\n\n      model = ollama.OllamaLanguageModel(\n          model_id=\"test-model\",\n          format_type=data.FormatType.JSON,\n      )\n\n      list(model.infer([\"Test prompt\"]))\n\n      mock_post.assert_called_once()\n      call_kwargs = mock_post.call_args[1]\n      payload = call_kwargs[\"json\"]\n\n      self.assertEqual(payload[\"format\"], \"json\", msg=\"Format should be json\")\n      self.assertEqual(\n          payload[\"model\"], \"test-model\", msg=\"Model ID should match\"\n      )\n      self.assertEqual(\n          payload[\"prompt\"], \"Test prompt\", msg=\"Prompt should match\"\n      )\n      self.assertFalse(payload[\"stream\"], msg=\"Stream should be False\")\n\n  def test_ollama_default_format_is_json(self):\n    \"\"\"Test that JSON is the default format when not specified.\"\"\"\n    with mock.patch(\"requests.post\", autospec=True) as mock_post:\n      mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n      mock_response.status_code = 200\n      mock_response.json.return_value = {\"response\": '{\"test\": \"value\"}'}\n      mock_post.return_value = mock_response\n\n      model = ollama.OllamaLanguageModel(model_id=\"test-model\")\n\n      list(model.infer([\"Test prompt\"]))\n\n      mock_post.assert_called_once()\n      call_kwargs = mock_post.call_args[1]\n      payload = call_kwargs[\"json\"]\n\n      self.assertEqual(\n          payload[\"format\"], \"json\", msg=\"Default format should be json\"\n      )\n\n  def test_extract_with_ollama_passes_json_format(self):\n    \"\"\"Test that lx.extract() correctly passes JSON format to Ollama API.\"\"\"\n    with mock.patch(\"requests.post\", autospec=True) as mock_post:\n      mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n      mock_response.status_code = 200\n      mock_response.json.return_value = {\n          \"response\": (\n              '{\"extractions\": [{\"extraction_class\": \"test\", \"extraction_text\":'\n              ' \"example\"}]}'\n          )\n      }\n      mock_post.return_value = mock_response\n\n      # Mock the registry to return OllamaLanguageModel\n      with mock.patch(\"langextract.providers.registry.resolve\") as mock_resolve:\n        mock_resolve.return_value = ollama.OllamaLanguageModel\n\n        examples = [\n            data.ExampleData(\n                text=\"Sample text\",\n                extractions=[\n                    data.Extraction(\n                        extraction_class=\"test\",\n                        extraction_text=\"sample\",\n                    )\n                ],\n            )\n        ]\n\n        result = lx.extract(\n            text_or_documents=\"Test document\",\n            prompt_description=\"Extract test information\",\n            examples=examples,\n            model_id=\"gemma2:2b\",\n            model_url=\"http://localhost:11434\",\n            format_type=data.FormatType.JSON,\n            use_schema_constraints=True,\n        )\n\n        mock_post.assert_called()\n\n        last_call = mock_post.call_args_list[-1]\n        payload = last_call[1][\"json\"]\n\n        self.assertEqual(\n            payload[\"format\"],\n            \"json\",\n            msg=\"Format should be json in extract() call\",\n        )\n        self.assertEqual(\n            payload[\"model\"], \"gemma2:2b\", msg=\"Model ID should match\"\n        )\n\n        self.assertIsNotNone(result)\n        self.assertIsInstance(result, data.AnnotatedDocument)\n\n\nclass OllamaYAMLOverrideTest(absltest.TestCase):\n  \"\"\"Tests for Ollama YAML format override behavior.\"\"\"\n\n  def test_ollama_yaml_format_in_request_payload(self):\n    \"\"\"Test that YAML format override appears in Ollama request payload.\"\"\"\n    with mock.patch(\"requests.post\", autospec=True) as mock_post:\n      mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n      mock_response.status_code = 200\n      mock_response.json.return_value = {\"response\": '{\"extractions\": []}'}\n      mock_post.return_value = mock_response\n\n      model = ollama.OllamaLanguageModel(model_id=\"gemma2:2b\", format=\"yaml\")\n\n      list(model.infer([\"Test prompt\"]))\n\n      mock_post.assert_called_once()\n      call_kwargs = mock_post.call_args[1]\n      self.assertIn(\n          \"json\", call_kwargs, msg=\"Request should use json parameter\"\n      )\n      payload = call_kwargs[\"json\"]\n      self.assertIn(\"format\", payload, msg=\"Payload should contain format key\")\n      self.assertEqual(payload[\"format\"], \"yaml\", msg=\"Format should be yaml\")\n\n  def test_yaml_override_sets_fence_output_true(self):\n    \"\"\"Test that overriding to YAML format sets fence_output to True.\"\"\"\n\n    examples_data = [\n        data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test_class\",\n                    extraction_text=\"test extraction\",\n                )\n            ],\n        )\n    ]\n\n    with mock.patch(\"requests.post\", autospec=True) as mock_post:\n      mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n      mock_response.status_code = 200\n      mock_response.json.return_value = {\"response\": '{\"extractions\": []}'}\n      mock_post.return_value = mock_response\n\n      with mock.patch(\"langextract.providers.registry.resolve\") as mock_resolve:\n        mock_resolve.return_value = ollama.OllamaLanguageModel\n\n        config = factory.ModelConfig(\n            model_id=\"gemma2:2b\",\n            provider_kwargs={\"format\": \"yaml\"},\n        )\n\n        model = factory.create_model(\n            config=config,\n            examples=examples_data,\n            use_schema_constraints=True,\n            fence_output=None,  # Let it be computed\n        )\n\n        self.assertTrue(\n            model.requires_fence_output, msg=\"YAML format should require fences\"\n        )\n\n  def test_json_format_keeps_fence_output_false(self):\n    \"\"\"Test that JSON format keeps fence_output False.\"\"\"\n\n    examples_data = [\n        data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test_class\",\n                    extraction_text=\"test extraction\",\n                )\n            ],\n        )\n    ]\n\n    with mock.patch(\"requests.post\", autospec=True) as mock_post:\n      mock_response = mock.Mock(spec=[\"status_code\", \"json\"])\n      mock_response.status_code = 200\n      mock_response.json.return_value = {\"response\": '{\"extractions\": []}'}\n      mock_post.return_value = mock_response\n\n      with mock.patch(\"langextract.providers.registry.resolve\") as mock_resolve:\n        mock_resolve.return_value = ollama.OllamaLanguageModel\n\n        config = factory.ModelConfig(\n            model_id=\"gemma2:2b\",\n            provider_kwargs={\"format\": \"json\"},\n        )\n\n        model = factory.create_model(\n            config=config,\n            examples=examples_data,\n            use_schema_constraints=True,\n            fence_output=None,  # Let it be computed\n        )\n\n        self.assertFalse(\n            model.requires_fence_output,\n            msg=\"JSON format should not require fences\",\n        )\n\n\nclass GeminiSchemaProviderIntegrationTest(absltest.TestCase):\n  \"\"\"Tests for GeminiSchema provider integration.\"\"\"\n\n  def test_gemini_schema_to_provider_config(self):\n    \"\"\"Test that GeminiSchema.to_provider_config includes response_schema.\"\"\"\n    examples_data = [\n        data.ExampleData(\n            text=\"Patient has diabetes\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"condition\",\n                    extraction_text=\"diabetes\",\n                    attributes={\"severity\": \"moderate\"},\n                )\n            ],\n        )\n    ]\n\n    gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n    provider_config = gemini_schema.to_provider_config()\n\n    self.assertIn(\n        \"response_schema\",\n        provider_config,\n        msg=\"GeminiSchema config should contain response_schema\",\n    )\n    self.assertIsInstance(\n        provider_config[\"response_schema\"],\n        dict,\n        msg=\"response_schema should be a dictionary\",\n    )\n    self.assertIn(\n        \"properties\",\n        provider_config[\"response_schema\"],\n        msg=\"response_schema should contain properties field\",\n    )\n\n    self.assertIn(\n        \"response_mime_type\",\n        provider_config,\n        msg=\"GeminiSchema config should contain response_mime_type\",\n    )\n    self.assertEqual(\n        provider_config[\"response_mime_type\"],\n        \"application/json\",\n        msg=\"response_mime_type should be application/json\",\n    )\n\n  def test_gemini_requires_raw_output(self):\n    \"\"\"Test that GeminiSchema requires raw output.\"\"\"\n    examples_data = []\n    gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n    self.assertTrue(\n        gemini_schema.requires_raw_output,\n        msg=\"GeminiSchema should require raw output\",\n    )\n\n  def test_gemini_rejects_yaml_with_schema(self):\n    \"\"\"Test that Gemini raises error when YAML format is used with schema.\"\"\"\n\n    examples_data = [\n        data.ExampleData(\n            text=\"Test\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test\",\n                    extraction_text=\"test text\",\n                )\n            ],\n        )\n    ]\n    test_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n\n    with mock.patch(\"google.genai.Client\", autospec=True):\n      model = gemini.GeminiLanguageModel(\n          model_id=\"gemini-2.5-flash\",\n          api_key=\"test_key\",\n          format_type=data.FormatType.YAML,\n      )\n      model.apply_schema(test_schema)\n\n      prompt = \"Test prompt\"\n      config = {\"temperature\": 0.5}\n      with self.assertRaises(exceptions.InferenceRuntimeError) as cm:\n        _ = model._process_single_prompt(prompt, config)\n\n      self.assertIn(\n          \"only supports JSON format\",\n          str(cm.exception),\n          msg=\"Error should mention JSON-only constraint\",\n      )\n\n  def test_gemini_forwards_schema_to_genai_client(self):\n    \"\"\"Test that GeminiLanguageModel forwards schema config to genai client.\"\"\"\n\n    examples_data = [\n        data.ExampleData(\n            text=\"Test\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test\",\n                    extraction_text=\"test text\",\n                )\n            ],\n        )\n    ]\n    test_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n\n    with mock.patch(\"google.genai.Client\", autospec=True) as mock_client:\n      mock_model_instance = mock.Mock(spec=[\"return_value\"])\n      mock_client.return_value.models.generate_content = mock_model_instance\n      mock_model_instance.return_value.text = '{\"extractions\": []}'\n\n      model = gemini.GeminiLanguageModel(\n          model_id=\"gemini-2.5-flash\",\n          api_key=\"test_key\",\n          response_schema=test_schema.schema_dict,\n          response_mime_type=\"application/json\",\n      )\n\n      prompt = \"Test prompt\"\n      config = {\"temperature\": 0.5}\n      _ = model._process_single_prompt(prompt, config)\n\n      mock_model_instance.assert_called_once()\n      call_kwargs = mock_model_instance.call_args[1]\n      self.assertIn(\n          \"config\",\n          call_kwargs,\n          msg=\"genai.generate_content should receive config parameter\",\n      )\n      self.assertIn(\n          \"response_schema\",\n          call_kwargs[\"config\"],\n          msg=\"Config should contain response_schema from GeminiSchema\",\n      )\n      self.assertIn(\n          \"response_mime_type\",\n          call_kwargs[\"config\"],\n          msg=\"Config should contain response_mime_type\",\n      )\n      self.assertEqual(\n          call_kwargs[\"config\"][\"response_mime_type\"],\n          \"application/json\",\n          msg=\"response_mime_type should be application/json\",\n      )\n\n  def test_gemini_doesnt_forward_non_api_kwargs(self):\n    \"\"\"Test that GeminiLanguageModel doesn't forward non-API kwargs to genai.\"\"\"\n\n    with mock.patch(\"google.genai.Client\", autospec=True) as mock_client:\n      mock_model_instance = mock.Mock(spec=[\"return_value\"])\n      mock_client.return_value.models.generate_content = mock_model_instance\n      mock_model_instance.return_value.text = '{\"extractions\": []}'\n\n      model = gemini.GeminiLanguageModel(\n          model_id=\"gemini-2.5-flash\",\n          api_key=\"test_key\",\n          max_workers=5,\n          response_schema={\"test\": \"schema\"},  # API parameter\n      )\n\n      prompt = \"Test prompt\"\n      config = {\"temperature\": 0.5}\n      _ = model._process_single_prompt(prompt, config)\n\n      mock_model_instance.assert_called_once()\n      call_kwargs = mock_model_instance.call_args[1]\n\n      self.assertNotIn(\n          \"max_workers\",\n          call_kwargs[\"config\"],\n          msg=\"max_workers should not be forwarded to genai API config\",\n      )\n\n      self.assertIn(\n          \"response_schema\",\n          call_kwargs[\"config\"],\n          msg=\"response_schema should be forwarded to genai API config\",\n      )\n\n\nclass SchemaShimTest(absltest.TestCase):\n  \"\"\"Tests for backward compatibility shims in schema module.\"\"\"\n\n  def test_constraint_types_import(self):\n    \"\"\"Test that Constraint and ConstraintType can be imported.\"\"\"\n    from langextract import schema as lx_schema  # pylint: disable=reimported,import-outside-toplevel\n\n    constraint = lx_schema.Constraint()\n    self.assertEqual(\n        constraint.constraint_type,\n        lx_schema.ConstraintType.NONE,\n        msg=\"Default Constraint should have type NONE\",\n    )\n\n    self.assertEqual(\n        lx_schema.ConstraintType.NONE.value,\n        \"none\",\n        msg=\"ConstraintType.NONE should have value 'none'\",\n    )\n\n  def test_provider_schema_imports(self):\n    \"\"\"Test that provider schemas can be imported from schema module.\"\"\"\n    from langextract import schema as lx_schema  # pylint: disable=reimported,import-outside-toplevel\n\n    # Backward compatibility: re-exported from providers.schemas.gemini\n    self.assertTrue(\n        hasattr(lx_schema, \"GeminiSchema\"),\n        msg=(\n            \"GeminiSchema should be importable from schema module for backward\"\n            \" compatibility\"\n        ),\n    )\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/registry_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for the provider registry module.\n\nNote: This file tests the deprecated registry module which is now an alias\nfor router. The no-name-in-module warning for providers.registry is expected.\nTest helper classes also intentionally have few public methods.\n\"\"\"\n# pylint: disable=no-name-in-module\n\nimport re\n\nfrom absl.testing import absltest\n\nfrom langextract import exceptions\nfrom langextract.core import base_model\nfrom langextract.core import types\nfrom langextract.providers import router\n\n\nclass FakeProvider(base_model.BaseLanguageModel):\n  \"\"\"Fake provider for testing.\"\"\"\n\n  def infer(self, batch_prompts, **kwargs):\n    return [[types.ScoredOutput(score=1.0, output=\"test\")]]\n\n  def infer_batch(self, prompts, batch_size=32):\n    return self.infer(prompts)\n\n\nclass AnotherFakeProvider(base_model.BaseLanguageModel):\n  \"\"\"Another fake provider for testing.\"\"\"\n\n  def infer(self, batch_prompts, **kwargs):\n    return [[types.ScoredOutput(score=1.0, output=\"another\")]]\n\n  def infer_batch(self, prompts, batch_size=32):\n    return self.infer(prompts)\n\n\nclass RegistryTest(absltest.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    router.clear()\n\n  def tearDown(self):\n    super().tearDown()\n    router.clear()\n\n  def test_register_decorator(self):\n    \"\"\"Test registering a provider using the decorator.\"\"\"\n\n    @router.register(r\"^test-model\")\n    class TestProvider(FakeProvider):\n      pass\n\n    resolved = router.resolve(\"test-model-v1\")\n    self.assertEqual(resolved, TestProvider)\n\n  def test_register_lazy(self):\n    \"\"\"Test lazy registration with string target.\"\"\"\n    # Use direct registration for test provider to avoid module path issues\n    router.register(r\"^fake-model\")(FakeProvider)\n\n    resolved = router.resolve(\"fake-model-v2\")\n    self.assertEqual(resolved, FakeProvider)\n\n  def test_multiple_patterns(self):\n    \"\"\"Test registering multiple patterns for one provider.\"\"\"\n    # Use direct registration to avoid module path issues in Bazel\n    router.register(r\"^gemini\", r\"^palm\")(FakeProvider)\n\n    self.assertEqual(router.resolve(\"gemini-pro\"), FakeProvider)\n    self.assertEqual(router.resolve(\"palm-2\"), FakeProvider)\n\n  def test_priority_resolution(self):\n    \"\"\"Test that higher priority wins on conflicts.\"\"\"\n    # Use direct registration to avoid module path issues in Bazel\n    router.register(r\"^model\", priority=0)(FakeProvider)\n    router.register(r\"^model\", priority=10)(AnotherFakeProvider)\n\n    resolved = router.resolve(\"model-v1\")\n    self.assertEqual(resolved, AnotherFakeProvider)\n\n  def test_no_provider_registered(self):\n    \"\"\"Test error when no provider matches.\"\"\"\n    with self.assertRaisesRegex(\n        exceptions.InferenceConfigError,\n        \"No provider registered for model_id='unknown-model'\",\n    ):\n      router.resolve(\"unknown-model\")\n\n  def test_caching(self):\n    \"\"\"Test that resolve results are cached.\"\"\"\n    # Use direct registration for test provider to avoid module path issues\n    router.register(r\"^cached\")(FakeProvider)\n\n    # First call\n    result1 = router.resolve(\"cached-model\")\n    # Second call should return cached result\n    result2 = router.resolve(\"cached-model\")\n\n    self.assertIs(result1, result2)\n\n  def test_clear_registry(self):\n    \"\"\"Test clearing the router.\"\"\"\n    # Use direct registration for test provider to avoid module path issues\n    router.register(r\"^temp\")(FakeProvider)\n\n    # Should resolve before clear\n    resolved = router.resolve(\"temp-model\")\n    self.assertEqual(resolved, FakeProvider)\n\n    # Clear registry\n    router.clear()\n\n    # Should fail after clear\n    with self.assertRaises(exceptions.InferenceConfigError):\n      router.resolve(\"temp-model\")\n\n  def test_list_entries(self):\n    \"\"\"Test listing registered entries.\"\"\"\n    router.register_lazy(r\"^test1\", target=\"fake:Target1\", priority=5)\n    router.register_lazy(\n        r\"^test2\", r\"^test3\", target=\"fake:Target2\", priority=10\n    )\n\n    entries = router.list_entries()\n    self.assertEqual(len(entries), 2)\n\n    patterns1, priority1 = entries[0]\n    self.assertEqual(patterns1, [\"^test1\"])\n    self.assertEqual(priority1, 5)\n\n    patterns2, priority2 = entries[1]\n    self.assertEqual(set(patterns2), {\"^test2\", \"^test3\"})\n    self.assertEqual(priority2, 10)\n\n  def test_lazy_loading_defers_import(self):\n    \"\"\"Test that lazy registration doesn't import until resolve.\"\"\"\n    # Register with a module that would fail if imported\n    router.register_lazy(r\"^lazy\", target=\"non.existent.module:Provider\")\n\n    # Registration should succeed without importing\n    entries = router.list_entries()\n    self.assertTrue(any(\"^lazy\" in patterns for patterns, _ in entries))\n\n    # Only on resolve should it try to import and fail\n    with self.assertRaises(ModuleNotFoundError):\n      router.resolve(\"lazy-model\")\n\n  def test_regex_pattern_objects(self):\n    \"\"\"Test using pre-compiled regex patterns.\"\"\"\n    pattern = re.compile(r\"^custom-\\d+\")\n\n    @router.register(pattern)\n    class CustomProvider(FakeProvider):\n      pass\n\n    self.assertEqual(router.resolve(\"custom-123\"), CustomProvider)\n\n    # Should not match without digits\n    with self.assertRaises(exceptions.InferenceConfigError):\n      router.resolve(\"custom-abc\")\n\n  def test_resolve_provider_by_name(self):\n    \"\"\"Test resolving provider by exact name.\"\"\"\n\n    @router.register(r\"^test-model\", r\"^TestProvider$\")\n    class TestProvider(FakeProvider):\n      pass\n\n    # Resolve by exact class name pattern\n    provider = router.resolve_provider(\"TestProvider\")\n    self.assertEqual(provider, TestProvider)\n\n    # Resolve by partial name match\n    provider = router.resolve_provider(\"test\")\n    self.assertEqual(provider, TestProvider)\n\n  def test_resolve_provider_not_found(self):\n    \"\"\"Test resolve_provider raises for unknown provider.\"\"\"\n    with self.assertRaises(exceptions.InferenceConfigError) as cm:\n      router.resolve_provider(\"UnknownProvider\")\n    self.assertIn(\"No provider found matching\", str(cm.exception))\n\n  def test_hf_style_model_id_patterns(self):\n    \"\"\"Test that Hugging Face style model ID patterns work.\n\n    This addresses issue #129 where HF-style model IDs like\n    'meta-llama/Llama-3.2-1B-Instruct' weren't being recognized.\n    \"\"\"\n\n    @router.register(\n        r\"^meta-llama/[Ll]lama\",\n        r\"^google/gemma\",\n        r\"^mistralai/[Mm]istral\",\n        r\"^microsoft/phi\",\n        r\"^Qwen/\",\n        r\"^TinyLlama/\",\n        priority=100,\n    )\n    class TestHFProvider(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n      def infer(self, batch_prompts, **kwargs):\n        return []\n\n    hf_model_ids = [\n        \"meta-llama/Llama-3.2-1B-Instruct\",\n        \"meta-llama/llama-2-7b\",\n        \"google/gemma-2b\",\n        \"mistralai/Mistral-7B-v0.1\",\n        \"microsoft/phi-3-mini\",\n        \"Qwen/Qwen2.5-7B\",\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n    ]\n\n    for model_id in hf_model_ids:\n      with self.subTest(model_id=model_id):\n        provider_class = router.resolve(model_id)\n        self.assertEqual(provider_class, TestHFProvider)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/resolver_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport textwrap\nfrom typing import Sequence\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract import chunking\nfrom langextract import resolver as resolver_lib\nfrom langextract.core import data\nfrom langextract.core import tokenizer\n\n\ndef assert_char_interval_match_source(\n    test_case: absltest.TestCase,\n    source_text: str,\n    extractions: Sequence[data.Extraction],\n):\n  \"\"\"Asserts that the char_interval of matched extractions matches the source text.\n\n  Args:\n    test_case: The TestCase instance.\n    source_text: The original source text.\n    extractions: A sequence of extractions to check.\n  \"\"\"\n  for extraction in extractions:\n    if extraction.alignment_status == data.AlignmentStatus.MATCH_EXACT:\n      assert (\n          extraction.char_interval is not None\n      ), \"char_interval should not be None for AlignmentStatus.MATCH_EXACT\"\n\n      char_int = extraction.char_interval\n      start = char_int.start_pos\n      end = char_int.end_pos\n      test_case.assertIsNotNone(start, \"start_pos should not be None\")\n      test_case.assertIsNotNone(end, \"end_pos should not be None\")\n      extracted = source_text[start:end]\n      test_case.assertEqual(\n          extracted.lower(),\n          extraction.extraction_text.lower(),\n          f\"Extraction '{extraction.extraction_text}' does not match extracted\"\n          f\" '{extracted}' using char_interval {char_int}\",\n      )\n\n\nclass ParserTest(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"json_invalid_input\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.JSON,\n              fence_output=True,\n              strict_fences=True,\n          ),\n          input_text=\"invalid input\",\n          expected_exception=resolver_lib.ResolverParsingError,\n          expected_regex=\".*fence markers.*\",\n      ),\n      dict(\n          testcase_name=\"json_missing_markers\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.JSON,\n              fence_output=True,\n              strict_fences=True,\n          ),\n          input_text='[{\"key\": \"value\"}]',\n          expected_exception=resolver_lib.ResolverParsingError,\n          expected_regex=\".*fence markers.*\",\n      ),\n      dict(\n          testcase_name=\"json_empty_string\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.JSON,\n              fence_output=True,\n          ),\n          input_text=\"\",\n          expected_exception=ValueError,\n          expected_regex=\".*must be a non-empty string.*\",\n      ),\n      dict(\n          testcase_name=\"json_partial_markers\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.JSON,\n              fence_output=True,\n              strict_fences=True,\n          ),\n          input_text='```json\\n{\"key\": \"value\"',\n          expected_exception=resolver_lib.ResolverParsingError,\n          expected_regex=\".*fence markers.*\",\n      ),\n      dict(\n          testcase_name=\"yaml_invalid_input\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.YAML,\n              fence_output=True,\n              strict_fences=True,\n          ),\n          input_text=\"invalid input\",\n          expected_exception=resolver_lib.ResolverParsingError,\n          expected_regex=\".*fence markers.*\",\n      ),\n      dict(\n          testcase_name=\"yaml_missing_markers\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.YAML,\n              fence_output=True,\n              strict_fences=True,\n          ),\n          input_text='[{\"key\": \"value\"}]',\n          expected_exception=resolver_lib.ResolverParsingError,\n          expected_regex=\".*fence markers.*\",\n      ),\n      dict(\n          testcase_name=\"yaml_empty_content\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.YAML,\n              fence_output=True,\n          ),\n          input_text=\"```yaml\\n```\",\n          expected_exception=resolver_lib.ResolverParsingError,\n          expected_regex=(\n              \".*Content must be a mapping with an\"\n              f\" '{data.EXTRACTIONS_KEY}' key.*\"\n          ),\n      ),\n  )\n  def test_parser_error_cases(\n      self, resolver, input_text, expected_exception, expected_regex\n  ):\n    with self.assertRaisesRegex(expected_exception, expected_regex):\n      resolver.string_to_extraction_data(input_text)\n\n\nclass ExtractOrderedEntitiesTest(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"valid_input\",\n          test_input=[\n              {\n                  \"medication\": \"Naprosyn\",\n                  \"medication_index\": 4,\n                  \"frequency\": \"as needed\",\n                  \"frequency_index\": 5,\n                  \"reason\": \"pain\",\n                  \"reason_index\": 8,\n              },\n              {\n                  \"medication\": \"prednisone\",\n                  \"medication_index\": 5,\n                  \"frequency\": \"daily\",\n                  \"frequency_index\": 1,\n              },\n          ],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"frequency\",\n                  extraction_text=\"daily\",\n                  extraction_index=1,\n                  group_index=1,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Naprosyn\",\n                  extraction_index=4,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"frequency\",\n                  extraction_text=\"as needed\",\n                  extraction_index=5,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"prednisone\",\n                  extraction_index=5,\n                  group_index=1,\n              ),\n              data.Extraction(\n                  extraction_class=\"reason\",\n                  extraction_text=\"pain\",\n                  extraction_index=8,\n                  group_index=0,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"empty_input\",\n          test_input=[],\n          expected_output=[],\n      ),\n      dict(\n          testcase_name=\"mixed_index_order\",\n          test_input=[\n              {\n                  \"medication\": \"Ibuprofen\",\n                  \"medication_index\": 2,\n                  \"dosage\": \"400mg\",\n                  \"dosage_index\": 1,\n              },\n              {\n                  \"medication\": \"Acetaminophen\",\n                  \"medication_index\": 1,\n                  \"duration\": \"7 days\",\n                  \"duration_index\": 2,\n              },\n          ],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"400mg\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Acetaminophen\",\n                  extraction_index=1,\n                  group_index=1,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Ibuprofen\",\n                  extraction_index=2,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"duration\",\n                  extraction_text=\"7 days\",\n                  extraction_index=2,\n                  group_index=1,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"missing_index_key\",\n          test_input=[{\n              \"medication\": \"Aspirin\",\n              \"dosage\": \"325mg\",\n              \"dosage_index\": 1,\n          }],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"325mg\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"all_indices_missing\",\n          test_input=[\n              {\"medication\": \"Aspirin\", \"dosage\": \"325mg\"},\n              {\"medication\": \"Ibuprofen\", \"dosage\": \"400mg\"},\n          ],\n          expected_output=[],\n      ),\n      dict(\n          testcase_name=\"single_element_dictionaries\",\n          test_input=[\n              {\"medication\": \"Aspirin\", \"medication_index\": 1},\n              {\"medication\": \"Ibuprofen\", \"medication_index\": 2},\n          ],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Aspirin\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Ibuprofen\",\n                  extraction_index=2,\n                  group_index=1,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"duplicate_indices_unchanged\",\n          test_input=[{\n              \"medication\": \"Aspirin\",\n              \"medication_index\": 1,\n              \"dosage\": \"325mg\",\n              \"dosage_index\": 1,\n              \"form\": \"tablet\",\n              \"form_index\": 1,\n          }],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Aspirin\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"325mg\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"form\",\n                  extraction_text=\"tablet\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"negative_indices\",\n          test_input=[{\n              \"medication\": \"Aspirin\",\n              \"medication_index\": -1,\n              \"dosage\": \"325mg\",\n              \"dosage_index\": -2,\n          }],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"325mg\",\n                  extraction_index=-2,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Aspirin\",\n                  extraction_index=-1,\n                  group_index=0,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"index_without_data_key_ignored\",\n          test_input=[{\n              \"medication_index\": 1,\n              \"dosage\": \"325mg\",\n              \"dosage_index\": 2,\n          }],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"325mg\",\n                  extraction_index=2,\n                  group_index=0,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"no_index_suffix\",\n          resolver=resolver_lib.Resolver(\n              extraction_index_suffix=None,\n              format_type=data.FormatType.JSON,\n          ),\n          test_input=[\n              {\"medication\": \"Aspirin\"},\n              {\"medication\": \"Ibuprofen\"},\n              {\"dosage\": \"325mg\"},\n              {\"dosage\": \"400mg\"},\n          ],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Aspirin\",\n                  extraction_index=1,\n                  group_index=0,\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Ibuprofen\",\n                  extraction_index=2,\n                  group_index=1,\n              ),\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"325mg\",\n                  extraction_index=3,\n                  group_index=2,\n              ),\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"400mg\",\n                  extraction_index=4,\n                  group_index=3,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"attributes_suffix\",\n          resolver=resolver_lib.Resolver(\n              extraction_index_suffix=None,\n              format_type=data.FormatType.JSON,\n          ),\n          test_input=[\n              {\n                  \"patient\": \"Jane Doe\",\n                  \"patient_attributes\": {\n                      \"PERSON\": \"True\",\n                      \"IDENTIFIABLE\": \"True\",\n                  },\n              },\n              {\n                  \"medication\": \"Lisinopril\",\n                  \"medication_attributes\": {\n                      \"THERAPEUTIC\": \"True\",\n                      \"CLINICAL\": \"True\",\n                  },\n              },\n          ],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"patient\",\n                  extraction_text=\"Jane Doe\",\n                  extraction_index=1,\n                  group_index=0,\n                  attributes={\n                      \"PERSON\": \"True\",\n                      \"IDENTIFIABLE\": \"True\",\n                  },\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Lisinopril\",\n                  extraction_index=2,\n                  group_index=1,\n                  attributes={\n                      \"THERAPEUTIC\": \"True\",\n                      \"CLINICAL\": \"True\",\n                  },\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"indices_and_attributes\",\n          test_input=[\n              {\n                  \"patient\": \"John Doe\",\n                  \"patient_index\": 2,\n                  \"patient_attributes\": {\n                      \"IDENTIFIABLE\": \"True\",\n                  },\n                  \"condition\": \"hypertension\",\n                  \"condition_index\": 1,\n                  \"condition_attributes\": {\n                      \"CHRONIC_CONDITION\": \"True\",\n                      \"REQUIRES_MANAGEMENT\": \"True\",\n                  },\n              },\n              {\n                  \"medication\": \"Lisinopril\",\n                  \"medication_index\": 3,\n                  \"medication_attributes\": {\n                      \"ANTIHYPERTENSIVE_MEDICATION\": \"True\",\n                      \"DAILY_USE\": \"True\",\n                  },\n                  \"dosage\": \"10mg\",\n                  \"dosage_index\": 4,\n                  \"dosage_attributes\": {\n                      \"STANDARD_DAILY_DOSE\": \"True\",\n                  },\n              },\n          ],\n          expected_output=[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"hypertension\",\n                  extraction_index=1,\n                  group_index=0,\n                  attributes={\n                      \"CHRONIC_CONDITION\": \"True\",\n                      \"REQUIRES_MANAGEMENT\": \"True\",\n                  },\n              ),\n              data.Extraction(\n                  extraction_class=\"patient\",\n                  extraction_text=\"John Doe\",\n                  extraction_index=2,\n                  group_index=0,\n                  attributes={\n                      \"IDENTIFIABLE\": \"True\",\n                  },\n              ),\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Lisinopril\",\n                  extraction_index=3,\n                  group_index=1,\n                  attributes={\n                      \"ANTIHYPERTENSIVE_MEDICATION\": \"True\",\n                      \"DAILY_USE\": \"True\",\n                  },\n              ),\n              data.Extraction(\n                  extraction_class=\"dosage\",\n                  extraction_text=\"10mg\",\n                  extraction_index=4,\n                  group_index=1,\n                  attributes={\n                      \"STANDARD_DAILY_DOSE\": \"True\",\n                  },\n              ),\n          ],\n      ),\n  )\n  def test_extract_ordered_extractions_success(\n      self,\n      test_input,\n      resolver=None,\n      expected_output=None,\n  ):\n    if resolver is None:\n      resolver = resolver_lib.Resolver(\n          extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX\n      )\n    actual_output = resolver.extract_ordered_extractions(test_input)\n    self.assertEqual(actual_output, expected_output)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"non_integer_indices\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.JSON,\n              extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n          ),\n          test_input=[{\n              \"medication\": \"Aspirin\",\n              \"medication_index\": \"first\",\n              \"dosage\": \"325mg\",\n              \"dosage_index\": \"second\",\n          }],\n          expected_exception=ValueError,\n          expected_regex=\".*must be an integer.*\",\n      ),\n      dict(\n          testcase_name=\"float_indices\",\n          resolver=resolver_lib.Resolver(\n              format_type=data.FormatType.JSON,\n              extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n          ),\n          test_input=[{\"medication\": \"Aspirin\", \"medication_index\": 1.0}],\n          expected_exception=ValueError,\n          expected_regex=\".*must be an integer.*\",\n      ),\n  )\n  def test_extract_ordered_extractions_exceptions(\n      self, resolver, test_input, expected_exception, expected_regex\n  ):\n    with self.assertRaisesRegex(expected_exception, expected_regex):\n      resolver.extract_ordered_extractions(test_input)\n\n\nclass AlignEntitiesTest(parameterized.TestCase):\n  _SOURCE_TEXT_TWO_MEDS = (\n      \"Patient is prescribed Naprosyn and prednisone for treatment.\"\n  )\n  _SOURCE_TEXT_THREE_CONDITIONS_AND_MEDS = (\n      \"Patient with arthritis, fever, and inflammation is prescribed\"\n      \" Naprosyn, prednisone, and ibuprofen.\"\n  )\n  _SOURCE_TEXT_MULTI_WORD_EXTRACTIONS = (\n      \"Pt was prescribed Naprosyn as needed for pain and prednisone for\"\n      \" one month.\"\n  )\n\n  def setUp(self):\n    super().setUp()\n    self.aligner = resolver_lib.WordAligner()\n    self.maxDiff = 10000\n\n  @parameterized.named_parameters(\n      (\n          \"basic_alignment\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_TWO_MEDS,\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=30),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=5, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=35, end_pos=45),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"shuffled_order_of_last_two_extractions\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\", extraction_text=\"arthritis\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\", extraction_text=\"fever\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"inflammation\",\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"ibuprofen\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_THREE_CONDITIONS_AND_MEDS,\n          # Indexes Aligned with Tokens\n          # --------------------------------------------------------------------\n          # Index    | 0        1      2         3      4      5     6\n          # Token    | Patient  with   arthritis ,     fever   ,     and\n          # --------------------------------------------------------------------\n          # Index    | 7              8        9\n          # Token    | inflammation  is       prescribed\n          # --------------------------------------------------------------------\n          # Index    | 10       11        12         13   14      15\n          # Token    | Naprosyn ,         prednisone ,    and     ibuprofen\n          # --------------------------------------------------------------------\n          # Index    | 16\n          # Token    | .\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"arthritis\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=2, end_index=3\n                      ),\n                      char_interval=data.CharInterval(start_pos=13, end_pos=22),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"fever\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=4, end_index=5\n                      ),\n                      char_interval=data.CharInterval(start_pos=24, end_pos=29),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"inflammation\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=7, end_index=8\n                      ),\n                      char_interval=data.CharInterval(start_pos=35, end_pos=47),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=10, end_index=11\n                      ),\n                      char_interval=data.CharInterval(start_pos=62, end_pos=70),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"ibuprofen\",\n                      token_interval=None,\n                      char_interval=None,\n                      alignment_status=None,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=12, end_index=13\n                      ),\n                      char_interval=data.CharInterval(start_pos=72, end_pos=82),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"extraction_not_found\",\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\", extraction_text=\"aspirin\"\n              )\n          ]],\n          _SOURCE_TEXT_TWO_MEDS,\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"aspirin\",\n                  char_interval=None,\n              )\n          ]],\n      ),\n      (\n          \"multiple_word_extraction_partially_matched\",\n          [[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"high blood pressure\",\n              )\n          ]],\n          \"Patient is prescribed high glucose.\",\n          [[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"high blood pressure\",\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=3, end_index=4\n                  ),\n                  alignment_status=data.AlignmentStatus.MATCH_LESSER,\n                  char_interval=data.CharInterval(start_pos=22, end_pos=26),\n              )\n          ]],\n      ),\n      (\n          \"optimize_multiword_extractions_at_back\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn and prednisone\",\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_TWO_MEDS,\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=None,\n                      char_interval=None,\n                      alignment_status=None,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn and prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=45),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"optimize_multiword_extractions_at_front\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn and prednisone\",\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_TWO_MEDS,\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn and prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=45),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      char_interval=None,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_en_dash_unicode_handling\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"word\", extraction_text=\"Separated\"\n                  )\n              ],\n              [data.Extraction(extraction_class=\"word\", extraction_text=\"by\")],\n              [\n                  data.Extraction(\n                      extraction_class=\"word\", extraction_text=\"en–dashes\"\n                  )\n              ],\n          ],\n          \"Separated–by–en–dashes.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"word\",\n                      extraction_text=\"Separated\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=0, end_index=1\n                      ),\n                      char_interval=data.CharInterval(start_pos=0, end_pos=9),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"word\",\n                      extraction_text=\"by\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=2, end_index=3\n                      ),\n                      char_interval=data.CharInterval(start_pos=10, end_pos=12),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"word\",\n                      extraction_text=\"en–dashes\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=4, end_index=7\n                      ),\n                      char_interval=data.CharInterval(start_pos=13, end_pos=22),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"empty_source_text\",\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n              )\n          ]],\n          \"\",\n          ValueError,\n      ),\n      (\n          \"special_characters_in_extractions\",\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\", extraction_text=\"Napro-syn\"\n              )\n          ]],\n          \"Patient is prescribed Napro-syn.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Napro-syn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=31),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_extraction_with_substring_of_another_not_matched\",\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\", extraction_text=\"Napro\"\n              )\n          ]],\n          _SOURCE_TEXT_TWO_MEDS,\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Napro\",\n                  char_interval=None,\n              )\n          ]],\n      ),\n      (\n          \"test_empty_extractions_list\",\n          [],\n          _SOURCE_TEXT_TWO_MEDS,\n          [],\n      ),\n      (\n          \"test_extractions_with_similar_words\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Napro\"\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_TWO_MEDS,\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=30),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Napro\",\n                      char_interval=None,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_source_text_with_repeated_extractions\",\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n              )\n          ]],\n          \"Patient is prescribed Naprosyn and Naprosyn.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=30),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_interleaved_extractions\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\", extraction_text=\"arthritis\"\n                  )\n              ],\n          ],\n          \"Patient with arthritis is prescribed Naprosyn.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      char_interval=None,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"arthritis\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=2, end_index=3\n                      ),\n                      char_interval=data.CharInterval(start_pos=13, end_pos=22),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"overlapping_extractions_different_types\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"Naprosyn allergy\",\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_TWO_MEDS,\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=30),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"Naprosyn allergy\",\n                      char_interval=None,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_overlapping_text_extractions_with_overlapping_source\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\", extraction_text=\"high blood\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"blood pressure\",\n                  )\n              ],\n          ],\n          \"Patient has high blood pressure.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"high blood\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=2, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=12, end_pos=22),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"blood pressure\",\n                      char_interval=None,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_multiple_instances_same_extraction\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                  )\n              ],\n          ],\n          \"Naprosyn, prednisone, and again Naprosyn are prescribed.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=0, end_index=1\n                      ),\n                      char_interval=data.CharInterval(start_pos=0, end_pos=8),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=2, end_index=3\n                      ),\n                      char_interval=data.CharInterval(start_pos=10, end_pos=20),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_longer_extraction_spanning_multiple_words\",\n          [[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"rheumatoid arthritis\",\n              )\n          ]],\n          \"Patient diagnosed with rheumatoid arthritis.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"rheumatoid arthritis\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=5\n                      ),\n                      char_interval=data.CharInterval(start_pos=23, end_pos=43),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_case_insensitivity\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"PREDNISONE\",\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_TWO_MEDS.lower(),\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=30),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"PREDNISONE\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=5, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=35, end_pos=45),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"numerical_extractions\",\n          [[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"Ibuprofen 600mg\",\n              )\n          ]],\n          \"Patient was given Ibuprofen 600mg twice daily.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Ibuprofen 600mg\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=18, end_pos=33),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_extractions_spanning_across_sentence_boundaries\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Ibuprofen\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"instruction\",\n                      extraction_text=\"take with food\",\n                  )\n              ],\n          ],\n          \"Take Ibuprofen. Always take with food.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Ibuprofen\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=1, end_index=2\n                      ),\n                      char_interval=data.CharInterval(start_pos=5, end_pos=14),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"instruction\",\n                      extraction_text=\"take with food\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=4, end_index=7\n                      ),\n                      char_interval=data.CharInterval(start_pos=23, end_pos=37),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_multiple_multiword_extractions_multi_group\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\", extraction_text=\"Naprosyn\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"frequency\", extraction_text=\"as needed\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"reason\", extraction_text=\"pain\"\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"duration\",\n                      extraction_text=\"for one month\",\n                  )\n              ],\n          ],\n          _SOURCE_TEXT_MULTI_WORD_EXTRACTIONS,\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Naprosyn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=4\n                      ),\n                      char_interval=data.CharInterval(start_pos=18, end_pos=26),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"frequency\",\n                      extraction_text=\"as needed\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=4, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=27, end_pos=36),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"reason\",\n                      extraction_text=\"pain\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=7, end_index=8\n                      ),\n                      char_interval=data.CharInterval(start_pos=41, end_pos=45),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=9, end_index=10\n                      ),\n                      char_interval=data.CharInterval(start_pos=50, end_pos=60),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"duration\",\n                      extraction_text=\"for one month\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=10, end_index=13\n                      ),\n                      char_interval=data.CharInterval(start_pos=61, end_pos=74),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"extraction_with_tokenizing_pipe_delimiter\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Napro | syn\",\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                  )\n              ],\n          ],\n          \"Patient is prescribed Napro | syn and prednisone.\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"Napro | syn\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=6\n                      ),\n                      char_interval=data.CharInterval(start_pos=22, end_pos=33),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"prednisone\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=7, end_index=8\n                      ),\n                      char_interval=data.CharInterval(start_pos=38, end_pos=48),\n                      alignment_status=data.AlignmentStatus.MATCH_EXACT,\n                  )\n              ],\n          ],\n      ),\n      (\n          \"test_only_matching_end_does_not_align\",\n          [\n              [\n                  data.Extraction(\n                      extraction_class=\"some_class\",\n                      extraction_text=\"only matched end\",\n                  )\n              ],\n          ],\n          \"end\",\n          [[\n              data.Extraction(\n                  extraction_class=\"some_class\",\n                  extraction_text=\"only matched end\",\n                  char_interval=None,\n                  alignment_status=None,\n              )\n          ]],\n      ),\n      dict(\n          testcase_name=\"fuzzy_alignment_success\",\n          # Tests fuzzy alignment alongside exact matching. Shows different alignment statuses:\n          # \"heart problems\" gets fuzzy match, \"severe heart problems complications\" gets lesser match.\n          # Demonstrates both fuzzy and lesser matching working with 75% threshold.\n          extractions=[\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"heart problems\",\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"severe heart problems complications\",\n                  )\n              ],\n          ],\n          source_text=\"Patient has severe heart problems today.\",\n          expected_output=[\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"heart problems\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=3, end_index=5\n                      ),\n                      char_interval=data.CharInterval(start_pos=19, end_pos=33),\n                      alignment_status=data.AlignmentStatus.MATCH_FUZZY,\n                  )\n              ],\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"severe heart problems complications\",\n                      token_interval=tokenizer.TokenInterval(\n                          start_index=2, end_index=5\n                      ),\n                      char_interval=data.CharInterval(start_pos=12, end_pos=33),\n                      alignment_status=data.AlignmentStatus.MATCH_LESSER,\n                  )\n              ],\n          ],\n          enable_fuzzy_alignment=True,\n      ),\n      dict(\n          testcase_name=\"fuzzy_alignment_below_threshold\",\n          # Tests fuzzy alignment failure when overlap ratio < _FUZZY_ALIGNMENT_MIN_THRESHOLD (75%).\n          # No tokens overlap between \"completely different medicine\" and \"Patient takes aspirin daily.\"\n          extractions=[\n              [\n                  data.Extraction(\n                      extraction_class=\"medication\",\n                      extraction_text=\"completely different medicine\",\n                  )\n              ],\n          ],\n          source_text=\"Patient takes aspirin daily.\",\n          expected_output=[[\n              data.Extraction(\n                  extraction_class=\"medication\",\n                  extraction_text=\"completely different medicine\",\n                  char_interval=None,\n                  alignment_status=None,\n              )\n          ]],\n          enable_fuzzy_alignment=True,\n      ),\n      dict(\n          testcase_name=\"accept_match_lesser_disabled\",\n          # Tests accept_match_lesser=False with fuzzy fallback.\n          extractions=[\n              [\n                  data.Extraction(\n                      extraction_class=\"condition\",\n                      extraction_text=\"patient heart problems today\",\n                  )\n              ],\n          ],\n          source_text=\"Patient has heart problems today.\",\n          expected_output=[[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"patient heart problems today\",\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=0, end_index=5\n                  ),\n                  char_interval=data.CharInterval(start_pos=0, end_pos=32),\n                  alignment_status=data.AlignmentStatus.MATCH_FUZZY,\n              )\n          ]],\n          enable_fuzzy_alignment=True,\n          accept_match_lesser=False,\n      ),\n      dict(\n          testcase_name=\"fuzzy_alignment_subset_window\",\n          # Extraction is a subset of a longer source clause; ensures extra tokens do not penalise score.\n          extractions=[[\n              data.Extraction(\n                  extraction_class=\"tendon\",\n                  extraction_text=\"The iliopsoas tendon is intact\",\n              )\n          ]],\n          source_text=(\n              \"The iliopsoas and proximal hamstring tendons are intact.\"\n          ),\n          expected_output=[[\n              data.Extraction(\n                  extraction_class=\"tendon\",\n                  extraction_text=\"The iliopsoas tendon is intact\",\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=0, end_index=8\n                  ),\n                  char_interval=data.CharInterval(start_pos=0, end_pos=55),\n                  alignment_status=data.AlignmentStatus.MATCH_FUZZY,\n              )\n          ]],\n          enable_fuzzy_alignment=True,\n          accept_match_lesser=False,\n      ),\n      dict(\n          testcase_name=\"fuzzy_alignment_with_reordered_words\",\n          # Tests fuzzy alignment's ability to handle reordered words in the extraction.\n          extractions=[[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"problems heart\",  # Reordered words\n                  char_interval=data.CharInterval(start_pos=12, end_pos=33),\n                  alignment_status=data.AlignmentStatus.MATCH_FUZZY,\n              )\n          ]],\n          source_text=\"Patient has severe heart problems today.\",\n          expected_output=[[\n              data.Extraction(\n                  extraction_class=\"condition\",\n                  extraction_text=\"problems heart\",\n                  # The best matching window in the source is \"severe heart problems\"\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=2, end_index=5\n                  ),\n                  char_interval=data.CharInterval(start_pos=12, end_pos=33),\n                  alignment_status=data.AlignmentStatus.MATCH_FUZZY,\n              )\n          ]],\n          enable_fuzzy_alignment=True,\n      ),\n      dict(\n          testcase_name=\"fuzzy_alignment_fails_low_ratio\",\n          # An extraction that partially overlaps but is below the fuzzy threshold should not be aligned.\n          extractions=[[\n              data.Extraction(\n                  extraction_class=\"symptom\",\n                  extraction_text=\"headache and fever\",\n              )\n          ]],\n          source_text=\"Patient reports back pain and a fever.\",\n          expected_output=[[\n              data.Extraction(\n                  extraction_class=\"symptom\",\n                  extraction_text=\"headache and fever\",\n                  char_interval=None,\n                  alignment_status=None,\n              )\n          ]],\n          enable_fuzzy_alignment=True,\n      ),\n      dict(\n          testcase_name=\"fuzzy_alignment_partial_overlap_success\",\n          # An extraction where the number of matched tokens divided by total extraction tokens\n          # is >= the threshold (3/4 = 0.75).\n          extractions=[[\n              data.Extraction(\n                  extraction_class=\"finding\",\n                  extraction_text=\"mild degenerative disc disease\",\n              )\n          ]],\n          source_text=(\n              \"Findings consistent with degenerative disc disease at L5-S1.\"\n          ),\n          expected_output=[[\n              data.Extraction(\n                  extraction_class=\"finding\",\n                  extraction_text=\"mild degenerative disc disease\",\n                  # The best window found is \"degenerative disc disease\"\n                  token_interval=tokenizer.TokenInterval(\n                      start_index=3, end_index=6\n                  ),\n                  char_interval=data.CharInterval(start_pos=20, end_pos=50),\n                  alignment_status=data.AlignmentStatus.MATCH_FUZZY,\n              )\n          ]],\n          enable_fuzzy_alignment=True,\n      ),\n  )\n  def test_extraction_alignment(\n      self,\n      extractions: Sequence[Sequence[data.Extraction]],\n      source_text: str,\n      expected_output: Sequence[Sequence[data.Extraction]] | ValueError,\n      enable_fuzzy_alignment: bool = False,\n      accept_match_lesser: bool = True,\n  ):\n    if expected_output is ValueError:\n      with self.assertRaises(ValueError):\n        self.aligner.align_extractions(\n            extractions, source_text, enable_fuzzy_alignment=False\n        )\n    else:\n      aligned_extraction_groups = self.aligner.align_extractions(\n          extractions,\n          source_text,\n          enable_fuzzy_alignment=enable_fuzzy_alignment,\n          accept_match_lesser=accept_match_lesser,\n      )\n      flattened_extractions = []\n      for group in aligned_extraction_groups:\n        flattened_extractions.extend(group)\n      assert_char_interval_match_source(\n          self, source_text, flattened_extractions\n      )\n      self.assertEqual(aligned_extraction_groups, expected_output)\n\n\nclass ResolverTest(parameterized.TestCase):\n  _TWO_MEDICATIONS_JSON_UNDELIMITED = textwrap.dedent(f\"\"\"\\\n      {{\n        \"{data.EXTRACTIONS_KEY}\": [\n          {{\n            \"medication\": \"Naprosyn\",\n            \"medication_index\": 4,\n            \"frequency\": \"as needed\",\n            \"frequency_index\": 5,\n            \"reason\": \"pain\",\n            \"reason_index\": 8\n          }},\n          {{\n            \"medication\": \"prednisone\",\n            \"medication_index\": 9,\n            \"duration\": \"for one month\",\n            \"duration_index\": 10\n          }}\n        ]\n      }}\"\"\")\n\n  _TWO_MEDICATIONS_YAML_UNDELIMITED = textwrap.dedent(f\"\"\"\\\n  {data.EXTRACTIONS_KEY}:\n    - medication: \"Naprosyn\"\n      medication_index: 4\n      frequency: \"as needed\"\n      frequency_index: 5\n      reason: \"pain\"\n      reason_index: 8\n\n    - medication: \"prednisone\"\n      medication_index: 9\n      duration: \"for one month\"\n      duration_index: 10\n  \"\"\")\n\n  _EXPECTED_TWO_MEDICATIONS_ANNOTATED = [\n      data.Extraction(\n          extraction_class=\"medication\",\n          extraction_text=\"Naprosyn\",\n          extraction_index=4,\n          group_index=0,\n      ),\n      data.Extraction(\n          extraction_class=\"frequency\",\n          extraction_text=\"as needed\",\n          extraction_index=5,\n          group_index=0,\n      ),\n      data.Extraction(\n          extraction_class=\"reason\",\n          extraction_text=\"pain\",\n          extraction_index=8,\n          group_index=0,\n      ),\n      data.Extraction(\n          extraction_class=\"medication\",\n          extraction_text=\"prednisone\",\n          extraction_index=9,\n          group_index=1,\n      ),\n      data.Extraction(\n          extraction_class=\"duration\",\n          extraction_text=\"for one month\",\n          extraction_index=10,\n          group_index=1,\n      ),\n  ]\n\n  def setUp(self):\n    super().setUp()\n    self.default_resolver = resolver_lib.Resolver(\n        format_type=data.FormatType.JSON,\n        extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n    )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"json_with_fence\",\n          resolver=resolver_lib.Resolver(\n              fence_output=True,\n              format_type=data.FormatType.JSON,\n              extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n          ),\n          input_text=textwrap.dedent(f\"\"\"\\\n            ```json\n            {{\n              \"{data.EXTRACTIONS_KEY}\": [\n                {{\n                  \"medication\": \"Naprosyn\",\n                  \"medication_index\": 4,\n                  \"frequency\": \"as needed\",\n                  \"frequency_index\": 5,\n                  \"reason\": \"pain\",\n                  \"reason_index\": 8\n                }},\n                {{\n                  \"medication\": \"prednisone\",\n                  \"medication_index\": 9,\n                  \"duration\": \"for one month\",\n                  \"duration_index\": 10\n                }}\n              ]\n            }}\n            ```\"\"\"),\n          expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED,\n      ),\n      dict(\n          testcase_name=\"yaml_with_fence\",\n          resolver=resolver_lib.Resolver(\n              fence_output=True,\n              format_type=data.FormatType.YAML,\n              extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n          ),\n          input_text=textwrap.dedent(f\"\"\"\\\n            ```yaml\n            {data.EXTRACTIONS_KEY}:\n              - medication: \"Naprosyn\"\n                medication_index: 4\n                frequency: \"as needed\"\n                frequency_index: 5\n                reason: \"pain\"\n                reason_index: 8\n\n              - medication: \"prednisone\"\n                medication_index: 9\n                duration: \"for one month\"\n                duration_index: 10\n            ```\"\"\"),\n          expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED,\n      ),\n      dict(\n          testcase_name=\"json_no_fence\",\n          resolver=resolver_lib.Resolver(\n              fence_output=False,\n              format_type=data.FormatType.JSON,\n              extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n          ),\n          input_text=_TWO_MEDICATIONS_JSON_UNDELIMITED,\n          expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED,\n      ),\n      dict(\n          testcase_name=\"yaml_no_fence\",\n          resolver=resolver_lib.Resolver(\n              fence_output=False,\n              format_type=data.FormatType.YAML,\n              extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,\n          ),\n          input_text=_TWO_MEDICATIONS_YAML_UNDELIMITED,\n          expected_output=_EXPECTED_TWO_MEDICATIONS_ANNOTATED,\n      ),\n  )\n  def test_resolve_valid_inputs(self, resolver, input_text, expected_output):\n    actual_extractions = resolver.resolve(input_text)\n    self.assertCountEqual(expected_output, actual_extractions)\n    assert_char_interval_match_source(self, input_text, actual_extractions)\n\n  def test_handle_integer_extraction(self):\n    test_input = textwrap.dedent(f\"\"\"\\\n    ```json\n    {{\n      \"{data.EXTRACTIONS_KEY}\": [\n        {{\n          \"year\": 2006,\n          \"year_index\": 6\n        }}\n      ]\n    }}\n    ```\"\"\")\n    expected_extractions = [\n        data.Extraction(\n            extraction_class=\"year\",\n            extraction_text=\"2006\",\n            extraction_index=6,\n            group_index=0,\n        )\n    ]\n\n    actual_extractions = self.default_resolver.resolve(test_input)\n    self.assertEqual(expected_extractions, list(actual_extractions))\n\n  def test_resolve_empty_yaml(self):\n    test_input = \"```json\\n```\"\n    actual = self.default_resolver.resolve(\n        test_input, suppress_parse_errors=True\n    )\n    self.assertEmpty(actual)\n\n  def test_resolve_empty_yaml_without_suppress_parse_errors(self):\n    test_input = \"```json\\n```\"\n    with self.assertRaises(resolver_lib.ResolverParsingError):\n      self.default_resolver.resolve(test_input, suppress_parse_errors=False)\n\n  def test_align_with_valid_chunk(self):\n    text = \"This is a sample text with some extractions.\"\n    tokenized_text = tokenizer.tokenize(text)\n\n    chunk = tokenizer.TokenInterval(start_index=0, end_index=8)\n    annotated_extractions = [\n        data.Extraction(\n            extraction_class=\"medication\", extraction_text=\"sample\"\n        ),\n        data.Extraction(\n            extraction_class=\"condition\", extraction_text=\"extractions\"\n        ),\n    ]\n    expected_extractions = [\n        data.Extraction(\n            extraction_class=\"medication\",\n            extraction_text=\"sample\",\n            token_interval=tokenizer.TokenInterval(start_index=3, end_index=4),\n            char_interval=data.CharInterval(start_pos=10, end_pos=16),\n            alignment_status=data.AlignmentStatus.MATCH_EXACT,\n        ),\n        data.Extraction(\n            extraction_class=\"condition\",\n            extraction_text=\"extractions\",\n            token_interval=tokenizer.TokenInterval(start_index=7, end_index=8),\n            char_interval=data.CharInterval(start_pos=32, end_pos=43),\n            alignment_status=data.AlignmentStatus.MATCH_EXACT,\n        ),\n    ]\n\n    chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n    token_offset = chunk.start_index\n    aligned_extractions = list(\n        self.default_resolver.align(\n            extractions=annotated_extractions,\n            source_text=chunk_text,\n            token_offset=token_offset,\n            char_offset=0,\n            enable_fuzzy_alignment=False,\n        )\n    )\n\n    self.assertEqual(len(aligned_extractions), len(expected_extractions))\n    for expected, actual in zip(expected_extractions, aligned_extractions):\n      self.assertDataclassEqual(expected, actual)\n    assert_char_interval_match_source(self, text, aligned_extractions)\n\n  def test_align_with_chunk_starting_in_middle(self):\n    text = \"This is a sample text with some extractions.\"\n    tokenized_text = tokenizer.tokenize(text)\n\n    chunk = tokenizer.TokenInterval(start_index=3, end_index=8)\n    annotated_extractions = [\n        data.Extraction(\n            extraction_class=\"medication\", extraction_text=\"sample\"\n        ),\n        data.Extraction(\n            extraction_class=\"condition\", extraction_text=\"extractions\"\n        ),\n    ]\n    expected_extractions = [\n        data.Extraction(\n            extraction_class=\"medication\",\n            extraction_text=\"sample\",\n            token_interval=tokenizer.TokenInterval(start_index=3, end_index=4),\n            char_interval=data.CharInterval(start_pos=10, end_pos=16),\n            alignment_status=data.AlignmentStatus.MATCH_EXACT,\n        ),\n        data.Extraction(\n            extraction_class=\"condition\",\n            extraction_text=\"extractions\",\n            token_interval=tokenizer.TokenInterval(start_index=7, end_index=8),\n            char_interval=data.CharInterval(start_pos=32, end_pos=43),\n            alignment_status=data.AlignmentStatus.MATCH_EXACT,\n        ),\n    ]\n\n    chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n    token_offset = chunk.start_index\n    # Compute global char offset from the token at chunk.start_index.\n    char_offset = tokenized_text.tokens[\n        chunk.start_index\n    ].char_interval.start_pos\n    aligned_extractions = list(\n        self.default_resolver.align(\n            extractions=annotated_extractions,\n            source_text=chunk_text,\n            token_offset=token_offset,\n            char_offset=char_offset,\n            enable_fuzzy_alignment=False,\n        )\n    )\n\n    self.assertEqual(len(aligned_extractions), len(expected_extractions))\n    for expected, actual in zip(expected_extractions, aligned_extractions):\n      self.assertDataclassEqual(expected, actual)\n\n    assert_char_interval_match_source(self, text, aligned_extractions)\n\n  def test_align_with_no_extractions_in_chunk(self):\n    tokenized_text = tokenizer.tokenize(\"No extractions here.\")\n\n    # Define a chunk that includes the entire text.\n    chunk = tokenizer.TokenInterval()\n    chunk.start_index = 0\n    chunk.end_index = 3\n    annotated_extractions = []\n\n    chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n    token_offset = chunk.start_index\n    aligned_extractions = list(\n        self.default_resolver.align(\n            extractions=annotated_extractions,\n            source_text=chunk_text,\n            token_offset=token_offset,\n            char_offset=0,\n            enable_fuzzy_alignment=False,\n        )\n    )\n\n    self.assertEmpty(aligned_extractions)\n\n  def test_align_successful(self):\n    tokenized_text = tokenizer.TokenizedText(\n        text=\"zero one two\",\n        tokens=[\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=0, end_pos=4),\n                index=0,\n            ),\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=5, end_pos=8),\n                index=1,\n            ),\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=9, end_pos=12),\n                index=2,\n            ),\n        ],\n    )\n\n    # Define a chunk that includes the entire text.\n    chunk = tokenizer.TokenInterval(start_index=0, end_index=3)\n    annotated_extractions = [\n        data.Extraction(extraction_class=\"foo\", extraction_text=\"zero\"),\n        data.Extraction(extraction_class=\"foo\", extraction_text=\"one\"),\n    ]\n\n    chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n    token_offset = chunk.start_index\n    aligned_extractions = list(\n        self.default_resolver.align(\n            extractions=annotated_extractions,\n            source_text=chunk_text,\n            token_offset=token_offset,\n            char_offset=0,\n            enable_fuzzy_alignment=False,\n        )\n    )\n\n    self.assertLen(aligned_extractions, 2)\n    assert_char_interval_match_source(\n        self, tokenized_text.text, aligned_extractions\n    )\n\n  def test_align_with_discontinuous_tokenized_text(self):\n    tokenized_text = tokenizer.TokenizedText(\n        text=\"zero one five\",\n        tokens=[\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=0, end_pos=4),\n                index=0,\n            ),\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=5, end_pos=8),\n                index=1,\n            ),\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=9, end_pos=14),\n                index=5,\n            ),\n        ],\n    )\n\n    # Define a chunk that includes too many tokens.\n    chunk = tokenizer.TokenInterval(start_index=0, end_index=6)\n    annotated_extractions = [\n        data.Extraction(extraction_class=\"foo\", extraction_text=\"zero\"),\n        data.Extraction(extraction_class=\"foo\", extraction_text=\"one\"),\n    ]\n\n    with self.assertRaises(tokenizer.InvalidTokenIntervalError):\n      chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n      token_offset = chunk.start_index\n      list(\n          self.default_resolver.align(\n              annotated_extractions,\n              chunk_text,\n              token_offset,\n              enable_fuzzy_alignment=False,\n          )\n      )\n\n  def test_align_with_discontinuous_tokenized_text_but_right_chunk(self):\n    tokenized_text = tokenizer.TokenizedText(\n        text=\"zero one five\",\n        tokens=[\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=0, end_pos=4),\n                index=0,\n            ),\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=5, end_pos=8),\n                index=1,\n            ),\n            tokenizer.Token(\n                token_type=tokenizer.TokenType.WORD,\n                char_interval=tokenizer.CharInterval(start_pos=9, end_pos=14),\n                index=5,\n            ),\n        ],\n    )\n\n    # Define a correct chunk.\n    chunk = tokenizer.TokenInterval(start_index=0, end_index=3)\n    annotated_extractions = [\n        data.Extraction(extraction_class=\"foo\", extraction_text=\"zero\"),\n        data.Extraction(extraction_class=\"foo\", extraction_text=\"one\"),\n    ]\n\n    chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n    token_offset = chunk.start_index\n    aligned_extractions = list(\n        self.default_resolver.align(\n            extractions=annotated_extractions,\n            source_text=chunk_text,\n            token_offset=token_offset,\n            char_offset=0,\n            enable_fuzzy_alignment=False,\n        )\n    )\n    self.assertLen(aligned_extractions, 2)\n    assert_char_interval_match_source(\n        self, tokenized_text.text, aligned_extractions\n    )\n\n  def test_align_with_empty_annotated_extractions(self):\n    \"\"\"Test align method with empty annotated_extractions sequence.\"\"\"\n    tokenized_text = tokenizer.tokenize(\"No extractions here.\")\n\n    # Define a chunk that includes the entire text.\n    chunk = tokenizer.TokenInterval()\n    chunk.start_index = 0\n    chunk.end_index = 3\n    annotated_extractions = []  # Empty sequence representing no extractions\n\n    chunk_text = chunking.get_token_interval_text(tokenized_text, chunk)\n    token_offset = chunk.start_index\n    aligned_extractions = list(\n        self.default_resolver.align(\n            extractions=annotated_extractions,\n            source_text=chunk_text,\n            token_offset=token_offset,\n            char_offset=0,\n            enable_fuzzy_alignment=False,\n        )\n    )\n\n    self.assertEmpty(aligned_extractions)\n\n\nclass FenceFallbackTest(parameterized.TestCase):\n  \"\"\"Tests for fence marker fallback behavior.\"\"\"\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"with_valid_fences\",\n          test_input=textwrap.dedent(\"\"\"\\\n              ```json\n              {\n                \"extractions\": [\n                  {\"person\": \"Marie Curie\", \"person_attributes\": {\"field\": \"physics\"}}\n                ]\n              }\n              ```\"\"\"),\n          fence_output=True,\n          strict_fences=False,\n          expected_key=\"person\",\n          expected_value=\"Marie Curie\",\n      ),\n      dict(\n          testcase_name=\"fallback_no_fences\",\n          test_input=textwrap.dedent(\"\"\"\\\n              {\n                \"extractions\": [\n                  {\"person\": \"Albert Einstein\", \"person_attributes\": {\"field\": \"physics\"}}\n                ]\n              }\"\"\"),\n          fence_output=True,\n          strict_fences=False,\n          expected_key=\"person\",\n          expected_value=\"Albert Einstein\",\n      ),\n      dict(\n          testcase_name=\"no_fence_expectation\",\n          test_input=textwrap.dedent(\"\"\"\\\n              {\n                \"extractions\": [\n                  {\"drug\": \"Aspirin\", \"drug_attributes\": {\"dosage\": \"100mg\"}}\n                ]\n              }\"\"\"),\n          fence_output=False,\n          strict_fences=False,\n          expected_key=\"drug\",\n          expected_value=\"Aspirin\",\n      ),\n  )\n  def test_parsing_scenarios(\n      self,\n      test_input,\n      fence_output,\n      strict_fences,\n      expected_key,\n      expected_value,\n  ):\n    resolver = resolver_lib.Resolver(\n        fence_output=fence_output,\n        format_type=data.FormatType.JSON,\n        strict_fences=strict_fences,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 1)\n    self.assertIn(expected_key, result[0])\n    self.assertEqual(result[0][expected_key], expected_value)\n\n  def test_fallback_preserves_content_integrity(self):\n    test_input = textwrap.dedent(\"\"\"\\\n        {\n          \"extractions\": [\n            {\n              \"medication\": \"Ibuprofen\",\n              \"medication_attributes\": {\n                \"dosage\": \"200mg\",\n                \"frequency\": \"twice daily\"\n              }\n            },\n            {\n              \"condition\": \"headache\",\n              \"condition_attributes\": {\n                \"severity\": \"mild\"\n              }\n            }\n          ]\n        }\"\"\")\n    resolver = resolver_lib.Resolver(\n        fence_output=True,\n        format_type=data.FormatType.JSON,\n        strict_fences=False,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 2, \"Should preserve all extractions during fallback\")\n\n    self.assertEqual(\n        result[0][\"medication\"],\n        \"Ibuprofen\",\n        \"First extraction should have correct medication\",\n    )\n    self.assertEqual(\n        result[0][\"medication_attributes\"][\"dosage\"],\n        \"200mg\",\n        \"Should preserve nested attributes in fallback\",\n    )\n\n    self.assertEqual(\n        result[1][\"condition\"],\n        \"headache\",\n        \"Second extraction should have correct condition\",\n    )\n    self.assertEqual(\n        result[1][\"condition_attributes\"][\"severity\"],\n        \"mild\",\n        \"Should preserve all nested attributes\",\n    )\n\n  def test_malformed_json_still_raises_error(self):\n    test_input = textwrap.dedent(\"\"\"\\\n        {\n          \"extractions\": [\n            {\"person\": \"Missing closing brace\"\n          ]\"\"\")\n    resolver = resolver_lib.Resolver(\n        fence_output=True,\n        format_type=data.FormatType.JSON,\n        strict_fences=False,\n    )\n    with self.assertRaises(resolver_lib.ResolverParsingError):\n      resolver.string_to_extraction_data(test_input)\n\n  def test_strict_fences_raises_on_missing_markers(self):\n    strict_resolver = resolver_lib.Resolver(\n        fence_output=True,\n        format_type=data.FormatType.JSON,\n        strict_fences=True,\n    )\n    test_input = textwrap.dedent(\"\"\"\\\n        {\"extractions\": [{\"person\": \"Test\"}]}\"\"\")\n\n    with self.assertRaisesRegex(\n        resolver_lib.ResolverParsingError, \".*fence markers.*\"\n    ):\n      strict_resolver.string_to_extraction_data(test_input)\n\n  def test_default_allows_fallback(self):\n    default_resolver = resolver_lib.Resolver(\n        fence_output=True,\n        format_type=data.FormatType.JSON,\n    )\n    test_input = textwrap.dedent(\"\"\"\\\n        {\"extractions\": [{\"person\": \"Default Test\"}]}\"\"\")\n\n    result = default_resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 1)\n    self.assertEqual(result[0][\"person\"], \"Default Test\")\n\n  def test_rejects_multiple_fenced_blocks(self):\n    test_input = textwrap.dedent(\"\"\"\\\n        preamble\n        ```json\n        {\"extractions\": [{\"item\": \"first\"}]}\n        ```\n        Some explanation text\n        ```json\n        {\"extractions\": [{\"item\": \"second\"}]}\n        ```\"\"\")\n    resolver = resolver_lib.Resolver(\n        fence_output=True,\n        format_type=data.FormatType.JSON,\n        strict_fences=False,\n    )\n    with self.assertRaisesRegex(\n        resolver_lib.ResolverParsingError, \"Multiple fenced blocks found\"\n    ):\n      resolver.string_to_extraction_data(test_input)\n\n\nclass FlexibleSchemaTest(parameterized.TestCase):\n  \"\"\"Tests for flexible schema formats without extractions key.\"\"\"\n\n  def test_direct_list_format(self):\n    test_input = textwrap.dedent(\"\"\"\\\n        [\n          {\"person\": \"Marie Curie\", \"field\": \"physics\"},\n          {\"person\": \"Albert Einstein\", \"field\": \"relativity\"}\n        ]\"\"\")\n    resolver = resolver_lib.Resolver(\n        fence_output=False,\n        format_type=data.FormatType.JSON,\n        require_extractions_key=False,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 2)\n    self.assertEqual(result[0][\"person\"], \"Marie Curie\")\n    self.assertEqual(result[1][\"person\"], \"Albert Einstein\")\n\n  def test_single_dict_as_extraction(self):\n    test_input = '{\"person\": \"Isaac Newton\", \"field\": \"gravity\"}'\n    resolver = resolver_lib.Resolver(\n        fence_output=False,\n        format_type=data.FormatType.JSON,\n        require_extractions_key=False,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 1)\n    self.assertEqual(result[0][\"person\"], \"Isaac Newton\")\n    self.assertEqual(result[0][\"field\"], \"gravity\")\n\n  def test_traditional_format_still_works(self):\n    test_input = textwrap.dedent(\"\"\"\\\n        {\n          \"extractions\": [\n            {\"person\": \"Charles Darwin\", \"field\": \"evolution\"}\n          ]\n        }\"\"\")\n    resolver = resolver_lib.Resolver(\n        fence_output=False,\n        format_type=data.FormatType.JSON,\n        require_extractions_key=False,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 1)\n    self.assertEqual(result[0][\"person\"], \"Charles Darwin\")\n\n  def test_lenient_mode_accepts_list(self):\n    # Some models return [...] instead of {\"extractions\": [...]}\n    test_input = '[{\"person\": \"Test\"}]'\n    resolver = resolver_lib.Resolver(\n        fence_output=False,\n        format_type=data.FormatType.JSON,\n        require_extractions_key=True,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 1)\n    self.assertEqual(result[0][\"person\"], \"Test\")\n\n  def test_flexible_with_attributes(self):\n    test_input = textwrap.dedent(\"\"\"\\\n        [\n          {\n            \"medication\": \"Aspirin\",\n            \"medication_attributes\": {\"dosage\": \"100mg\", \"frequency\": \"daily\"}\n          },\n          {\n            \"medication\": \"Ibuprofen\",\n            \"medication_attributes\": {\"dosage\": \"200mg\"}\n          }\n        ]\"\"\")\n    resolver = resolver_lib.Resolver(\n        fence_output=False,\n        format_type=data.FormatType.JSON,\n        require_extractions_key=False,\n    )\n    result = resolver.string_to_extraction_data(test_input)\n    self.assertLen(result, 2)\n    self.assertEqual(result[0][\"medication\"], \"Aspirin\")\n    self.assertEqual(result[0][\"medication_attributes\"][\"dosage\"], \"100mg\")\n    self.assertEqual(result[1][\"medication\"], \"Ibuprofen\")\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/schema_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for the schema module.\n\nNote: This file contains test helper classes that intentionally have\nfew public methods. The too-few-public-methods warnings are expected.\n\"\"\"\n\nfrom unittest import mock\nimport warnings\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract.core import base_model\nfrom langextract.core import data\nfrom langextract.core import format_handler as fh\nfrom langextract.core import schema\nfrom langextract.providers import schemas\n\n\nclass BaseSchemaTest(absltest.TestCase):\n  \"\"\"Tests for BaseSchema abstract class.\"\"\"\n\n  def test_abstract_methods_required(self):\n    \"\"\"Test that BaseSchema cannot be instantiated directly.\"\"\"\n    with self.assertRaises(TypeError):\n      schema.BaseSchema()  # pylint: disable=abstract-class-instantiated\n\n  def test_subclass_must_implement_all_methods(self):\n    \"\"\"Test that subclasses must implement all abstract methods.\"\"\"\n\n    class IncompleteSchema(schema.BaseSchema):  # pylint: disable=too-few-public-methods\n\n      @classmethod\n      def from_examples(cls, examples_data, attribute_suffix=\"_attributes\"):\n        return cls()\n\n    with self.assertRaises(TypeError):\n      IncompleteSchema()  # pylint: disable=abstract-class-instantiated\n\n\nclass BaseLanguageModelSchemaTest(absltest.TestCase):\n  \"\"\"Tests for BaseLanguageModel schema methods.\"\"\"\n\n  def test_get_schema_class_returns_none_by_default(self):\n    \"\"\"Test that get_schema_class returns None by default.\"\"\"\n\n    class TestModel(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n      def infer(self, batch_prompts, **kwargs):\n        yield []\n\n    self.assertIsNone(TestModel.get_schema_class())\n\n  def test_apply_schema_stores_instance(self):\n    \"\"\"Test that apply_schema stores the schema instance.\"\"\"\n\n    class TestModel(base_model.BaseLanguageModel):  # pylint: disable=too-few-public-methods\n\n      def infer(self, batch_prompts, **kwargs):\n        yield []\n\n    model = TestModel()\n\n    mock_schema = mock.Mock(spec=schema.BaseSchema)\n\n    model.apply_schema(mock_schema)\n\n    self.assertEqual(model._schema, mock_schema)\n\n    model.apply_schema(None)\n    self.assertIsNone(model._schema)\n\n\nclass GeminiSchemaTest(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"empty_extractions\",\n          examples_data=[],\n          expected_schema={\n              \"type\": \"object\",\n              \"properties\": {\n                  data.EXTRACTIONS_KEY: {\n                      \"type\": \"array\",\n                      \"items\": {\n                          \"type\": \"object\",\n                          \"properties\": {},\n                      },\n                  },\n              },\n              \"required\": [data.EXTRACTIONS_KEY],\n          },\n      ),\n      dict(\n          testcase_name=\"single_extraction_no_attributes\",\n          examples_data=[\n              data.ExampleData(\n                  text=\"Patient has diabetes.\",\n                  extractions=[\n                      data.Extraction(\n                          extraction_text=\"diabetes\",\n                          extraction_class=\"condition\",\n                      )\n                  ],\n              )\n          ],\n          expected_schema={\n              \"type\": \"object\",\n              \"properties\": {\n                  data.EXTRACTIONS_KEY: {\n                      \"type\": \"array\",\n                      \"items\": {\n                          \"type\": \"object\",\n                          \"properties\": {\n                              \"condition\": {\"type\": \"string\"},\n                              \"condition_attributes\": {\n                                  \"type\": \"object\",\n                                  \"properties\": {\n                                      \"_unused\": {\"type\": \"string\"},\n                                  },\n                                  \"nullable\": True,\n                              },\n                          },\n                      },\n                  },\n              },\n              \"required\": [data.EXTRACTIONS_KEY],\n          },\n      ),\n      dict(\n          testcase_name=\"single_extraction\",\n          examples_data=[\n              data.ExampleData(\n                  text=\"Patient has diabetes.\",\n                  extractions=[\n                      data.Extraction(\n                          extraction_text=\"diabetes\",\n                          extraction_class=\"condition\",\n                          attributes={\"chronicity\": \"chronic\"},\n                      )\n                  ],\n              )\n          ],\n          expected_schema={\n              \"type\": \"object\",\n              \"properties\": {\n                  data.EXTRACTIONS_KEY: {\n                      \"type\": \"array\",\n                      \"items\": {\n                          \"type\": \"object\",\n                          \"properties\": {\n                              \"condition\": {\"type\": \"string\"},\n                              \"condition_attributes\": {\n                                  \"type\": \"object\",\n                                  \"properties\": {\n                                      \"chronicity\": {\"type\": \"string\"},\n                                  },\n                                  \"nullable\": True,\n                              },\n                          },\n                      },\n                  },\n              },\n              \"required\": [data.EXTRACTIONS_KEY],\n          },\n      ),\n      dict(\n          testcase_name=\"multiple_extraction_classes\",\n          examples_data=[\n              data.ExampleData(\n                  text=\"Patient has diabetes.\",\n                  extractions=[\n                      data.Extraction(\n                          extraction_text=\"diabetes\",\n                          extraction_class=\"condition\",\n                          attributes={\"chronicity\": \"chronic\"},\n                      )\n                  ],\n              ),\n              data.ExampleData(\n                  text=\"Patient is John Doe\",\n                  extractions=[\n                      data.Extraction(\n                          extraction_text=\"John Doe\",\n                          extraction_class=\"patient\",\n                          attributes={\"id\": \"12345\"},\n                      )\n                  ],\n              ),\n          ],\n          expected_schema={\n              \"type\": \"object\",\n              \"properties\": {\n                  data.EXTRACTIONS_KEY: {\n                      \"type\": \"array\",\n                      \"items\": {\n                          \"type\": \"object\",\n                          \"properties\": {\n                              \"condition\": {\"type\": \"string\"},\n                              \"condition_attributes\": {\n                                  \"type\": \"object\",\n                                  \"properties\": {\n                                      \"chronicity\": {\"type\": \"string\"}\n                                  },\n                                  \"nullable\": True,\n                              },\n                              \"patient\": {\"type\": \"string\"},\n                              \"patient_attributes\": {\n                                  \"type\": \"object\",\n                                  \"properties\": {\n                                      \"id\": {\"type\": \"string\"},\n                                  },\n                                  \"nullable\": True,\n                              },\n                          },\n                      },\n                  },\n              },\n              \"required\": [data.EXTRACTIONS_KEY],\n          },\n      ),\n  )\n  def test_from_examples_constructs_expected_schema(\n      self, examples_data, expected_schema\n  ):\n    gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n    actual_schema = gemini_schema.schema_dict\n    self.assertEqual(actual_schema, expected_schema)\n\n  def test_to_provider_config_returns_response_schema(self):\n    \"\"\"Test that to_provider_config returns the correct provider kwargs.\"\"\"\n    examples_data = [\n        data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test_class\",\n                    extraction_text=\"test extraction\",\n                )\n            ],\n        )\n    ]\n\n    gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n    provider_config = gemini_schema.to_provider_config()\n\n    self.assertIn(\"response_schema\", provider_config)\n    self.assertEqual(\n        provider_config[\"response_schema\"], gemini_schema.schema_dict\n    )\n\n  def test_requires_raw_output_returns_true(self):\n    \"\"\"Test that GeminiSchema requires raw output.\"\"\"\n    examples_data = [\n        data.ExampleData(\n            text=\"Test text\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"test_class\",\n                    extraction_text=\"test extraction\",\n                )\n            ],\n        )\n    ]\n\n    gemini_schema = schemas.gemini.GeminiSchema.from_examples(examples_data)\n    self.assertTrue(gemini_schema.requires_raw_output)\n\n\nclass SchemaValidationTest(parameterized.TestCase):\n  \"\"\"Tests for schema format validation.\"\"\"\n\n  def _create_test_schema(self):\n    \"\"\"Helper to create a test schema.\"\"\"\n    examples = [\n        data.ExampleData(\n            text=\"Test\",\n            extractions=[\n                data.Extraction(\n                    extraction_class=\"entity\",\n                    extraction_text=\"test\",\n                )\n            ],\n        )\n    ]\n    return schemas.gemini.GeminiSchema.from_examples(examples)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"warns_about_fences\",\n          use_fences=True,\n          use_wrapper=True,\n          wrapper_key=data.EXTRACTIONS_KEY,\n          expected_warning=\"fence_output=True may cause parsing issues\",\n      ),\n      dict(\n          testcase_name=\"warns_about_wrong_wrapper_key\",\n          use_fences=False,\n          use_wrapper=True,\n          wrapper_key=\"wrong_key\",\n          expected_warning=\"response_schema expects wrapper_key='extractions'\",\n      ),\n      dict(\n          testcase_name=\"no_warning_with_correct_settings\",\n          use_fences=False,\n          use_wrapper=True,\n          wrapper_key=data.EXTRACTIONS_KEY,\n          expected_warning=None,\n      ),\n  )\n  def test_gemini_validation(\n      self, use_fences, use_wrapper, wrapper_key, expected_warning\n  ):\n    \"\"\"Test GeminiSchema validation with various settings.\"\"\"\n    schema_obj = self._create_test_schema()\n    format_handler = fh.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_fences=use_fences,\n        use_wrapper=use_wrapper,\n        wrapper_key=wrapper_key,\n    )\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter(\"always\")\n      schema_obj.validate_format(format_handler)\n\n      if expected_warning:\n        self.assertLen(\n            w,\n            1,\n            f\"Expected exactly one warning containing '{expected_warning}'\",\n        )\n        self.assertIn(\n            expected_warning,\n            str(w[0].message),\n            f\"Warning message should contain '{expected_warning}'\",\n        )\n      else:\n        self.assertEmpty(w, \"No warnings should be issued for correct settings\")\n\n  def test_base_schema_no_validation(self):\n    \"\"\"Test that base schema has no validation by default.\"\"\"\n    schema_obj = schema.FormatModeSchema()\n    format_handler = fh.FormatHandler(\n        format_type=data.FormatType.JSON,\n        use_fences=True,\n    )\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter(\"always\")\n      schema_obj.validate_format(format_handler)\n\n      self.assertEmpty(\n          w, \"FormatModeSchema should not issue validation warnings\"\n      )\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/test_gemini_batch_api.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for Gemini Batch API functionality.\"\"\"\n\nimport io\nimport json\nfrom unittest import mock\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nfrom google import genai\nfrom google.api_core import exceptions\n\nfrom langextract.providers import gemini\nfrom langextract.providers import gemini_batch as gb\nfrom langextract.providers import schemas\n\n\ndef create_mock_batch_job(\n    state=genai.types.JobState.JOB_STATE_SUCCEEDED,\n    gcs_uri=f\"gs://bucket/output/file{gb._EXT_JSONL}\",\n):\n  \"\"\"Create a mock BatchJob for testing.\"\"\"\n  job = mock.create_autospec(genai.types.BatchJob, instance=True)\n  job.name = \"batches/123\"\n  job.state = state\n  job.dest = mock.create_autospec(\n      genai.types.BatchJobDestination, instance=True\n  )\n  job.dest.gcs_uri = gcs_uri\n  return job\n\n\ndef _create_batch_response(idx, text_content):\n  \"\"\"Helper to create a batch output line with response.\"\"\"\n  if not isinstance(text_content, str):\n    text_content = json.dumps(text_content, separators=(\",\", \":\"))\n  return json.dumps({\n      \"key\": f\"{gb._KEY_IDX}{idx}\",\n      \"response\": {\n          \"candidates\": [{\"content\": {\"parts\": [{\"text\": text_content}]}}]\n      },\n  })\n\n\ndef _create_batch_error(idx, code, message):\n  \"\"\"Helper to create a batch output line with error.\"\"\"\n  return json.dumps({\n      \"key\": f\"{gb._KEY_IDX}{idx}\",\n      \"error\": {\"code\": code, \"message\": message},\n  })\n\n\nclass TestGeminiBatchAPI(absltest.TestCase):\n  \"\"\"Test Gemini Batch API routing and functionality.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self.mock_storage_cls = self.enter_context(\n        mock.patch.object(gb.storage, \"Client\", autospec=True)\n    )\n    self.mock_storage_client = self.mock_storage_cls.return_value\n    self.mock_bucket = self.mock_storage_client.bucket.return_value\n    self.mock_blob = self.mock_bucket.blob.return_value\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_batch_routing_vertex(self, mock_client_cls):\n    \"\"\"Test that batch API is used when enabled and threshold is met (Vertex).\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n\n    self.mock_storage_client.create_bucket.return_value = self.mock_bucket\n\n    output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n    output_blob.name = \"output.jsonl\"\n    # Mock blob.open context manager\n    output_blob.open.return_value.__enter__.return_value = io.StringIO(\n        \"\\n\".join([\n            _create_batch_response(0, {\"ok\": 1}),\n            _create_batch_response(1, {\"ok\": 2}),\n        ])\n    )\n    self.mock_bucket.list_blobs.return_value = [output_blob]\n\n    mock_client.batches.create.return_value = create_mock_batch_job()\n    mock_client.batches.get.return_value = create_mock_batch_job()\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"test-project\",\n        location=gb._DEFAULT_LOCATION,\n        batch={\n            \"enabled\": True,\n            \"threshold\": 2,\n            \"poll_interval\": 1,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n    prompts = [\"p1\", \"p2\"]\n    outs = list(model.infer(prompts))\n\n    self.assertLen(outs, 2)\n    self.assertEqual(outs[0][0].output, '{\"ok\":1}')\n    self.assertEqual(outs[1][0].output, '{\"ok\":2}')\n\n    self.mock_blob.upload_from_filename.assert_called()\n\n    mock_client.batches.create.assert_called()\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_realtime_when_disabled(self, mock_client_cls):\n    \"\"\"Test that real-time API is used when batch is disabled.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n    mock_response = mock.create_autospec(\n        genai.types.GenerateContentResponse, instance=True\n    )\n    mock_response.text = '{\"ok\":1}'\n    mock_client.models.generate_content.return_value = mock_response\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        batch={\"enabled\": False},\n    )\n    outs = list(model.infer([\"hello\"]))\n\n    self.assertLen(outs, 1)\n    self.assertEqual(outs[0][0].output, '{\"ok\":1}')\n    mock_client.models.generate_content.assert_called()\n    mock_client.batches.create.assert_not_called()\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_realtime_when_below_threshold(self, mock_client_cls):\n    \"\"\"Test that real-time API is used when prompt count is below threshold.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n    mock_response = mock.create_autospec(\n        genai.types.GenerateContentResponse, instance=True\n    )\n    mock_response.text = '{\"ok\":1}'\n    mock_client.models.generate_content.return_value = mock_response\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        batch={\n            \"enabled\": True,\n            \"threshold\": 10,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n    outs = list(model.infer([\"hello\"]))\n\n    self.assertLen(outs, 1)\n    self.assertEqual(outs[0][0].output, '{\"ok\":1}')\n    mock_client.models.generate_content.assert_called()\n    mock_client.batches.create.assert_not_called()\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_batch_with_schema(self, mock_client_cls):\n    \"\"\"Test that batch API properly includes schema when configured.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n\n    output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n    output_blob.name = f\"output{gb._EXT_JSONL}\"\n    output_blob.open.return_value.__enter__.return_value = io.StringIO(\n        _create_batch_response(0, {\"name\": \"test\"})\n    )\n    self.mock_bucket.list_blobs.return_value = [output_blob]\n\n    mock_client.batches.create.return_value = create_mock_batch_job()\n    mock_client.batches.get.return_value = create_mock_batch_job()\n\n    mock_schema = mock.create_autospec(\n        schemas.gemini.GeminiSchema, instance=True\n    )\n    mock_schema.schema_dict = {\n        \"type\": \"object\",\n        \"properties\": {\"name\": {\"type\": \"string\"}},\n    }\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        gemini_schema=mock_schema,\n        batch={\n            \"enabled\": True,\n            \"threshold\": 1,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n\n    # Mock _submit_file to verify the request payload contains the schema.\n    with mock.patch.object(gb, \"_submit_file\", autospec=True) as mock_submit:\n      mock_submit.return_value = create_mock_batch_job()\n\n      outs = list(model.infer([\"test prompt\"]))\n\n      self.assertLen(outs, 1)\n      self.assertEqual(outs[0][0].output, '{\"name\":\"test\"}')\n\n      # Verify _submit_file was called with project and location parameters.\n      mock_submit.assert_called_with(\n          mock_client,\n          \"gemini-2.5-flash\",\n          [{\n              \"contents\": [\n                  {\"role\": \"user\", \"parts\": [{\"text\": \"test prompt\"}]}\n              ],\n              \"generationConfig\": {\n                  \"responseMimeType\": \"application/json\",\n                  \"responseSchema\": mock_schema.schema_dict,\n                  \"temperature\": 0.0,\n              },\n          }],\n          mock.ANY,  # Display name contains timestamp/random.\n          None,  # retention_days\n          \"p\",  # project\n          \"l\",  # location\n      )\n\n    self.assertEqual(model.gemini_schema.schema_dict, mock_schema.schema_dict)\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_batch_error_handling(self, mock_client_cls):\n    \"\"\"Test that batch errors are properly handled and raised.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n    mock_client.batches.create.side_effect = Exception(\"Batch API error\")\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        batch={\n            \"enabled\": True,\n            \"threshold\": 1,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n\n    with self.assertRaisesRegex(Exception, \"Gemini Batch API error\"):\n      list(model.infer([\"test prompt\"]))\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_file_based_ordering(self, mock_client_cls):\n    \"\"\"Test that file-based results are returned in correct order.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n\n    # Define inputs and expected outputs\n    prompts = [\"prompt 0\", \"prompt 1\", \"prompt 2\"]\n    # Simulate shuffled response in the file\n    output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n    output_blob.name = f\"output{gb._EXT_JSONL}\"\n    output_blob.open.return_value.__enter__.return_value = io.StringIO(\n        \"\\n\".join([\n            _create_batch_response(2, \"response 2\"),\n            _create_batch_response(0, \"response 0\"),\n            _create_batch_response(1, \"response 1\"),\n        ])\n    )\n    self.mock_bucket.list_blobs.return_value = [output_blob]\n\n    job = create_mock_batch_job()\n    mock_client.batches.create.return_value = job\n    mock_client.batches.get.return_value = job\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        batch={\n            \"enabled\": True,\n            \"threshold\": 1,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n\n    results = list(model.infer(prompts))\n\n    # Verify results are in original order despite shuffled response\n    self.assertListEqual(\n        [r[0].output for r in results],\n        [\"response 0\", \"response 1\", \"response 2\"],\n    )\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_max_prompts_per_job(self, mock_client_cls):\n    \"\"\"Test that requests are split into multiple batch jobs when they exceed max_prompts_per_job.\n\n    This verifies that:\n    1. Large requests are chunked correctly based on the limit.\n    2. Multiple batch jobs are submitted.\n    3. Results are aggregated and returned in the correct order.\n    \"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n\n    # Define inputs and expected behavior\n    prompts = [\"p1\", \"p2\", \"p3\", \"p4\", \"p5\"]\n    max_prompts_per_job = 2\n    # Expected chunks: [\"p1\", \"p2\"], [\"p3\", \"p4\"], [\"p5\"]\n\n    # Setup mock storage and blobs for 3 separate jobs\n    blob0 = mock.create_autospec(gb.storage.Blob, instance=True)\n    blob0.name = f\"out0{gb._EXT_JSONL}\"\n    blob0.open.return_value.__enter__.return_value = io.StringIO(\n        \"\\n\".join([\n            _create_batch_response(0, \"r1\"),\n            _create_batch_response(1, \"r2\"),\n        ])\n    )\n\n    blob1 = mock.create_autospec(gb.storage.Blob, instance=True)\n    blob1.name = f\"out1{gb._EXT_JSONL}\"\n    blob1.open.return_value.__enter__.return_value = io.StringIO(\n        \"\\n\".join([\n            _create_batch_response(0, \"r3\"),\n            _create_batch_response(1, \"r4\"),\n        ])\n    )\n\n    blob2 = mock.create_autospec(gb.storage.Blob, instance=True)\n    blob2.name = f\"out2{gb._EXT_JSONL}\"\n    blob2.open.return_value.__enter__.return_value = io.StringIO(\n        _create_batch_response(0, \"r5\")\n    )\n\n    def list_blobs_side_effect(prefix=None):\n      if \"part-0\" in prefix:\n        return [blob0]\n      if \"part-1\" in prefix:\n        return [blob1]\n      if \"part-2\" in prefix:\n        return [blob2]\n      return []\n\n    self.mock_bucket.list_blobs.side_effect = list_blobs_side_effect\n\n    # Setup mock jobs\n    job0 = create_mock_batch_job(gcs_uri=\"gs://b/batch-input/part-0/out\")\n    job1 = create_mock_batch_job(gcs_uri=\"gs://b/batch-input/part-1/out\")\n    job2 = create_mock_batch_job(gcs_uri=\"gs://b/batch-input/part-2/out\")\n\n    mock_client.batches.create.side_effect = [job0, job1, job2]\n    mock_client.batches.get.side_effect = [job0, job1, job2]\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        batch={\n            \"enabled\": True,\n            \"threshold\": 1,\n            \"max_prompts_per_job\": max_prompts_per_job,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n\n    results = list(model.infer(prompts))\n\n    self.assertEqual(mock_client.batches.create.call_count, 3)\n    self.assertListEqual(\n        [r[0].output for r in results], [\"r1\", \"r2\", \"r3\", \"r4\", \"r5\"]\n    )\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_batch_item_error(self, mock_client_cls):\n    \"\"\"Test that batch item errors raise exception.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n\n    output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n    output_blob.name = f\"output{gb._EXT_JSONL}\"\n    output_blob.open.return_value.__enter__.return_value = io.StringIO(\n        _create_batch_error(0, 13, \"Internal error\")\n    )\n    self.mock_bucket.list_blobs.return_value = [output_blob]\n\n    job = create_mock_batch_job()\n    mock_client.batches.create.return_value = job\n    mock_client.batches.get.return_value = job\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=\"p\",\n        location=\"l\",\n        batch={\n            \"enabled\": True,\n            \"threshold\": 1,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n\n    with self.assertRaisesRegex(Exception, \"Batch item error\"):\n      list(model.infer([\"test\"]))\n\n\nclass BatchConfigValidationTest(parameterized.TestCase):\n  \"\"\"Test BatchConfig validation logic.\"\"\"\n\n  @parameterized.named_parameters(\n      dict(testcase_name=\"threshold_lt_1\", threshold=0),\n      dict(testcase_name=\"poll_interval_le_0\", poll_interval=0),\n      dict(testcase_name=\"timeout_le_0\", timeout=0),\n      dict(testcase_name=\"max_prompts_per_job_le_0\", max_prompts_per_job=0),\n  )\n  def test_validation_errors(self, **overrides):\n    \"\"\"Verify validation errors for invalid config values.\"\"\"\n    with self.assertRaises(ValueError):\n      gb.BatchConfig(**overrides)\n\n\nclass EmptyAndPaddingTest(absltest.TestCase):\n  \"\"\"Test empty prompt handling and result padding/trimming.\"\"\"\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_empty_prompts_fast_path(self, mock_client_cls):\n    \"\"\"Verify empty prompts return immediately without API calls.\"\"\"\n    outs = gb.infer_batch(\n        client=mock_client_cls.return_value,\n        model_id=\"m\",\n        prompts=[],\n        schema_dict=None,\n        gen_config={},\n        cfg=gb.BatchConfig(\n            enabled=True,\n            poll_interval=1,\n            enable_caching=False,\n            retention_days=None,\n        ),\n    )\n    self.assertEqual(outs, [])\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_file_pad_to_expected_count(self, mock_client_cls):\n    \"\"\"Verify padding to maintain 1:1 alignment with input prompts.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n\n    with mock.patch.object(gb.storage, \"Client\", autospec=True) as mock_storage:\n      mock_bucket = mock_storage.return_value.bucket.return_value\n      output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n      output_blob.name = f\"output{gb._EXT_JSONL}\"\n      output_blob.open.return_value.__enter__.return_value = io.StringIO(\n          _create_batch_response(0, \"only_one\")\n      )\n      mock_bucket.list_blobs.return_value = [output_blob]\n\n      job = create_mock_batch_job()\n      mock_client.batches.create.return_value = job\n      mock_client.batches.get.return_value = job\n\n      cfg = gb.BatchConfig(\n          enabled=True,\n          threshold=1,\n          poll_interval=1,\n          enable_caching=False,\n          retention_days=None,\n      )\n      outs = gb.infer_batch(\n          client=mock_client,\n          model_id=\"m\",\n          prompts=[\"p1\", \"p2\"],\n          schema_dict=None,\n          gen_config={},\n          cfg=cfg,\n      )\n      self.assertEqual(outs, [\"only_one\", \"\"])  # padded\n\n\nclass GCSBatchCachingTest(absltest.TestCase):\n  \"\"\"Test GCS batch caching functionality.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self.mock_storage_cls = self.enter_context(\n        mock.patch.object(gb.storage, \"Client\", autospec=True)\n    )\n    self.mock_storage_client = self.mock_storage_cls.return_value\n    self.mock_bucket = self.mock_storage_client.bucket.return_value\n    self.mock_blob = self.mock_bucket.blob.return_value\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_cache_hit_skips_inference(self, mock_client_cls):\n    \"\"\"Test that fully cached prompts skip inference.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n    mock_client.project = \"p\"\n    mock_client.location = \"l\"\n\n    self.mock_blob.download_as_text.return_value = '{\"text\": \"cached_response\"}'\n\n    cfg = gb.BatchConfig(\n        enabled=True,\n        threshold=1,\n        enable_caching=True,\n        retention_days=None,\n    )\n\n    outs = gb.infer_batch(\n        client=mock_client,\n        model_id=\"m\",\n        prompts=[\"p1\"],\n        schema_dict=None,\n        gen_config={},\n        cfg=cfg,\n    )\n\n    self.assertListEqual(outs, [\"cached_response\"])\n\n    mock_client.batches.create.assert_not_called()\n\n    self.mock_bucket.blob.assert_called()\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  def test_partial_cache_hit(self, mock_client_cls):\n    \"\"\"Test that partial cache hits only submit missing prompts.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n    mock_client.project = \"p\"\n    mock_client.location = \"l\"\n\n    # Mock GCS cache: hit for \"cached_prompt\", miss for \"new_prompt\"\n    # We mock _compute_hash to avoid dealing with complex hashing in test\n    with mock.patch.object(gb.GCSBatchCache, \"_compute_hash\") as mock_hash:\n      mock_hash.side_effect = lambda k: f\"hash_{k['prompt']}\"\n\n      # Pre-configure blobs\n      blob_hit = mock.create_autospec(gb.storage.Blob, instance=True)\n      blob_hit.download_as_text.return_value = '{\"text\": \"cached_response\"}'\n\n      blob_miss = mock.create_autospec(gb.storage.Blob, instance=True)\n      blob_miss.download_as_text.side_effect = exceptions.NotFound(\"Not found\")\n\n      def get_blob(name):\n        if \"hash_cached_prompt\" in name:\n          return blob_hit\n        return blob_miss\n\n      self.mock_bucket.blob.side_effect = get_blob\n\n      # Mock list_blobs to return the batch output file for the new prompt\n      output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n      output_blob.name = f\"output{gb._EXT_JSONL}\"\n      output_blob.open.return_value.__enter__.return_value = io.StringIO(\n          _create_batch_response(0, \"new_response\")\n      )\n      self.mock_bucket.list_blobs.return_value = [output_blob]\n\n      job = create_mock_batch_job()\n      mock_client.batches.create.return_value = job\n      mock_client.batches.get.return_value = job\n\n      cfg = gb.BatchConfig(\n          enabled=True,\n          threshold=1,\n          enable_caching=True,\n          retention_days=None,\n      )\n\n      outs = gb.infer_batch(\n          client=mock_client,\n          model_id=\"m\",\n          prompts=[\"cached_prompt\", \"new_prompt\"],\n          schema_dict=None,\n          gen_config={},\n          cfg=cfg,\n      )\n\n      self.assertListEqual(outs, [\"cached_response\", \"new_response\"])\n      mock_client.batches.create.assert_called_once()\n\n      # Verify \"new_response\" was uploaded to cache (using the miss blob)\n      # The blob used for upload is blob_miss because it was returned for the miss key\n      upload_calls = [\n          call\n          for call in blob_miss.upload_from_string.mock_calls\n          if \"new_response\" in str(call)\n      ]\n      self.assertTrue(\n          upload_calls, \"Should have uploaded new_response to cache\"\n      )\n\n  @mock.patch.object(genai, \"Client\", autospec=True)\n  @mock.patch.dict(\"os.environ\", {}, clear=True)\n  def test_project_passed_to_storage_client(self, mock_client_cls):\n    \"\"\"Test that project parameter is passed to storage.Client constructor.\"\"\"\n    mock_client = mock_client_cls.return_value\n    mock_client.vertexai = True\n    if hasattr(mock_client, \"project\"):\n      del mock_client.project\n\n    self.mock_storage_client.create_bucket.return_value = self.mock_bucket\n\n    output_blob = mock.create_autospec(gb.storage.Blob, instance=True)\n    output_blob.name = f\"output{gb._EXT_JSONL}\"\n    output_blob.open.return_value.__enter__.return_value = io.StringIO(\n        _create_batch_response(0, {\"result\": \"ok\"})\n    )\n    self.mock_bucket.list_blobs.return_value = [output_blob]\n\n    mock_client.batches.create.return_value = create_mock_batch_job()\n    mock_client.batches.get.return_value = create_mock_batch_job()\n\n    # Create model with specific project and location\n    test_project = \"test-project-123\"\n    test_location = \"us-central1\"\n\n    model = gemini.GeminiLanguageModel(\n        model_id=\"gemini-2.5-flash\",\n        vertexai=True,\n        project=test_project,\n        location=test_location,\n        batch={\n            \"enabled\": True,\n            \"threshold\": 1,\n            \"poll_interval\": 0.1,\n            \"enable_caching\": False,\n            \"retention_days\": None,\n        },\n    )\n\n    list(model.infer([\"test prompt\"]))\n\n    # Verify storage.Client was called with the correct project parameter.\n    storage_calls = self.mock_storage_cls.call_args_list\n\n    project_calls = [\n        call\n        for call in storage_calls\n        if call.kwargs.get(\"project\") == test_project\n    ]\n\n    self.assertGreaterEqual(\n        len(project_calls),\n        1,\n        f\"storage.Client should be called with project={test_project}, \"\n        f\"but was called with: {[call.kwargs for call in storage_calls]}\",\n    )\n\n  def test_cache_hashing_stability(self):\n    \"\"\"Test that hash is stable for same inputs.\"\"\"\n    cache = gb.GCSBatchCache(\"b\")\n    data1 = {\"a\": 1, \"b\": 2}\n    data2 = {\"b\": 2, \"a\": 1}\n    self.assertEqual(cache._compute_hash(data1), cache._compute_hash(data2))\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/test_kwargs_passthrough.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for enhanced kwargs pass-through in providers.\"\"\"\n\nimport unittest\nfrom unittest import mock\nimport warnings\n\nfrom absl.testing import parameterized\n\nfrom langextract.providers import ollama\nfrom langextract.providers import openai\n\n\nclass TestOpenAIKwargsPassthrough(unittest.TestCase):\n  \"\"\"Test OpenAI provider's enhanced kwargs handling.\"\"\"\n\n  @mock.patch('openai.OpenAI')\n  def test_reasoning_effort_alias_normalization(self, mock_openai_class):\n    \"\"\"Reasoning_effort parameter should be normalized to {reasoning: {effort: ...}}.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-4o-mini',\n        api_key='test-key',\n        reasoning_effort='minimal',\n    )\n\n    list(model.infer(['test prompt']))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs.get('reasoning'), {'effort': 'minimal'})\n\n  @mock.patch('openai.OpenAI')\n  def test_reasoning_parameter_normalized(self, mock_openai_class):\n    \"\"\"Runtime reasoning_effort should normalize even without constructor param.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-5-nano',\n        api_key='test-key',\n    )\n\n    list(model.infer(['test prompt'], reasoning_effort='maximal'))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs.get('reasoning'), {'effort': 'maximal'})\n\n  @mock.patch('openai.OpenAI')\n  def test_runtime_kwargs_override_stored(self, mock_openai_class):\n    \"\"\"Runtime parameters should override constructor parameters.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-4o-mini',\n        api_key='test-key',\n        temperature=0.7,\n        top_p=0.9,\n    )\n\n    list(model.infer(['test prompt'], temperature=0.3, seed=42))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs.get('temperature'), 0.3)\n    self.assertEqual(call_args.kwargs.get('top_p'), 0.9)\n    self.assertEqual(call_args.kwargs.get('seed'), 42)\n\n  @mock.patch('openai.OpenAI')\n  def test_falsy_values_preserved(self, mock_openai_class):\n    \"\"\"Falsy values like 0 should be preserved, not filtered as None.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-4o',\n        api_key='test-key',\n        temperature=0,\n        top_logprobs=0,\n    )\n\n    list(model.infer(['test prompt']))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs.get('temperature'), 0)\n    self.assertEqual(call_args.kwargs.get('top_logprobs'), 0)\n\n  @mock.patch('openai.OpenAI')\n  def test_both_reasoning_forms_merge(self, mock_openai_class):\n    \"\"\"Both reasoning and reasoning_effort should merge without clobbering.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-5',\n        api_key='test-key',\n        reasoning={'other_field': 'value'},\n        reasoning_effort='maximal',\n    )\n\n    list(model.infer(['test prompt']))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(\n        call_args.kwargs.get('reasoning'),\n        {'other_field': 'value', 'effort': 'maximal'},\n    )\n\n  @mock.patch('openai.OpenAI')\n  def test_custom_response_format(self, mock_openai_class):\n    \"\"\"Custom response_format should override default JSON format.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-4o',\n        api_key='test-key',\n        format_type=openai.data.FormatType.JSON,\n    )\n\n    list(\n        model.infer(\n            ['test prompt'],\n            response_format={'type': 'text', 'schema': 'custom'},\n        )\n    )\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(\n        call_args.kwargs.get('response_format'),\n        {'type': 'text', 'schema': 'custom'},\n    )\n\n  @mock.patch('openai.OpenAI')\n  def test_direct_reasoning_parameter(self, mock_openai_class):\n    \"\"\"Direct reasoning parameter should pass through without modification.\"\"\"\n    mock_client = mock.Mock()\n    mock_openai_class.return_value = mock_client\n\n    mock_response = mock.Mock()\n    mock_response.choices = [\n        mock.Mock(message=mock.Mock(content='{\"result\": \"test\"}'))\n    ]\n    mock_client.chat.completions.create.return_value = mock_response\n\n    model = openai.OpenAILanguageModel(\n        model_id='gpt-5',\n        api_key='test-key',\n    )\n\n    list(model.infer(['test prompt'], reasoning={'effort': 'minimal'}))\n\n    call_args = mock_client.chat.completions.create.call_args\n    self.assertEqual(call_args.kwargs.get('reasoning'), {'effort': 'minimal'})\n\n\nclass TestOllamaAuthSupport(parameterized.TestCase):\n  \"\"\"Test Ollama provider's authentication support for proxied instances.\"\"\"\n\n  @mock.patch('requests.post')\n  def test_api_key_in_authorization_header(self, mock_post):\n    \"\"\"API key should be sent in Authorization header with Bearer scheme.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id='gemma2:2b',\n        model_url='https://proxy.example.com',\n        api_key='sk-test-key-123',\n    )\n\n    list(model.infer(['test prompt']))\n\n    mock_post.assert_called_once()\n    call_args = mock_post.call_args\n    headers = call_args.kwargs.get('headers', {})\n    self.assertEqual(headers.get('Authorization'), 'Bearer sk-test-key-123')\n    self.assertEqual(headers.get('Content-Type'), 'application/json')\n\n  @mock.patch('requests.post')\n  def test_custom_auth_header_name(self, mock_post):\n    \"\"\"Custom auth header name (e.g. X-API-Key) should be supported.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id='gemma2:2b',\n        model_url='https://api.example.com',\n        api_key='abc123',\n        auth_header='X-API-Key',\n        auth_scheme='',\n    )\n\n    list(model.infer(['test prompt']))\n\n    headers = mock_post.call_args.kwargs.get('headers', {})\n    self.assertEqual(headers.get('X-API-Key'), 'abc123')\n    self.assertNotIn('Authorization', headers)\n\n  @mock.patch('requests.post')\n  def test_pass_through_kwargs(self, mock_post):\n    \"\"\"Future Ollama parameters should pass through without code changes.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id='mistral:7b',\n        temperature=0.5,\n        top_k=40,\n        repeat_penalty=1.1,\n        mirostat=2,\n    )\n\n    list(model.infer(['test prompt']))\n\n    call_args = mock_post.call_args\n    payload = call_args.kwargs['json']\n    options = payload['options']\n\n    self.assertEqual(options.get('temperature'), 0.5)\n    self.assertEqual(options.get('top_k'), 40)\n    self.assertEqual(options.get('repeat_penalty'), 1.1)\n    self.assertEqual(options.get('mirostat'), 2)\n\n  def test_api_key_redacted_in_repr(self):\n    \"\"\"API key should be redacted in string representation for security.\"\"\"\n    model = ollama.OllamaLanguageModel(\n        model_id='gemma2:2b',\n        api_key='super-secret-key',\n    )\n\n    repr_str = repr(model)\n    self.assertIn('[REDACTED]', repr_str, 'API key should be redacted')\n    self.assertNotIn(\n        'super-secret-key', repr_str, 'Actual API key should not appear'\n    )\n\n  @mock.patch('requests.post')\n  def test_localhost_auth_warning_but_still_works(self, mock_post):\n    \"\"\"Should warn about localhost auth but still send the auth header.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter('always')\n      model = ollama.OllamaLanguageModel(\n          model_id='gemma2:2b',\n          model_url='http://localhost:11434',\n          api_key='unnecessary-key',\n      )\n\n      self.assertTrue(\n          any('localhost' in str(warning.message) for warning in w),\n          'Expected warning about localhost auth',\n      )\n\n    # Verify auth header is still sent despite warning\n    list(model.infer(['test prompt']))\n    headers = mock_post.call_args.kwargs.get('headers', {})\n    self.assertEqual(headers.get('Authorization'), 'Bearer unnecessary-key')\n\n  @mock.patch('requests.post')\n  def test_runtime_kwargs_override(self, mock_post):\n    \"\"\"Runtime parameters should override constructor parameters.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id='gemma2:2b',\n        temperature=0.7,\n        timeout=60,\n    )\n\n    list(model.infer(['test prompt'], temperature=0.3, timeout=120))\n\n    call_args = mock_post.call_args\n    payload = call_args.kwargs['json']\n    options = payload['options']\n\n    self.assertEqual(options.get('temperature'), 0.3)\n    self.assertEqual(call_args.kwargs.get('timeout'), 120)\n\n  @parameterized.named_parameters(\n      ('https_localhost', 'https://localhost:11434', True),\n      ('ipv6_localhost', 'http://[::1]:11434', True),\n      ('ipv4_localhost', 'http://127.0.0.1:8080/', True),\n      ('remote_proxy', 'https://proxy.example.com', False),\n  )\n  @mock.patch('requests.post')\n  def test_localhost_detection(self, url, should_warn, mock_post):\n    \"\"\"Should detect localhost in various URL formats (IPv6, https, etc).\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter('always')\n      _ = ollama.OllamaLanguageModel(\n          model_id='gemma2:2b',\n          model_url=url,\n          api_key='test-key',\n      )\n\n      if should_warn:\n        self.assertTrue(\n            any('localhost' in str(warning.message) for warning in w),\n            f'Expected warning for {url}',\n        )\n      else:\n        self.assertFalse(\n            any('localhost' in str(warning.message) for warning in w),\n            f'Unexpected warning for {url}',\n        )\n\n  @mock.patch('requests.post')\n  def test_format_none_not_in_payload(self, mock_post):\n    \"\"\"Format key should be omitted from payload when None (not sent as null).\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': 'plain text'}\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id='gemma2:2b',\n    )\n\n    model.format_type = None\n\n    _ = model._ollama_query(\n        prompt='test prompt',\n        model='gemma2:2b',\n        structured_output_format=None,\n    )\n\n    call_args = mock_post.call_args\n    payload = call_args.kwargs['json']\n\n    self.assertNotIn('format', payload, 'format=None should not be in payload')\n\n  @mock.patch('requests.post')\n  def test_reserved_kwargs_not_in_options(self, mock_post):\n    \"\"\"Reserved top-level keys (stop, format) should not go into options dict.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    model = ollama.OllamaLanguageModel(\n        model_id='gemma2:2b',\n        stop=['END'],\n        temperature=0.5,\n        custom_param='value',\n    )\n\n    list(model.infer(['test prompt']))\n\n    call_args = mock_post.call_args\n    payload = call_args.kwargs['json']\n    options = payload['options']\n\n    self.assertEqual(payload.get('stop'), ['END'])\n    self.assertNotIn(\n        'stop', options, 'stop should be at top level, not in options'\n    )\n    self.assertEqual(options.get('temperature'), 0.5)\n    self.assertEqual(options.get('custom_param'), 'value')\n\n  @mock.patch('requests.post')\n  def test_api_key_without_localhost_warning(self, mock_post):\n    \"\"\"Should not warn when using auth with remote/proxied Ollama instances.\"\"\"\n    mock_response = mock.Mock()\n    mock_response.status_code = 200\n    mock_response.json.return_value = {'response': '{\"test\": \"value\"}'}\n    mock_post.return_value = mock_response\n\n    with warnings.catch_warnings(record=True) as w:\n      warnings.simplefilter('always')\n      model = ollama.OllamaLanguageModel(\n          model_id='gemma2:2b',\n          model_url='https://proxy.example.com',\n          api_key='necessary-key',\n      )\n\n      self.assertFalse(\n          any('localhost' in str(warning.message) for warning in w)\n      )\n\n    list(model.infer(['test prompt']))\n    headers = mock_post.call_args.kwargs.get('headers', {})\n    self.assertEqual(headers.get('Authorization'), 'Bearer necessary-key')\n\n\nif __name__ == '__main__':\n  unittest.main()\n"
  },
  {
    "path": "tests/test_live_api.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Live API integration tests that require real API keys.\n\nThese tests are skipped if API keys are not available in the environment.\nThey should run in CI after all other tests pass.\n\"\"\"\n\nimport functools\nimport json\nimport os\nimport re\nimport textwrap\nimport time\nfrom typing import Any\nimport unittest\nfrom unittest import mock\nimport uuid\n\nimport dotenv\nimport google.auth\nimport google.auth.exceptions\nimport google.genai.errors\nimport pytest\n\nfrom langextract import data\nimport langextract as lx\nfrom langextract.core import tokenizer as tokenizer_lib\nfrom langextract.providers import gemini_batch as gb\n\ndotenv.load_dotenv(override=True)\n\nDEFAULT_GEMINI_MODEL = \"gemini-2.5-flash\"\nDEFAULT_OPENAI_MODEL = \"gpt-4o\"\n\nGEMINI_API_KEY = os.environ.get(\"GEMINI_API_KEY\") or os.environ.get(\n    \"LANGEXTRACT_API_KEY\"\n)\nOPENAI_API_KEY = os.environ.get(\"OPENAI_API_KEY\")\n\nVERTEX_PROJECT = os.environ.get(\"VERTEX_PROJECT\") or os.environ.get(\n    \"GOOGLE_CLOUD_PROJECT\"\n)\nVERTEX_LOCATION = os.environ.get(\"VERTEX_LOCATION\", \"us-central1\")\n\n\ndef has_vertex_ai_credentials():\n  \"\"\"Check if Vertex AI credentials are available.\"\"\"\n  if not VERTEX_PROJECT:\n    return False\n  try:\n    credentials, _ = google.auth.default()\n    return credentials is not None\n  except (ImportError, google.auth.exceptions.DefaultCredentialsError):\n    return False\n\n\nskip_if_no_gemini = pytest.mark.skipif(\n    not GEMINI_API_KEY,\n    reason=(\n        \"Gemini API key not available (set GEMINI_API_KEY or\"\n        \" LANGEXTRACT_API_KEY)\"\n    ),\n)\nskip_if_no_openai = pytest.mark.skipif(\n    not OPENAI_API_KEY,\n    reason=\"OpenAI API key not available (set OPENAI_API_KEY)\",\n)\nskip_if_no_vertex = pytest.mark.skipif(\n    not has_vertex_ai_credentials(),\n    reason=(\n        \"Vertex AI credentials not available (set GOOGLE_CLOUD_PROJECT and\"\n        \" configure gcloud auth)\"\n    ),\n)\n\nlive_api = pytest.mark.live_api\n\nGEMINI_MODEL_PARAMS = {\n    \"temperature\": 0.0,\n    \"top_p\": 0.0,\n    \"max_output_tokens\": 256,\n}\n\nOPENAI_MODEL_PARAMS = {\n    \"temperature\": 0.0,\n}\n\n# Extraction Classes\n_CLASS_MEDICATION = \"medication\"\n_CLASS_DOSAGE = \"dosage\"\n_CLASS_ROUTE = \"route\"\n_CLASS_FREQUENCY = \"frequency\"\n_CLASS_DURATION = \"duration\"\n_CLASS_CONDITION = \"condition\"\n\nINITIAL_RETRY_DELAY = 1.0\nMAX_RETRY_DELAY = 8.0\n\n\ndef retry_on_transient_errors(max_retries=3, backoff_factor=2.0):\n  \"\"\"Decorator to retry tests on transient API errors with exponential backoff.\n\n  Args:\n    max_retries (int): Maximum number of retry attempts\n    backoff_factor (float): Multiplier for exponential backoff (e.g., 2.0 = 1s, 2s, 4s)\n  \"\"\"\n\n  def decorator(func):\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n      last_exception = None\n      delay = INITIAL_RETRY_DELAY\n\n      for attempt in range(max_retries + 1):\n        try:\n          return func(*args, **kwargs)\n        except (\n            lx.exceptions.LangExtractError,\n            google.genai.errors.ClientError,\n            ConnectionError,\n            TimeoutError,\n            OSError,\n            RuntimeError,\n        ) as e:\n          last_exception = e\n          if attempt < max_retries:\n            print(\n                f\"\\nRetryable error ({type(e).__name__}) on attempt\"\n                f\" {attempt + 1}/{max_retries + 1}: {e}\"\n            )\n            time.sleep(delay)\n            delay = min(delay * backoff_factor, MAX_RETRY_DELAY)\n            continue\n\n          raise\n\n      raise last_exception\n\n    return wrapper\n\n  return decorator\n\n\n@pytest.fixture(autouse=True)\ndef add_delay_between_tests():\n  \"\"\"Add a small delay between tests to avoid rate limiting.\"\"\"\n  yield\n  time.sleep(0.5)\n\n\ndef get_basic_medication_examples():\n  \"\"\"Get example data for basic medication extraction.\"\"\"\n  return [\n      lx.data.ExampleData(\n          text=\"Patient was given 250 mg IV Cefazolin TID for one week.\",\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=_CLASS_DOSAGE, extraction_text=\"250 mg\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_ROUTE, extraction_text=\"IV\"\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_MEDICATION,\n                  extraction_text=\"Cefazolin\",\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_FREQUENCY,\n                  extraction_text=\"TID\",  # TID = three times a day\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_DURATION,\n                  extraction_text=\"for one week\",\n              ),\n          ],\n      )\n  ]\n\n\ndef get_relationship_examples():\n  \"\"\"Get example data for medication relationship extraction.\"\"\"\n  return [\n      lx.data.ExampleData(\n          text=(\n              \"Patient takes Aspirin 100mg daily for heart health and\"\n              \" Simvastatin 20mg at bedtime.\"\n          ),\n          extractions=[\n              # First medication group\n              lx.data.Extraction(\n                  extraction_class=_CLASS_MEDICATION,\n                  extraction_text=\"Aspirin\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_DOSAGE,\n                  extraction_text=\"100mg\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_FREQUENCY,\n                  extraction_text=\"daily\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_CONDITION,\n                  extraction_text=\"heart health\",\n                  attributes={\"medication_group\": \"Aspirin\"},\n              ),\n              # Second medication group\n              lx.data.Extraction(\n                  extraction_class=_CLASS_MEDICATION,\n                  extraction_text=\"Simvastatin\",\n                  attributes={\"medication_group\": \"Simvastatin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_DOSAGE,\n                  extraction_text=\"20mg\",\n                  attributes={\"medication_group\": \"Simvastatin\"},\n              ),\n              lx.data.Extraction(\n                  extraction_class=_CLASS_FREQUENCY,\n                  extraction_text=\"at bedtime\",\n                  attributes={\"medication_group\": \"Simvastatin\"},\n              ),\n          ],\n      )\n  ]\n\n\ndef extract_by_class(result, extraction_class):\n  \"\"\"Helper to extract entities by class.\n\n  Returns a set of extraction texts for the given class.\n  \"\"\"\n  return {\n      e.extraction_text\n      for e in result.extractions\n      if e.extraction_class == extraction_class\n  }\n\n\ndef assert_extractions_contain(test_case, result, expected_classes):\n  \"\"\"Assert that result contains all expected extraction classes.\n\n  Uses unittest assertions for richer error messages.\n  \"\"\"\n  actual_classes = {e.extraction_class for e in result.extractions}\n  missing_classes = expected_classes - actual_classes\n  test_case.assertFalse(\n      missing_classes,\n      f\"Missing expected classes: {missing_classes}. Found extractions:\"\n      f\" {[f'{e.extraction_class}:{e.extraction_text}' for e in result.extractions]}\",\n  )\n\n\ndef assert_valid_char_intervals(test_case, result):\n  \"\"\"Assert that all extractions have valid char intervals and alignment status.\"\"\"\n  for extraction in result.extractions:\n    test_case.assertIsNotNone(\n        extraction.char_interval,\n        f\"Missing char_interval for extraction: {extraction.extraction_text}\",\n    )\n    test_case.assertIsNotNone(\n        extraction.alignment_status,\n        \"Missing alignment_status for extraction:\"\n        f\" {extraction.extraction_text}\",\n    )\n    if isinstance(result, lx.data.AnnotatedDocument) and result.text:\n      text_length = len(result.text)\n      test_case.assertGreaterEqual(\n          extraction.char_interval.start_pos,\n          0,\n          f\"Invalid start_pos for extraction: {extraction.extraction_text}\",\n      )\n      test_case.assertLessEqual(\n          extraction.char_interval.end_pos,\n          text_length,\n          f\"Invalid end_pos for extraction: {extraction.extraction_text}\",\n      )\n\n\nclass TestLiveAPIGemini(unittest.TestCase):\n  \"\"\"Tests using real Gemini API.\"\"\"\n\n  def _check_cached_result(self, result_json: dict[str, Any]) -> bool:\n    \"\"\"Check if cached result contains expected medication data.\n\n    Args:\n      result_json: The raw JSON dict from the cache file.\n                   Expected format: {\"text\": \"JSON_STRING_OF_RESULT\"}\n\n    Returns:\n      True if the result contains valid medication extractions, False otherwise.\n    \"\"\"\n    try:\n      text_content = result_json.get(\"text\")\n      if not isinstance(text_content, str):\n        return False\n\n      inner_json = json.loads(text_content)\n      if not isinstance(inner_json, dict):\n        return False\n\n      extractions_data = inner_json.get(data.EXTRACTIONS_KEY)\n      if not isinstance(extractions_data, list):\n        return False\n\n      extractions = []\n      for item in extractions_data:\n        if isinstance(item, dict):\n          clean_item = {k: v for k, v in item.items() if not k.startswith(\"_\")}\n          extractions.append(data.Extraction(**clean_item))\n\n      doc = data.AnnotatedDocument(\n          text=inner_json.get(\"text\"), extractions=extractions\n      )\n\n      if not doc.extractions:\n        return False\n\n      # Check for specific content\n      medication_texts = extract_by_class(doc, _CLASS_MEDICATION)\n      dosage_texts = extract_by_class(doc, _CLASS_DOSAGE)\n\n      has_lisinopril = any(\"Lisinopril\" in t for t in medication_texts)\n      has_10mg = any(\"10mg\" in t for t in dosage_texts)\n\n      return has_lisinopril and has_10mg\n\n    except (json.JSONDecodeError, TypeError, ValueError):\n      return False\n\n  def _verify_gcs_cache_content(self, bucket_name):\n    \"\"\"Verify that GCS cache contains expected structured results.\"\"\"\n    cache = gb.GCSBatchCache(bucket_name, project=VERTEX_PROJECT)\n    found_content = False\n\n    # Use iter_items() to check cache content\n    items = list(cache.iter_items())\n    self.assertTrue(len(items) > 0, \"No cache files found in GCS bucket\")\n\n    for _, text in items:\n      try:\n        result_json = json.loads(text)\n        if self._check_cached_result(result_json):\n          found_content = True\n          break\n      except (json.JSONDecodeError, TypeError, ValueError):\n        continue\n\n    self.assertTrue(\n        found_content,\n        \"Could not find expected structured result in GCS cache files\",\n    )\n\n  @skip_if_no_gemini\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_medication_extraction(self):\n    \"\"\"Test medication extraction with entities in order.\"\"\"\n    prompt = textwrap.dedent(\"\"\"\\\n        Extract medication information including medication name, dosage, route, frequency,\n        and duration in the order they appear in the text.\"\"\")\n\n    examples = get_basic_medication_examples()\n    input_text = \"Patient took 400 mg PO Ibuprofen q4h for two days.\"\n\n    result = lx.extract(\n        text_or_documents=input_text,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_GEMINI_MODEL,\n        api_key=GEMINI_API_KEY,\n        language_model_params=GEMINI_MODEL_PARAMS,\n    )\n\n    assert result is not None\n    self.assertIsInstance(result, lx.data.AnnotatedDocument)\n    assert len(result.extractions) > 0\n\n    expected_classes = {\n        _CLASS_DOSAGE,\n        _CLASS_ROUTE,\n        _CLASS_MEDICATION,\n        _CLASS_FREQUENCY,\n        _CLASS_DURATION,\n    }\n    assert_extractions_contain(self, result, expected_classes)\n    assert_valid_char_intervals(self, result)\n\n    # Using regex for precise matching to avoid false positives\n    medication_texts = extract_by_class(result, _CLASS_MEDICATION)\n    self.assertTrue(\n        any(\n            re.search(r\"\\bIbuprofen\\b\", text, re.IGNORECASE)\n            for text in medication_texts\n        ),\n        f\"No Ibuprofen found in: {medication_texts}\",\n    )\n\n    dosage_texts = extract_by_class(result, _CLASS_DOSAGE)\n    self.assertTrue(\n        any(\n            re.search(r\"\\b400\\s*mg\\b\", text, re.IGNORECASE)\n            for text in dosage_texts\n        ),\n        f\"No 400mg dosage found in: {dosage_texts}\",\n    )\n\n    route_texts = extract_by_class(result, _CLASS_ROUTE)\n    self.assertTrue(\n        any(\n            re.search(r\"\\b(PO|oral)\\b\", text, re.IGNORECASE)\n            for text in route_texts\n        ),\n        f\"No PO/oral route found in: {route_texts}\",\n    )\n\n  @skip_if_no_gemini\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_multilingual_medication_extraction(self):\n    \"\"\"Test medication extraction with Japanese text.\"\"\"\n    text = (  # \"The patient takes 10 mg of medication daily.\"\n        \"患者は毎日10mgの薬を服用します。\"\n    )\n\n    prompt = \"Extract medication information including dosage and frequency.\"\n\n    examples = [\n        lx.data.ExampleData(\n            text=\"The patient takes 20mg of aspirin twice daily.\",\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=_CLASS_MEDICATION,\n                    extraction_text=\"aspirin\",\n                    attributes={\n                        _CLASS_DOSAGE: \"20mg\",\n                        _CLASS_FREQUENCY: \"twice daily\",\n                    },\n                ),\n            ],\n        )\n    ]\n\n    unicode_tokenizer = tokenizer_lib.UnicodeTokenizer()\n\n    result = lx.extract(\n        text_or_documents=text,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_GEMINI_MODEL,\n        api_key=GEMINI_API_KEY,\n        language_model_params=GEMINI_MODEL_PARAMS,\n        tokenizer=unicode_tokenizer,\n    )\n\n    assert result is not None\n    self.assertIsInstance(result, lx.data.AnnotatedDocument)\n    assert len(result.extractions) > 0\n\n    medication_extractions = [\n        e for e in result.extractions if e.extraction_class == _CLASS_MEDICATION\n    ]\n    assert (\n        len(medication_extractions) > 0\n    ), \"No medication entities found in Japanese text\"\n    assert_valid_char_intervals(self, result)\n\n  @skip_if_no_gemini\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_explicit_provider_gemini(self):\n    \"\"\"Test using explicit provider with Gemini.\"\"\"\n    config = lx.factory.ModelConfig(\n        model_id=DEFAULT_GEMINI_MODEL,\n        provider=\"GeminiLanguageModel\",\n        provider_kwargs={\n            \"api_key\": GEMINI_API_KEY,\n            \"temperature\": 0.0,\n        },\n    )\n\n    model = lx.factory.create_model(config)\n    self.assertEqual(model.__class__.__name__, \"GeminiLanguageModel\")\n    self.assertEqual(model.model_id, DEFAULT_GEMINI_MODEL)\n\n    config2 = lx.factory.ModelConfig(\n        model_id=DEFAULT_GEMINI_MODEL,\n        provider=\"gemini\",\n        provider_kwargs={\n            \"api_key\": GEMINI_API_KEY,\n        },\n    )\n\n    model2 = lx.factory.create_model(config2)\n    self.assertEqual(model2.__class__.__name__, \"GeminiLanguageModel\")\n\n  @skip_if_no_gemini\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_medication_relationship_extraction(self):\n    \"\"\"Test relationship extraction for medications with Gemini.\"\"\"\n    input_text = \"\"\"\n    The patient was prescribed Lisinopril and Metformin last month.\n    He takes the Lisinopril 10mg daily for hypertension, but often misses\n    his Metformin 500mg dose which should be taken twice daily for diabetes.\n    \"\"\"\n\n    prompt = textwrap.dedent(\"\"\"\n        Extract medications with their details, using attributes to group related information:\n\n        1. Extract entities in the order they appear in the text\n        2. Each entity must have a 'medication_group' attribute linking it to its medication\n        3. All details about a medication should share the same medication_group value\n    \"\"\")\n\n    examples = get_relationship_examples()\n\n    result = lx.extract(\n        text_or_documents=input_text,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_GEMINI_MODEL,\n        api_key=GEMINI_API_KEY,\n        language_model_params=GEMINI_MODEL_PARAMS,\n    )\n\n    assert result is not None\n    assert len(result.extractions) > 0\n    assert_valid_char_intervals(self, result)\n\n    medication_groups = {}\n    for extraction in result.extractions:\n      assert (\n          extraction.attributes is not None\n      ), f\"Missing attributes for {extraction.extraction_text}\"\n      assert (\n          \"medication_group\" in extraction.attributes\n      ), f\"Missing medication_group for {extraction.extraction_text}\"\n\n      group_name = extraction.attributes[\"medication_group\"]\n      medication_groups.setdefault(group_name, []).append(extraction)\n\n    assert (\n        len(medication_groups) >= 2\n    ), f\"Expected at least 2 medications, found {len(medication_groups)}\"\n\n    # Allow flexible matching for dosage field (could be \"dosage\" or \"dose\")\n    for med_name, extractions in medication_groups.items():\n      extraction_classes = {e.extraction_class for e in extractions}\n      # At minimum, each group should have the medication itself\n      assert (\n          _CLASS_MEDICATION in extraction_classes\n      ), f\"{med_name} group missing medication entity\"\n      # Dosage is expected but might be formatted differently\n      assert any(\n          c in extraction_classes for c in [_CLASS_DOSAGE, \"dose\"]\n      ), f\"{med_name} group missing dosage\"\n\n  @skip_if_no_vertex\n  @live_api\n  @pytest.mark.vertex_ai\n  @mock.patch.object(gb, \"infer_batch\", wraps=gb.infer_batch, autospec=True)\n  def test_batch_extraction_vertex_gcs(self, mock_infer_batch):\n    \"\"\"Test extraction using Vertex AI Batch API with GCS.\n\n    This test runs a real Vertex AI Batch job and will take time to complete.\n    It is skipped unless VERTEX_PROJECT is set.\n\n    We wrap `infer_batch` to verify that:\n    - Batch API is actually called (not falling back to real-time API)\n    - Schema dict is passed (non-None) to the batch function\n    \"\"\"\n\n    prompt = textwrap.dedent(\"\"\"\\\n        Extract medication information including medication name, dosage, route, frequency,\n        and duration in the order they appear in the text.\"\"\")\n\n    examples = get_basic_medication_examples()\n\n    documents = [\n        lx.data.Document(\n            document_id=\"vx_doc1\",\n            text=\"Patient took 400 mg PO Ibuprofen q4h for two days.\",\n        ),\n        lx.data.Document(\n            document_id=\"vx_doc2\",\n            text=\"Patient was given 250 mg IV Cefazolin TID for one week.\",\n        ),\n        lx.data.Document(\n            document_id=\"vx_doc3\",\n            text=\"Administered 2 mg IV Morphine once for acute pain.\",\n        ),\n        lx.data.Document(\n            document_id=\"vx_doc4\",\n            text=\"Prescribed 500 mg PO Amoxicillin BID for infection.\",\n        ),\n        lx.data.Document(\n            document_id=\"vx_doc5\",\n            text=\"Given 10 mg IM Haloperidol PRN for agitation.\",\n        ),\n    ]\n    expected_meds = [\n        \"Ibuprofen\",\n        \"Cefazolin\",\n        \"Morphine\",\n        \"Amoxicillin\",\n        \"Haloperidol\",\n    ]\n\n    language_model_params = dict(GEMINI_MODEL_PARAMS)\n    language_model_params[\"vertexai\"] = True\n    language_model_params[\"project\"] = VERTEX_PROJECT\n    language_model_params[\"location\"] = VERTEX_LOCATION\n    language_model_params[\"batch\"] = {\n        \"enabled\": True,\n        \"threshold\": 2,\n        \"poll_interval\": 1,  # Fast polling for test\n        \"timeout\": 900,  # 15 minutes for actual batch job completion\n    }\n\n    batch_result = lx.extract(\n        text_or_documents=documents,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_GEMINI_MODEL,\n        language_model_params=language_model_params,\n    )\n\n    mock_infer_batch.assert_called_once()\n    call_args = mock_infer_batch.call_args\n    schema_dict_arg = call_args.kwargs.get(\"schema_dict\")\n    self.assertIsNotNone(\n        schema_dict_arg,\n        \"schema_dict should be passed to batch API (not None)\",\n    )\n\n    self.assertIsInstance(batch_result, list)\n    self.assertEqual(\n        len(batch_result),\n        len(documents),\n        f\"Expected {len(documents)} results from Vertex batch API\",\n    )\n\n    for i, (res, med_name) in enumerate(zip(batch_result, expected_meds)):\n      self.assertIsInstance(\n          res,\n          lx.data.AnnotatedDocument,\n          f\"Result {i} should be an AnnotatedDocument, got {type(res)}\",\n      )\n      self.assertTrue(\n          res.extractions,\n          f\"No extractions for document {i}\",\n      )\n      for extraction in res.extractions:\n        self.assertIsInstance(\n            extraction,\n            lx.data.Extraction,\n            \"Extraction item should be Extraction object, got\"\n            f\" {type(extraction)}\",\n        )\n\n      meds = extract_by_class(res, _CLASS_MEDICATION)\n      self.assertTrue(\n          any(\n              re.search(rf\"\\b{re.escape(med_name)}\\b\", m, re.IGNORECASE)\n              for m in meds\n          ),\n          f\"Expected medication '{med_name}' not found in results: {meds}\",\n      )\n\n      dosages = extract_by_class(res, _CLASS_DOSAGE)\n      self.assertTrue(\n          dosages,\n          f\"No dosage extracted for medication '{med_name}'\",\n      )\n\n      assert_valid_char_intervals(self, res)\n\n  @skip_if_no_vertex\n  @live_api\n  @pytest.mark.vertex_ai\n  def test_batch_caching_live(self):\n    \"\"\"Test batch caching with real Vertex AI Batch API.\n\n    Verifies that:\n    1. First run populates GCS cache\n    2. Second run uses cache (returns same results faster)\n    \"\"\"\n    prompt = \"Extract the medication: Patient takes 10mg Lisinopril.\"\n    examples = get_basic_medication_examples()\n\n    # Use unique IDs to ensure cache isolation between test runs.\n    run_id = uuid.uuid4().hex[:8]\n    documents = [\n        lx.data.Document(\n            document_id=f\"doc_{i}_{run_id}\",\n            text=f\"Patient takes 10mg Lisinopril {i} {run_id}.\",\n        )\n        for i in range(2)\n    ]\n\n    language_model_params = dict(GEMINI_MODEL_PARAMS)\n    language_model_params[\"vertexai\"] = True\n    language_model_params[\"project\"] = VERTEX_PROJECT\n    language_model_params[\"location\"] = VERTEX_LOCATION\n    language_model_params[\"batch\"] = {\n        \"enabled\": True,\n        \"threshold\": 2,\n        \"poll_interval\": 1,\n        \"timeout\": 900,\n        \"enable_caching\": True,\n    }\n\n    print(\"\\nStarting first batch run (API)...\")\n    start_time = time.time()\n    results1 = list(\n        lx.extract(\n            text_or_documents=documents,\n            prompt_description=prompt,\n            examples=examples,\n            model_id=DEFAULT_GEMINI_MODEL,\n            language_model_params=language_model_params,\n        )\n    )\n    duration1 = time.time() - start_time\n    print(f\"First run took {duration1:.2f}s\")\n\n    print(\"Starting second batch run (Cache)...\")\n    start_time = time.time()\n    results2 = list(\n        lx.extract(\n            text_or_documents=documents,\n            prompt_description=prompt,\n            examples=examples,\n            model_id=DEFAULT_GEMINI_MODEL,\n            language_model_params=language_model_params,\n        )\n    )\n    duration2 = time.time() - start_time\n    print(f\"Second run took {duration2:.2f}s\")\n\n    self.assertEqual(len(results1), len(results2))\n    for r1, r2 in zip(results1, results2):\n      self.assertEqual(r1.text, r2.text)\n      self.assertEqual(len(r1.extractions), len(r2.extractions))\n\n    self.assertLess(duration2, 10.0, \"Second run took too long for cache hit\")\n\n    self.assertLess(duration2, 10.0, \"Second run took too long for cache hit\")\n\n    print(\"\\nVerifying GCS cache content...\")\n    bucket_name = gb._get_bucket_name(VERTEX_PROJECT, VERTEX_LOCATION)\n    print(f\"Checking bucket: {bucket_name}\")\n    self._verify_gcs_cache_content(bucket_name)\n\n\nclass TestCrossChunkContext(unittest.TestCase):\n  \"\"\"Tests for cross-chunk context feature with real API.\"\"\"\n\n  @skip_if_no_gemini\n  @live_api\n  @retry_on_transient_errors(max_retries=3)\n  def test_context_window_extracts_from_both_chunks(self):\n    \"\"\"Verify context_window_chars enables extraction across chunk boundaries.\"\"\"\n    input_text = (\n        \"Dr. Sarah Chen is the lead researcher at the institute. \"\n        \"She published groundbreaking work on neural networks last year.\"\n    )\n    prompt = textwrap.dedent(\n        \"\"\"\\\n        Extract all person names, roles, and achievements mentioned in the text.\n        Include both explicit names and information associated with pronouns.\"\"\"\n    )\n    examples = [\n        lx.data.ExampleData(\n            text=(\n                \"Professor James Miller leads the physics department. \"\n                \"He won the Nobel Prize in 2020.\"\n            ),\n            extractions=[\n                lx.data.Extraction(\n                    extraction_class=\"person\",\n                    extraction_text=\"Professor James Miller\",\n                    attributes={\"role\": \"leads the physics department\"},\n                ),\n                lx.data.Extraction(\n                    extraction_class=\"achievement\",\n                    extraction_text=\"won the Nobel Prize in 2020\",\n                ),\n            ],\n        )\n    ]\n\n    result = lx.extract(\n        text_or_documents=input_text,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_GEMINI_MODEL,\n        api_key=GEMINI_API_KEY,\n        language_model_params=GEMINI_MODEL_PARAMS,\n        max_char_buffer=60,\n        context_window_chars=50,\n    )\n\n    self.assertIsNotNone(result)\n    self.assertGreater(len(result.extractions), 0)\n\n    all_extraction_text = \" \".join(\n        str(e.extraction_text) + \" \" + str(e.attributes)\n        for e in result.extractions\n    ).lower()\n\n    has_chunk1_content = any(\n        term in all_extraction_text\n        for term in (\"sarah\", \"chen\", \"researcher\", \"lead\")\n    )\n    has_chunk2_content = any(\n        term in all_extraction_text\n        for term in (\"published\", \"groundbreaking\", \"neural\", \"networks\")\n    )\n\n    self.assertTrue(\n        has_chunk1_content,\n        f\"Expected chunk 1 content (Sarah Chen). Got: {result.extractions}\",\n    )\n    self.assertTrue(\n        has_chunk2_content,\n        f\"Expected chunk 2 content (publication). Got: {result.extractions}\",\n    )\n\n\nclass TestLiveAPIOpenAI(unittest.TestCase):\n  \"\"\"Tests using real OpenAI API.\"\"\"\n\n  @skip_if_no_openai\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_medication_extraction(self):\n    \"\"\"Test medication extraction with OpenAI models.\"\"\"\n    prompt = textwrap.dedent(\"\"\"\\\n        Extract medication information including medication name, dosage, route, frequency,\n        and duration in the order they appear in the text.\"\"\")\n\n    examples = get_basic_medication_examples()\n    input_text = \"Patient took 400 mg PO Ibuprofen q4h for two days.\"\n\n    result = lx.extract(\n        text_or_documents=input_text,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_OPENAI_MODEL,\n        api_key=OPENAI_API_KEY,\n        use_schema_constraints=False,\n        language_model_params=OPENAI_MODEL_PARAMS,\n    )\n\n    assert result is not None\n    self.assertIsInstance(result, lx.data.AnnotatedDocument)\n    assert len(result.extractions) > 0\n\n    expected_classes = {\n        _CLASS_DOSAGE,\n        _CLASS_ROUTE,\n        _CLASS_MEDICATION,\n        _CLASS_FREQUENCY,\n        _CLASS_DURATION,\n    }\n    assert_extractions_contain(self, result, expected_classes)\n    assert_valid_char_intervals(self, result)\n\n    # Using regex for precise matching to avoid false positives\n    medication_texts = extract_by_class(result, _CLASS_MEDICATION)\n    self.assertTrue(\n        any(\n            re.search(r\"\\bIbuprofen\\b\", text, re.IGNORECASE)\n            for text in medication_texts\n        ),\n        f\"No Ibuprofen found in: {medication_texts}\",\n    )\n\n    dosage_texts = extract_by_class(result, _CLASS_DOSAGE)\n    self.assertTrue(\n        any(\n            re.search(r\"\\b400\\s*mg\\b\", text, re.IGNORECASE)\n            for text in dosage_texts\n        ),\n        f\"No 400mg dosage found in: {dosage_texts}\",\n    )\n\n    route_texts = extract_by_class(result, _CLASS_ROUTE)\n    self.assertTrue(\n        any(\n            re.search(r\"\\b(PO|oral)\\b\", text, re.IGNORECASE)\n            for text in route_texts\n        ),\n        f\"No PO/oral route found in: {route_texts}\",\n    )\n\n  @skip_if_no_openai\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_explicit_provider_selection(self):\n    \"\"\"Test using explicit provider parameter for disambiguation.\"\"\"\n    # Test with explicit model_id and provider\n    config = lx.factory.ModelConfig(\n        model_id=DEFAULT_OPENAI_MODEL,\n        provider=\"OpenAILanguageModel\",  # Explicit provider selection\n        provider_kwargs={\n            \"api_key\": OPENAI_API_KEY,\n            \"temperature\": 0.0,\n        },\n    )\n\n    model = lx.factory.create_model(config)\n\n    self.assertIsInstance(model, lx.providers.openai.OpenAILanguageModel)\n    self.assertEqual(model.model_id, DEFAULT_OPENAI_MODEL)\n\n    # Also test using provider without model_id (uses default)\n    config_default = lx.factory.ModelConfig(\n        provider=\"OpenAILanguageModel\",\n        provider_kwargs={\n            \"api_key\": OPENAI_API_KEY,\n        },\n    )\n\n    model_default = lx.factory.create_model(config_default)\n    self.assertEqual(model_default.__class__.__name__, \"OpenAILanguageModel\")\n    # Should use the default model_id from the provider\n    self.assertEqual(model_default.model_id, \"gpt-4o-mini\")\n\n  @skip_if_no_openai\n  @live_api\n  @retry_on_transient_errors(max_retries=2)\n  def test_medication_relationship_extraction(self):\n    \"\"\"Test relationship extraction for medications with OpenAI.\"\"\"\n    input_text = \"\"\"\n    The patient was prescribed Lisinopril and Metformin last month.\n    He takes the Lisinopril 10mg daily for hypertension, but often misses\n    his Metformin 500mg dose which should be taken twice daily for diabetes.\n    \"\"\"\n\n    prompt = textwrap.dedent(\"\"\"\n        Extract medications with their details, using attributes to group related information:\n\n        1. Extract entities in the order they appear in the text\n        2. Each entity must have a 'medication_group' attribute linking it to its medication\n        3. All details about a medication should share the same medication_group value\n    \"\"\")\n\n    examples = get_relationship_examples()\n\n    result = lx.extract(\n        text_or_documents=input_text,\n        prompt_description=prompt,\n        examples=examples,\n        model_id=DEFAULT_OPENAI_MODEL,\n        api_key=OPENAI_API_KEY,\n        use_schema_constraints=False,\n        language_model_params=OPENAI_MODEL_PARAMS,\n    )\n\n    assert result is not None\n    assert len(result.extractions) > 0\n    assert_valid_char_intervals(self, result)\n\n    medication_groups = {}\n    for extraction in result.extractions:\n      assert (\n          extraction.attributes is not None\n      ), f\"Missing attributes for {extraction.extraction_text}\"\n      assert (\n          \"medication_group\" in extraction.attributes\n      ), f\"Missing medication_group for {extraction.extraction_text}\"\n\n      group_name = extraction.attributes[\"medication_group\"]\n      medication_groups.setdefault(group_name, []).append(extraction)\n\n    assert (\n        len(medication_groups) >= 2\n    ), f\"Expected at least 2 medications, found {len(medication_groups)}\"\n\n    # Allow flexible matching for dosage field (could be \"dosage\" or \"dose\")\n    for med_name, extractions in medication_groups.items():\n      extraction_classes = {e.extraction_class for e in extractions}\n      # At minimum, each group should have the medication itself\n      assert (\n          _CLASS_MEDICATION in extraction_classes\n      ), f\"{med_name} group missing medication entity\"\n      # Dosage is expected but might be formatted differently\n      assert any(\n          c in extraction_classes for c in [_CLASS_DOSAGE, \"dose\"]\n      ), f\"{med_name} group missing dosage\"\n"
  },
  {
    "path": "tests/test_ollama_integration.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Integration tests for Ollama functionality.\"\"\"\nimport socket\n\nimport pytest\nimport requests\n\nimport langextract as lx\n\n\ndef _ollama_available():\n  \"\"\"Check if Ollama is running on localhost:11434.\"\"\"\n  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:\n    result = sock.connect_ex((\"localhost\", 11434))\n    return result == 0\n\n\n@pytest.mark.skipif(not _ollama_available(), reason=\"Ollama not running\")\ndef test_ollama_extraction():\n  input_text = \"Isaac Asimov was a prolific science fiction writer.\"\n  prompt = \"Extract the author's full name and their primary literary genre.\"\n\n  examples = [\n      lx.data.ExampleData(\n          text=(\n              \"J.R.R. Tolkien was an English writer, best known for\"\n              \" high-fantasy.\"\n          ),\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"author_details\",\n                  extraction_text=\"J.R.R. Tolkien was an English writer...\",\n                  attributes={\n                      \"name\": \"J.R.R. Tolkien\",\n                      \"genre\": \"high-fantasy\",\n                  },\n              )\n          ],\n      )\n  ]\n\n  model_id = \"gemma2:2b\"\n\n  result = lx.extract(\n      text_or_documents=input_text,\n      prompt_description=prompt,\n      examples=examples,\n      model_id=model_id,\n      model_url=\"http://localhost:11434\",\n      temperature=0.3,\n      fence_output=False,\n      use_schema_constraints=False,\n  )\n\n  assert len(result.extractions) > 0\n  extraction = result.extractions[0]\n  assert extraction.extraction_class == \"author_details\"\n  if extraction.attributes:\n    assert \"asimov\" in extraction.attributes.get(\"name\", \"\").lower()\n\n\n@pytest.mark.skipif(not _ollama_available(), reason=\"Ollama not running\")\ndef test_ollama_extraction_with_fence_fallback():\n  input_text = \"Marie Curie was a physicist who won two Nobel prizes.\"\n  prompt = \"Extract information about people and their achievements.\"\n\n  examples = [\n      lx.data.ExampleData(\n          text=\"Albert Einstein developed the theory of relativity.\",\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"person\",\n                  extraction_text=\"Albert Einstein\",\n                  attributes={\"achievement\": \"theory of relativity\"},\n              )\n          ],\n      )\n  ]\n\n  model_id = \"gemma2:2b\"\n\n  result = lx.extract(\n      text_or_documents=input_text,\n      prompt_description=prompt,\n      examples=examples,\n      model_id=model_id,\n      model_url=\"http://localhost:11434\",\n      temperature=0.3,\n      fence_output=True,  # Testing that fallback works\n      use_schema_constraints=False,\n  )\n\n  assert len(result.extractions) > 0\n  extraction = result.extractions[0]\n  assert extraction.extraction_class == \"person\"\n  assert (\n      \"marie\" in extraction.extraction_text.lower()\n      or \"curie\" in extraction.extraction_text.lower()\n  )\n\n\ndef _model_available(model_name):\n  \"\"\"Check if a specific model is available in Ollama.\"\"\"\n  if not _ollama_available():\n    return False\n  try:\n    response = requests.get(\"http://localhost:11434/api/tags\", timeout=5)\n    models = [m[\"name\"] for m in response.json().get(\"models\", [])]\n    return any(model_name in m for m in models)\n  except (requests.RequestException, KeyError, TypeError):\n    return False\n\n\n@pytest.mark.skipif(\n    not _model_available(\"deepseek-r1\"),\n    reason=\"DeepSeek-R1 not available in Ollama\",\n)\ndef test_deepseek_r1_extraction():\n  \"\"\"Test extraction with DeepSeek-R1 reasoning model.\n\n  DeepSeek-R1 outputs <think> tags before JSON when not using format:json.\n  This test verifies the model works correctly with langextract.\n  \"\"\"\n  input_text = \"John Smith is a software engineer at Google.\"\n  prompt = \"Extract people and their roles.\"\n\n  examples = [\n      lx.data.ExampleData(\n          text=\"Alice works as a designer at Apple.\",\n          extractions=[\n              lx.data.Extraction(\n                  extraction_class=\"person\",\n                  extraction_text=\"Alice\",\n                  attributes={\"role\": \"designer\", \"company\": \"Apple\"},\n              )\n          ],\n      )\n  ]\n\n  result = lx.extract(\n      text_or_documents=input_text,\n      prompt_description=prompt,\n      examples=examples,\n      model_id=\"deepseek-r1:1.5b\",\n      model_url=\"http://localhost:11434\",\n      temperature=0.3,\n  )\n\n  assert len(result.extractions) > 0\n  extraction = result.extractions[0]\n  assert extraction.extraction_class == \"person\"\n  assert \"john\" in extraction.extraction_text.lower()\n"
  },
  {
    "path": "tests/tokenizer_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport textwrap\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\n\nfrom langextract.core import tokenizer\n\n\nclass TokenizerTest(parameterized.TestCase):\n  # pylint: disable=too-many-public-methods\n\n  def assertTokenListEqual(self, actual_tokens, expected_tokens, msg=None):\n    self.assertLen(actual_tokens, len(expected_tokens), msg=msg)\n    for i, (expected, actual) in enumerate(zip(expected_tokens, actual_tokens)):\n      expected = tokenizer.Token(\n          index=expected.index,\n          token_type=expected.token_type,\n          first_token_after_newline=expected.first_token_after_newline,\n      )\n      actual = tokenizer.Token(\n          index=actual.index,\n          token_type=actual.token_type,\n          first_token_after_newline=actual.first_token_after_newline,\n      )\n      self.assertDataclassEqual(\n          expected,\n          actual,\n          msg=f\"Token mismatch at index {i}\",\n      )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"basic_text\",\n          input_text=\"Hello, world!\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=1, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=3, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"multiple_spaces_and_numbers\",\n          input_text=\"Age:   25\\nWeight=70kg.\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=1, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(index=2, token_type=tokenizer.TokenType.NUMBER),\n              tokenizer.Token(\n                  index=3,\n                  token_type=tokenizer.TokenType.WORD,\n                  first_token_after_newline=True,\n              ),\n              tokenizer.Token(\n                  index=4, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(index=5, token_type=tokenizer.TokenType.NUMBER),\n              tokenizer.Token(index=6, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=7, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"multi_line_input\",\n          input_text=\"Line1\\nLine2\\nLine3\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=1, token_type=tokenizer.TokenType.NUMBER),\n              tokenizer.Token(\n                  index=2,\n                  token_type=tokenizer.TokenType.WORD,\n                  first_token_after_newline=True,\n              ),\n              tokenizer.Token(index=3, token_type=tokenizer.TokenType.NUMBER),\n              tokenizer.Token(\n                  index=4,\n                  token_type=tokenizer.TokenType.WORD,\n                  first_token_after_newline=True,\n              ),\n              tokenizer.Token(index=5, token_type=tokenizer.TokenType.NUMBER),\n          ],\n      ),\n      dict(\n          testcase_name=\"only_symbols\",\n          input_text=\"!!!@#   $$$%\",\n          expected_tokens=[\n              tokenizer.Token(\n                  index=0, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(\n                  index=1, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(\n                  index=2, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(\n                  index=3, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(\n                  index=4, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"empty_string\",\n          input_text=\"\",\n          expected_tokens=[],\n      ),\n      dict(\n          testcase_name=\"non_ascii_text\",\n          input_text=\"café\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n          ],\n      ),\n      dict(\n          testcase_name=\"mixed_punctuation\",\n          input_text=\"?!\",\n          expected_tokens=[\n              tokenizer.Token(\n                  index=0, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(\n                  index=1, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n          ],\n      ),\n  )\n  def test_tokenize_various_inputs(self, input_text, expected_tokens):\n    tokenized = tokenizer.tokenize(input_text)\n    self.assertTokenListEqual(\n        tokenized.tokens,\n        expected_tokens,\n        msg=f\"Tokens mismatch for input: {input_text!r}\",\n    )\n\n  def test_first_token_after_newline_flag(self):\n    input_text = \"Line1\\nLine2\\nLine3\"\n    tokenized = tokenizer.tokenize(input_text)\n\n    expected_tokens = [\n        tokenizer.Token(\n            index=0,\n            token_type=tokenizer.TokenType.WORD,\n        ),\n        tokenizer.Token(\n            index=1,\n            token_type=tokenizer.TokenType.NUMBER,\n        ),\n        tokenizer.Token(\n            index=2,\n            token_type=tokenizer.TokenType.WORD,\n            first_token_after_newline=True,\n        ),\n        tokenizer.Token(\n            index=3,\n            token_type=tokenizer.TokenType.NUMBER,\n        ),\n        tokenizer.Token(\n            index=4,\n            token_type=tokenizer.TokenType.WORD,\n            first_token_after_newline=True,\n        ),\n        tokenizer.Token(\n            index=5,\n            token_type=tokenizer.TokenType.NUMBER,\n        ),\n    ]\n\n    self.assertTokenListEqual(\n        tokenized.tokens,\n        expected_tokens,\n        msg=\"Newline flags mismatch\",\n    )\n\n  def test_performance_optimization_no_crash(self):\n    \"\"\"Verify that tokenization handles empty strings and newlines without error.\"\"\"\n    tok = tokenizer.RegexTokenizer()\n    text = \"\"\n    tokenized = tok.tokenize(text)\n    self.assertEmpty(tokenized.tokens)\n\n    text = \"\\n\"\n    tokenized = tok.tokenize(text)\n    self.assertEmpty(tokenized.tokens)\n\n    text = \"A\\nB\"\n    tokenized = tok.tokenize(text)\n    self.assertLen(tokenized.tokens, 2)\n    self.assertTrue(tokenized.tokens[1].first_token_after_newline)\n\n  def test_underscore_handling(self):\n    \"\"\"Verify that underscores are preserved as punctuation/symbols.\"\"\"\n    # RegexTokenizer should now capture underscores explicitly.\n    tok = tokenizer.RegexTokenizer()\n    text = \"user_id\"\n    tokenized = tok.tokenize(text)\n    # Expecting: \"user\", \"_\", \"id\"\n    self.assertLen(tokenized.tokens, 3)\n    self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD)\n    self.assertEqual(\n        tokenized.tokens[1].token_type, tokenizer.TokenType.PUNCTUATION\n    )\n    self.assertEqual(tokenized.tokens[2].token_type, tokenizer.TokenType.WORD)\n\n\nclass UnicodeTokenizerTest(parameterized.TestCase):\n  # pylint: disable=too-many-public-methods\n\n  def assertTokenListEqual(self, actual_tokens, expected_tokens, msg=None):\n    self.assertLen(actual_tokens, len(expected_tokens), msg=msg)\n    for i, (expected, actual) in enumerate(zip(expected_tokens, actual_tokens)):\n      expected_tok = tokenizer.Token(\n          index=expected.index,\n          token_type=expected.token_type,\n          first_token_after_newline=expected.first_token_after_newline,\n      )\n      actual_tok = tokenizer.Token(\n          index=actual.index,\n          token_type=actual.token_type,\n          first_token_after_newline=actual.first_token_after_newline,\n      )\n      self.assertDataclassEqual(\n          expected_tok,\n          actual_tok,\n          msg=f\"Token mismatch at index {i}\",\n      )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"japanese_text\",\n          input_text=\"こんにちは、世界！\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=1, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=3, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=4, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=5, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(index=6, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=7, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=8, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"english_text\",\n          input_text=\"Hello, world!\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=1, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n              tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(\n                  index=3, token_type=tokenizer.TokenType.PUNCTUATION\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"mixed_text\",\n          input_text=\"Hello 世界 123\",\n          expected_tokens=[\n              tokenizer.Token(index=0, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=1, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=2, token_type=tokenizer.TokenType.WORD),\n              tokenizer.Token(index=3, token_type=tokenizer.TokenType.NUMBER),\n          ],\n      ),\n  )\n  def test_tokenize_various_inputs(self, input_text, expected_tokens):\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(input_text)\n    self.assertTokenListEqual(\n        tokenized.tokens,\n        expected_tokens,\n        msg=f\"Tokens mismatch for input: {input_text!r}\",\n    )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"mixed_digit_han_same_type_grouping\",\n          input_text=\"10毫克\",  # \"10 milligrams\"\n          expected_tokens=[\n              (\"10\", tokenizer.TokenType.NUMBER),\n              (\"毫\", tokenizer.TokenType.WORD),\n              (\"克\", tokenizer.TokenType.WORD),\n          ],\n          expected_first_after_newline=[False, False, False],\n      ),\n      dict(\n          testcase_name=\"underscore_word_separator\",\n          input_text=\"hello_world\",\n          expected_tokens=[\n              (\"hello\", tokenizer.TokenType.WORD),\n              (\"_\", tokenizer.TokenType.PUNCTUATION),\n              (\"world\", tokenizer.TokenType.WORD),\n          ],\n          expected_first_after_newline=[False, False, False],\n      ),\n      dict(\n          testcase_name=\"leading_trailing_underscores\",\n          input_text=\"_test_case_\",\n          expected_tokens=[\n              (\"_\", tokenizer.TokenType.PUNCTUATION),\n              (\"test\", tokenizer.TokenType.WORD),\n              (\"_\", tokenizer.TokenType.PUNCTUATION),\n              (\"case\", tokenizer.TokenType.WORD),\n              (\"_\", tokenizer.TokenType.PUNCTUATION),\n          ],\n          expected_first_after_newline=[False, False, False, False, False],\n      ),\n  )\n  def test_special_unicode_and_punctuation_handling(\n      self, input_text, expected_tokens, expected_first_after_newline\n  ):\n    \"\"\"Test special Unicode sequences, punctuation grouping, and script handling edge cases.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(input_text)\n    self.assertLen(\n        tokenized.tokens,\n        len(expected_tokens),\n        f\"Expected {len(expected_tokens)} tokens for edge case test, but got\"\n        f\" {len(tokenized.tokens)}\",\n    )\n\n    for i, (\n        token,\n        (expected_text, expected_type),\n        expected_newline,\n    ) in enumerate(\n        zip(tokenized.tokens, expected_tokens, expected_first_after_newline)\n    ):\n      actual_text = input_text[\n          token.char_interval.start_pos : token.char_interval.end_pos\n      ]\n      self.assertEqual(\n          actual_text,\n          expected_text,\n          msg=f\"Token {i} text mismatch.\",\n      )\n      self.assertEqual(\n          token.token_type,\n          expected_type,\n          msg=f\"Token {i} type mismatch.\",\n      )\n      self.assertEqual(\n          token.first_token_after_newline,\n          expected_newline,\n          msg=f\"Token {i} newline flag mismatch.\",\n      )\n\n  def test_first_token_after_newline_parity(self):\n    \"\"\"Test that UnicodeTokenizer matches RegexTokenizer for newline detection.\"\"\"\n    input_text = \"a\\n b\"\n    regex_tok = tokenizer.RegexTokenizer()\n    regex_tokens = regex_tok.tokenize(input_text).tokens\n    self.assertTrue(regex_tokens[1].first_token_after_newline)\n\n    unicode_tok = tokenizer.UnicodeTokenizer()\n    unicode_tokens = unicode_tok.tokenize(input_text).tokens\n    self.assertTrue(\n        unicode_tokens[1].first_token_after_newline,\n        \"UnicodeTokenizer failed to detect newline in gap 'a\\\\n b'\",\n    )\n\n  def test_expanded_cjk_detection(self):\n    \"\"\"Test detection of CJK characters in extended ranges.\"\"\"\n    input_text = \"\\u4e00\\u3400\\U00020000\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(input_text)\n\n    self.assertLen(tokenized.tokens, 3)\n    for token in tokenized.tokens:\n      self.assertEqual(token.token_type, tokenizer.TokenType.WORD)\n\n  def test_mixed_script_and_emoji(self):\n    \"\"\"Test mixed script and emoji handling.\"\"\"\n    input_text = \"Hello👋🏼世界123\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(input_text)\n\n    expected_tokens = [\n        (\"Hello\", tokenizer.TokenType.WORD),\n        (\n            \"👋🏼\",\n            tokenizer.TokenType.PUNCTUATION,\n        ),\n        (\"世\", tokenizer.TokenType.WORD),\n        (\"界\", tokenizer.TokenType.WORD),\n        (\"123\", tokenizer.TokenType.NUMBER),\n    ]\n\n    self.assertLen(tokenized.tokens, len(expected_tokens))\n    for i, (expected_text, expected_type) in enumerate(expected_tokens):\n      token = tokenized.tokens[i]\n      actual_text = tokenized.text[\n          token.char_interval.start_pos : token.char_interval.end_pos\n      ]\n      self.assertEqual(actual_text, expected_text)\n      self.assertEqual(token.token_type, expected_type)\n\n  def test_script_boundary_grouping(self):\n    \"\"\"Test that we do NOT group characters from different scripts.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    text = \"HelloПривет\"\n    tokenized = tok.tokenize(text)\n\n    self.assertLen(tokenized.tokens, 2, \"Should be split into 2 tokens\")\n    self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD)\n    self.assertEqual(tokenized.tokens[1].token_type, tokenizer.TokenType.WORD)\n\n    t1_text = text[\n        tokenized.tokens[0]\n        .char_interval.start_pos : tokenized.tokens[0]\n        .char_interval.end_pos\n    ]\n    t2_text = text[\n        tokenized.tokens[1]\n        .char_interval.start_pos : tokenized.tokens[1]\n        .char_interval.end_pos\n    ]\n\n    self.assertEqual(t1_text, \"Hello\")\n    self.assertEqual(t2_text, \"Привет\")\n\n  def test_non_spaced_scripts_no_grouping(self):\n    \"\"\"Test that non-spaced scripts (Thai, Lao, etc.) are NOT grouped into a single word.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    text = \"สวัสดี\"\n    tokenized = tok.tokenize(text)\n\n    self.assertGreater(\n        len(tokenized.tokens), 1, \"Should not be grouped into a single token\"\n    )\n    self.assertLen(tokenized.tokens, 4)\n\n  def test_cjk_detection_regex(self):\n    \"\"\"Test that CJK characters are detected and not grouped.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    text = \"你好\"\n    tokenized = tok.tokenize(text)\n\n    self.assertLen(tokenized.tokens, 2)\n    self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD)\n    self.assertEqual(tokenized.tokens[1].token_type, tokenizer.TokenType.WORD)\n\n  def test_newline_simplification(self):\n    \"\"\"Test that newline handling works correctly with the simplified logic.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    text = \"LineA\\nLineB\"\n    tokenized = tok.tokenize(text)\n\n    self.assertLen(tokenized.tokens, 2)\n    self.assertEqual(tokenized.tokens[0].first_token_after_newline, False)\n    self.assertTrue(tokenized.tokens[1].first_token_after_newline)\n\n  def test_newline_simplification_start(self):\n    \"\"\"Test newline at start of text.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    text = \"\\nLineA\"\n    tokenized = tok.tokenize(text)\n\n    self.assertLen(tokenized.tokens, 1)\n    self.assertTrue(tokenized.tokens[0].first_token_after_newline)\n\n  def test_mixed_line_endings(self):\n    \"\"\"Test mixed line endings (\\\\r\\\\n).\"\"\"\n    # \\\\r\\\\n should be treated as a single newline for the purpose of the flag,\n    # or at least trigger it.\n    text = \"LineOne\\r\\nLineTwo\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(text)\n    self.assertLen(tokenized.tokens, 2)\n    self.assertTrue(tokenized.tokens[1].first_token_after_newline)\n\n  def test_mixed_uncommon_scripts_no_grouping(self):\n    \"\"\"Test that adjacent unknown scripts are NOT merged.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    # Armenian \"Բարև\" + Georgian \"გამარჯობა\".\n    # Both are \"unknown\" to _COMMON_SCRIPTS, so should not be grouped together.\n    text = \"Բարևგამარჯობა\"\n    tokenized = tok.tokenize(text)\n\n    # Unknown scripts are fragmented into characters for safety.\n    self.assertLen(\n        tokenized.tokens,\n        13,\n        \"Should be fragmented into characters for safety (13 tokens)\",\n    )\n    self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD)\n    self.assertEqual(tokenized.tokens[1].token_type, tokenizer.TokenType.WORD)\n\n  def test_unknown_script_merging_edge_case(self):\n    # Verify that adjacent IDENTICAL unknown scripts are fragmented for safety.\n    # Armenian \"Բարև\" + Armenian \"Բարև\".\n    tok = tokenizer.UnicodeTokenizer()\n    text = \"ԲարևԲարև\"\n    tokenized = tok.tokenize(text)\n    # Should be fragmented into 8 characters\n    self.assertLen(tokenized.tokens, 8)\n    self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD)\n\n  def test_find_sentence_range_empty_input(self):\n    # Ensure robustness against empty input, which previously caused a crash.\n    interval = tokenizer.find_sentence_range(\"\", [], 0)\n    self.assertEqual(interval, tokenizer.TokenInterval(0, 0))\n\n  def test_normalization_indices_match_input(self):\n    \"\"\"Test that token indices match the ORIGINAL input, not normalized text.\"\"\"\n    # \"e\" + combining acute accent (2 chars) -> NFC \"é\" (1 char)\n    nfd_text = \"e\\u0301\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(nfd_text)\n\n    # We want indices to match input, so CharInterval should be [0, 2).\n    self.assertEqual(tokenized.text, nfd_text)\n    self.assertLen(tokenized.tokens, 1)\n    self.assertEqual(tokenized.tokens[0].char_interval.start_pos, 0)\n    self.assertEqual(tokenized.tokens[0].char_interval.end_pos, 2)\n\n  def test_acronym_inconsistency(self):\n    \"\"\"Test that RegexTokenizer does NOT produce ACRONYM tokens (standardization).\"\"\"\n    tok = tokenizer.RegexTokenizer()\n    text = \"A/B\"\n    tokenized = tok.tokenize(text)\n    # Ensure parity with UnicodeTokenizer by splitting acronyms into constituent parts.\n    self.assertLen(tokenized.tokens, 3)\n    self.assertEqual(tokenized.tokens[0].token_type, tokenizer.TokenType.WORD)\n    self.assertEqual(\n        tokenized.tokens[1].token_type, tokenizer.TokenType.PUNCTUATION\n    )\n    self.assertEqual(tokenized.tokens[2].token_type, tokenizer.TokenType.WORD)\n\n  def test_consecutive_punctuation_grouping(self):\n    \"\"\"Test that consecutive punctuation is grouped into a single token.\"\"\"\n    input_text = \"Hello!! World...\"\n    expected_tokens = [\"Hello\", \"!!\", \"World\", \"...\"]\n    tokens = tokenizer.UnicodeTokenizer().tokenize(input_text).tokens\n    self.assertEqual(\n        [\n            input_text[t.char_interval.start_pos : t.char_interval.end_pos]\n            for t in tokens\n        ],\n        expected_tokens,\n    )\n\n  def test_punctuation_merging_identical_only(self):\n    \"\"\"Test that only identical punctuation is merged.\"\"\"\n    input_text = \"Hello!! World...\"\n    expected_tokens = [\"Hello\", \"!!\", \"World\", \"...\"]\n    tokens = tokenizer.UnicodeTokenizer().tokenize(input_text).tokens\n    self.assertEqual(\n        [\n            input_text[t.char_interval.start_pos : t.char_interval.end_pos]\n            for t in tokens\n        ],\n        expected_tokens,\n    )\n\n    input_text_mixed = 'End.\"'\n    expected_tokens_mixed = [\"End\", \".\", '\"']\n    tokens_mixed = (\n        tokenizer.UnicodeTokenizer().tokenize(input_text_mixed).tokens\n    )\n    self.assertEqual(\n        [\n            input_text_mixed[\n                t.char_interval.start_pos : t.char_interval.end_pos\n            ]\n            for t in tokens_mixed\n        ],\n        expected_tokens_mixed,\n    )\n\n  def test_distinct_unknown_scripts_do_not_merge(self):\n    \"\"\"Verify that distinct unknown scripts (e.g. Bengali vs Devanagari) are not merged.\"\"\"\n    # Bengali \"অ\" (U+0985) and Devanagari \"अ\" (U+0905)\n    text = \"অअ\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(text)\n\n    # Should be 2 tokens because scripts are different\n    self.assertLen(tokenized.tokens, 2)\n    self.assertEqual(tokenized.tokens[0].char_interval.start_pos, 0)\n    self.assertEqual(tokenized.tokens[0].char_interval.end_pos, 1)\n    self.assertEqual(tokenized.tokens[1].char_interval.start_pos, 1)\n    self.assertEqual(tokenized.tokens[1].char_interval.end_pos, 2)\n\n  def test_identical_unknown_scripts_merge(self):\n    \"\"\"Verify that identical unknown scripts merge into a single token.\"\"\"\n    # Bengali \"অ\" (U+0985) and Bengali \"আ\" (U+0986)\n    text = \"অআ\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(text)\n\n    # Identical unknown scripts are not merged to avoid expensive lookups.\n    self.assertLen(tokenized.tokens, 2)\n    self.assertEqual(tokenized.tokens[0].char_interval.start_pos, 0)\n    self.assertEqual(tokenized.tokens[0].char_interval.end_pos, 1)\n    self.assertEqual(tokenized.tokens[1].char_interval.start_pos, 1)\n    self.assertEqual(tokenized.tokens[1].char_interval.end_pos, 2)\n\n\nclass ExceptionTest(absltest.TestCase):\n  \"\"\"Test custom exception types and error conditions.\"\"\"\n\n  def test_invalid_token_interval_errors(self):\n    \"\"\"Test that InvalidTokenIntervalError is raised for invalid intervals.\"\"\"\n    text = \"Hello, world!\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(text)\n\n    with self.assertRaisesRegex(\n        tokenizer.InvalidTokenIntervalError,\n        \"Invalid token interval.*start_index=-1\",\n    ):\n      tokenizer.tokens_text(\n          tokenized, tokenizer.TokenInterval(start_index=-1, end_index=1)\n      )\n\n    with self.assertRaisesRegex(\n        tokenizer.InvalidTokenIntervalError,\n        \"Invalid token interval.*end_index=999\",\n    ):\n      tokenizer.tokens_text(\n          tokenized, tokenizer.TokenInterval(start_index=0, end_index=999)\n      )\n\n    with self.assertRaisesRegex(\n        tokenizer.InvalidTokenIntervalError,\n        \"Invalid token interval.*start_index=2.*end_index=1\",\n    ):\n      tokenizer.tokens_text(\n          tokenized, tokenizer.TokenInterval(start_index=2, end_index=1)\n      )\n\n  def test_sentence_range_errors(self):\n    \"\"\"Test that SentenceRangeError is raised for invalid start positions.\"\"\"\n    text = \"Hello world.\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokens = tok.tokenize(text).tokens\n\n    with self.assertRaisesRegex(\n        tokenizer.SentenceRangeError, \"start_token_index=-1 out of range\"\n    ):\n      tokenizer.find_sentence_range(text, tokens, -1)\n\n    with self.assertRaisesRegex(\n        tokenizer.SentenceRangeError,\n        \"start_token_index=999 out of range.*Total tokens: 3\",\n    ):\n      tokenizer.find_sentence_range(text, tokens, 999)\n\n    # Empty input should NOT raise SentenceRangeError (Feedback 10 Robustness)\n    interval = tokenizer.find_sentence_range(\"\", [], 0)\n    self.assertEqual(interval, tokenizer.TokenInterval(0, 0))\n\n\nclass NegativeTestCases(parameterized.TestCase):\n  \"\"\"Test cases for invalid input and edge cases.\"\"\"\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"invalid_utf8_sequence\",\n          input_text=\"Invalid \\ufffd sequence\",\n          expected_tokens=[\n              (\"Invalid\", tokenizer.TokenType.WORD),\n              (\n                  \"\\ufffd\",\n                  tokenizer.TokenType.PUNCTUATION,\n              ),\n              (\"sequence\", tokenizer.TokenType.WORD),\n          ],\n      ),\n      dict(\n          testcase_name=\"extremely_long_grapheme_cluster\",\n          input_text=\"e\" + \"\\u0301\" * 10,\n          expected_tokens=[\n              (\n                  \"e\" + \"\\u0301\" * 10,\n                  tokenizer.TokenType.WORD,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"mixed_valid_invalid_unicode\",\n          input_text=\"Valid текст \\ufffd 中文\",\n          expected_tokens=[\n              (\"Valid\", tokenizer.TokenType.WORD),\n              (\"текст\", tokenizer.TokenType.WORD),\n              (\"\\ufffd\", tokenizer.TokenType.PUNCTUATION),\n              (\"中\", tokenizer.TokenType.WORD),\n              (\"文\", tokenizer.TokenType.WORD),\n          ],\n      ),\n      dict(\n          testcase_name=\"zero_width_joiners\",\n          input_text=\"Family: 👨‍👩‍👧‍👦\",\n          expected_tokens=[\n              (\"Family\", tokenizer.TokenType.WORD),\n              (\":\", tokenizer.TokenType.PUNCTUATION),\n              (\n                  \"👨‍👩‍👧‍👦\",\n                  tokenizer.TokenType.PUNCTUATION,\n              ),\n          ],\n      ),\n      dict(\n          testcase_name=\"isolated_combining_marks\",\n          input_text=\"\\u0301\\u0302\\u0303 test\",\n          expected_tokens=[\n              (\n                  \"\\u0301\\u0302\\u0303\",\n                  tokenizer.TokenType.PUNCTUATION,\n              ),\n              (\"test\", tokenizer.TokenType.WORD),\n          ],\n      ),\n  )\n  def test_invalid_and_edge_case_unicode(self, input_text, expected_tokens):\n    \"\"\"Test handling of invalid Unicode sequences and edge cases.\"\"\"\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(input_text)\n    self.assertLen(\n        tokenized.tokens,\n        len(expected_tokens),\n        f\"Expected {len(expected_tokens)} tokens for edge case '{input_text}',\"\n        f\" but got {len(tokenized.tokens)}\",\n    )\n\n    for i, (token, (expected_text, expected_type)) in enumerate(\n        zip(tokenized.tokens, expected_tokens)\n    ):\n      # UPDATE: Tokenizer no longer normalizes to NFC, so we expect original text.\n      # expected_text = unicodedata.normalize(\"NFC\", expected_text)\n      actual_text = tokenized.text[\n          token.char_interval.start_pos : token.char_interval.end_pos\n      ]\n      self.assertEqual(\n          actual_text,\n          expected_text,\n          f\"Token {i} text mismatch. Expected '{expected_text}', got\"\n          f\" '{actual_text}'\",\n      )\n      self.assertEqual(\n          token.token_type,\n          expected_type,\n          f\"Token {i} type mismatch. Expected {expected_type}, got\"\n          f\" {token.token_type}\",\n      )\n\n  def test_empty_string_edge_case(self):\n    tok = tokenizer.UnicodeTokenizer()\n    tokenized = tok.tokenize(\"\")\n    self.assertEmpty(tokenized.tokens, \"Empty string should produce no tokens\")\n    self.assertEqual(\n        tokenized.text, \"\", \"Tokenized text should preserve empty string\"\n    )\n\n  def test_whitespace_only_string(self):\n    tok = tokenizer.UnicodeTokenizer()\n    test_cases = [\n        \"   \",  # Spaces\n        \"\\t\\t\",  # Tabs\n        \"\\n\\n\",  # Newlines\n        \" \\t\\n\\r \",  # Mixed whitespace\n    ]\n    for whitespace in test_cases:\n      tokenized = tok.tokenize(whitespace)\n      self.assertEmpty(\n          tokenized.tokens,\n          f\"Whitespace-only string '{repr(whitespace)}' should produce no\"\n          \" tokens\",\n      )\n\n\nclass TokensTextTest(parameterized.TestCase):\n\n  _SENTENCE_WITH_ONE_LINE = \"Patient Jane Doe, ID 67890, received 10mg daily.\"\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"substring_jane_doe\",\n          input_text=_SENTENCE_WITH_ONE_LINE,\n          start_index=1,\n          end_index=3,\n          expected_substring=\"Jane Doe\",\n      ),\n      dict(\n          testcase_name=\"substring_with_punctuation\",\n          input_text=_SENTENCE_WITH_ONE_LINE,\n          start_index=0,\n          end_index=4,\n          expected_substring=\"Patient Jane Doe,\",\n      ),\n      dict(\n          testcase_name=\"numeric_tokens\",\n          input_text=_SENTENCE_WITH_ONE_LINE,\n          start_index=5,\n          end_index=6,\n          expected_substring=\"67890\",\n      ),\n  )\n  def test_valid_intervals(\n      self, input_text, start_index, end_index, expected_substring\n  ):\n    input_tokenized = tokenizer.tokenize(input_text)\n    interval = tokenizer.TokenInterval(\n        start_index=start_index, end_index=end_index\n    )\n    result_str = tokenizer.tokens_text(input_tokenized, interval)\n    self.assertEqual(\n        result_str,\n        expected_substring,\n        msg=f\"Wrong substring for interval {start_index}..{end_index}\",\n    )\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"start_index_negative\",\n          input_text=_SENTENCE_WITH_ONE_LINE,\n          start_index=-1,\n          end_index=2,\n      ),\n      dict(\n          testcase_name=\"end_index_out_of_bounds\",\n          input_text=_SENTENCE_WITH_ONE_LINE,\n          start_index=0,\n          end_index=999,\n      ),\n      dict(\n          testcase_name=\"start_index_gt_end_index\",\n          input_text=_SENTENCE_WITH_ONE_LINE,\n          start_index=5,\n          end_index=4,\n      ),\n  )\n  def test_invalid_intervals(self, input_text, start_index, end_index):\n    input_tokenized = tokenizer.tokenize(input_text)\n    interval = tokenizer.TokenInterval(\n        start_index=start_index, end_index=end_index\n    )\n    with self.assertRaises(tokenizer.InvalidTokenIntervalError):\n      _ = tokenizer.tokens_text(input_tokenized, interval)\n\n\nclass SentenceRangeTest(parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"simple_sentence\",\n          input_text=\"This is one sentence. Then another?\",\n          start_pos=0,\n          expected_interval=(0, 5),\n      ),\n      dict(\n          testcase_name=\"abbreviation_not_boundary\",\n          input_text=\"Dr. John visited. Then left.\",\n          start_pos=0,\n          expected_interval=(0, 5),\n      ),\n      dict(\n          testcase_name=\"second_line_capital_letter_terminates_sentence\",\n          input_text=textwrap.dedent(\"\"\"\\\n              Blood pressure was 160/90 and patient was recommended to\n              Atenolol 50 mg daily.\"\"\"),\n          start_pos=0,\n          # \"160/90\" is now 3 tokens: \"160\", \"/\", \"90\".\n          # Tokens: Blood, pressure, was, 160, /, 90, and, patient, was, recommended, to (11 tokens)\n          expected_interval=(0, 11),\n      ),\n  )\n  def test_partial_sentence_range(\n      self, input_text, start_pos, expected_interval\n  ):\n    tokenized = tokenizer.tokenize(input_text)\n    tokens = tokenized.tokens\n\n    interval = tokenizer.find_sentence_range(input_text, tokens, start_pos)\n    expected_start, expected_end = expected_interval\n    self.assertEqual(interval.start_index, expected_start)\n    self.assertEqual(interval.end_index, expected_end)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"end_of_text\",\n          input_text=\"Only one sentence here\",\n          start_pos=0,\n      ),\n  )\n  def test_full_sentence_range(self, input_text, start_pos):\n    tokenized = tokenizer.tokenize(input_text)\n    tokens = tokenized.tokens\n\n    interval = tokenizer.find_sentence_range(input_text, tokens, start_pos)\n    self.assertEqual(interval.start_index, 0)\n    self.assertLen(tokens, interval.end_index)\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name=\"out_of_range_negative_start\",\n          input_text=\"Hello world.\",\n          start_pos=-1,\n      ),\n      dict(\n          testcase_name=\"out_of_range_exceeding_length\",\n          input_text=\"Hello world.\",\n          start_pos=999,\n      ),\n  )\n  def test_invalid_start_pos(self, input_text, start_pos):\n    tokenized = tokenizer.tokenize(input_text)\n    tokens = tokenized.tokens\n    with self.assertRaises(tokenizer.SentenceRangeError):\n      tokenizer.find_sentence_range(input_text, tokens, start_pos)\n\n  def test_sentence_boundary_with_quote(self):\n    \"\"\"Test that sentence boundary detection works with trailing quotes.\"\"\"\n    text = 'He said \"Hello.\"'\n    tokens = tokenizer.UnicodeTokenizer().tokenize(text).tokens\n    interval = tokenizer.find_sentence_range(text, tokens, 0)\n    self.assertEqual(interval.end_index, len(tokens))\n\n  def test_sentence_splitting_permissive(self):\n    \"\"\"Test permissive sentence splitting (quotes, numbers, \\\\r).\"\"\"\n    # Quote-initiated sentence.\n    text_quote = '\"The time is now.\" Next sentence.'\n    tokens = tokenizer.UnicodeTokenizer().tokenize(text_quote).tokens\n    interval = tokenizer.find_sentence_range(text_quote, tokens, 0)\n    self.assertEqual(interval.end_index, 7)\n\n    # Number-initiated sentence.\n    text_number = \"2025 will be good. Really.\"\n    tokens = tokenizer.tokenize(text_number).tokens\n    interval = tokenizer.find_sentence_range(text_number, tokens, 0)\n    self.assertEqual(interval.end_index, 5)\n\n    # Carriage return support.\n    text_cr = \"Line one.\\rLine two.\"\n    tokens = tokenizer.tokenize(text_cr).tokens\n    interval = tokenizer.find_sentence_range(text_cr, tokens, 0)\n    self.assertEqual(interval.end_index, 3)\n\n  def test_unicode_sentence_boundaries(self):\n    \"\"\"Verify that Unicode sentence terminators are respected.\"\"\"\n    # Japanese full stop\n    text_jp = \"こんにちは。世界。\"\n    tokens = tokenizer.UnicodeTokenizer().tokenize(text_jp).tokens\n    interval = tokenizer.find_sentence_range(text_jp, tokens, 0)\n    # \"こんにちは\" (5 tokens due to CJK fragmentation) + \"。\" (1 token) = 6 tokens\n    self.assertEqual(interval.end_index, 6)\n\n    # Hindi Danda\n    text_hi = \"नमस्ते। दुनिया।\"\n    tokens = tokenizer.UnicodeTokenizer().tokenize(text_hi).tokens\n    interval = tokenizer.find_sentence_range(text_hi, tokens, 0)\n    # \"नमस्ते\" (1 token, Devanagari is grouped) + \"।\" (1 token) = 2 tokens\n    self.assertEqual(interval.end_index, 2)\n\n  def test_configurable_sentence_splitting(self):\n    \"\"\"Verify that custom abbreviations prevent sentence splitting.\"\"\"\n    # Test with custom abbreviations (e.g. German \"z.B.\")\n    text = \"Das ist z.B. ein Test.\"\n    tok = tokenizer.RegexTokenizer()\n    _ = tok.tokenize(text)\n\n    text_french = \"M. Smith est ici.\"\n    tokenized_french = tok.tokenize(text_french)\n    # \"M.\" is not in default _KNOWN_ABBREVIATIONS (\"Mr.\", \"Mrs.\", etc.)\n\n    # Default: \"M.\" ends sentence.\n    sentence1 = tokenizer.find_sentence_range(\n        text_french, tokenized_french.tokens, 0\n    )\n    self.assertEqual(sentence1.end_index, 2)\n\n    # Now with custom abbreviations\n    custom_abbrevs = {\"M.\"}\n    sentence2 = tokenizer.find_sentence_range(\n        text_french,\n        tokenized_french.tokens,\n        0,\n        known_abbreviations=custom_abbrevs,\n    )\n\n    # Should NOT split at \"M.\"\n    self.assertEqual(sentence2.end_index, 6)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tests/visualization_test.py",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for langextract.visualization.\"\"\"\n\nfrom unittest import mock\n\nfrom absl.testing import absltest\n\nfrom langextract import visualization\nfrom langextract.core import data\n\n_PALETTE = visualization._PALETTE\n_VISUALIZATION_CSS = visualization._VISUALIZATION_CSS\n\n\nclass VisualizationTest(absltest.TestCase):\n\n  def test_assign_colors_basic_assignment(self):\n\n    extractions = [\n        data.Extraction(\n            extraction_class=\"CLASS_A\",\n            extraction_text=\"text_a\",\n            char_interval=data.CharInterval(start_pos=0, end_pos=1),\n        ),\n        data.Extraction(\n            extraction_class=\"CLASS_B\",\n            extraction_text=\"text_b\",\n            char_interval=data.CharInterval(start_pos=1, end_pos=2),\n        ),\n    ]\n    # Classes are sorted alphabetically before color assignment.\n    expected_color_map = {\n        \"CLASS_A\": _PALETTE[0],\n        \"CLASS_B\": _PALETTE[1],\n    }\n\n    actual_color_map = visualization._assign_colors(extractions)\n\n    self.assertDictEqual(actual_color_map, expected_color_map)\n\n  def test_build_highlighted_text_single_span_correct_html(self):\n\n    text = \"Hello world\"\n    extraction = data.Extraction(\n        extraction_class=\"GREETING\",\n        extraction_text=\"Hello\",\n        char_interval=data.CharInterval(start_pos=0, end_pos=5),\n    )\n    extractions = [extraction]\n    color_map = {\"GREETING\": \"#ff0000\"}\n    expected_html = (\n        '<span class=\"lx-highlight lx-current-highlight\" data-idx=\"0\" '\n        'style=\"background-color:#ff0000;\">Hello</span> world'\n    )\n\n    actual_html = visualization._build_highlighted_text(\n        text, extractions, color_map\n    )\n\n    self.assertEqual(actual_html, expected_html)\n\n  def test_build_highlighted_text_escapes_html_in_text_and_tooltip(self):\n\n    text = \"Text with <unsafe> content & ampersand.\"\n    extraction = data.Extraction(\n        extraction_class=\"UNSAFE_CLASS\",\n        extraction_text=\"<unsafe> content & ampersand.\",\n        char_interval=data.CharInterval(start_pos=10, end_pos=39),\n        attributes={\"detail\": \"Attribute with <tag> & 'quote'\"},\n    )\n    # Highlighting \"<unsafe> content & ampersand\"\n    extractions = [extraction]\n    color_map = {\"UNSAFE_CLASS\": \"#00ff00\"}\n    expected_highlighted_segment = \"&lt;unsafe&gt; content &amp; ampersand.\"\n    expected_html = (\n        'Text with <span class=\"lx-highlight lx-current-highlight\"'\n        ' data-idx=\"0\" '\n        f'style=\"background-color:#00ff00;\">{expected_highlighted_segment}</span>'\n    )\n\n    actual_html = visualization._build_highlighted_text(\n        text, extractions, color_map\n    )\n\n    self.assertEqual(actual_html, expected_html)\n\n  @mock.patch.object(\n      visualization, \"HTML\", new=None\n  )  # Ensures visualize returns str\n  def test_visualize_basic_document_renders_correctly(self):\n\n    doc = data.AnnotatedDocument(\n        text=\"Patient needs Aspirin.\",\n        extractions=[\n            data.Extraction(\n                extraction_class=\"MEDICATION\",\n                extraction_text=\"Aspirin\",\n                char_interval=data.CharInterval(\n                    start_pos=14, end_pos=21\n                ),  # \"Aspirin\"\n            )\n        ],\n    )\n    # Predictable color based on sorted class name \"MEDICATION\"\n    med_color = _PALETTE[0]\n    body_html = (\n        'Patient needs <span class=\"lx-highlight lx-current-highlight\"'\n        f' data-idx=\"0\" style=\"background-color:{med_color};\">Aspirin</span>.'\n    )\n    legend_html = (\n        '<div class=\"lx-legend\">Highlights Legend: <span class=\"lx-label\" '\n        f'style=\"background-color:{med_color};\">MEDICATION</span></div>'\n    )\n    css_html = _VISUALIZATION_CSS\n    expected_components = [\n        css_html,\n        \"lx-animated-wrapper\",\n        body_html,\n        legend_html,\n    ]\n\n    actual_html = visualization.visualize(doc)\n\n    # Verify expected components appear in output\n    for component in expected_components:\n      self.assertIn(component, actual_html)\n\n  @mock.patch.object(\n      visualization, \"HTML\", new=None\n  )  # Ensures visualize returns str\n  def test_visualize_no_extractions_renders_text_and_empty_legend(self):\n\n    doc = data.AnnotatedDocument(text=\"No entities here.\", extractions=[])\n    body_html = (\n        '<div class=\"lx-animated-wrapper\"><p>No valid extractions to'\n        \" animate.</p></div>\"\n    )\n    css_html = _VISUALIZATION_CSS\n    expected_html = css_html + body_html\n\n    actual_html = visualization.visualize(doc)\n\n    self.assertEqual(actual_html, expected_html)\n\n\nif __name__ == \"__main__\":\n  absltest.main()\n"
  },
  {
    "path": "tox.ini",
    "content": "# Copyright 2025 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n[tox]\nenvlist = py310, py311, py312, format, lint-src, lint-tests\nskip_missing_interpreters = True\n\n[testenv]\nsetenv =\n    PYTHONWARNINGS = ignore\ndeps =\n    .[openai,dev,test]\ncommands =\n    pytest -ra -m \"not live_api and not requires_pip\"\n\n[testenv:format]\nskip_install = true\ndeps =\n    isort>=5.13.2\n    pyink~=24.3.0\ncommands =\n    isort langextract tests --check-only --diff\n    pyink langextract tests --check --diff --config pyproject.toml\n\n[testenv:lint-src]\ndeps =\n    pylint>=3.0.0\ncommands =\n    pylint --rcfile=.pylintrc langextract\n\n[testenv:lint-tests]\ndeps =\n    pylint>=3.0.0\ncommands =\n    pylint --rcfile=tests/.pylintrc tests\n\n[testenv:live-api]\nbasepython = python3.11\npassenv =\n    GEMINI_API_KEY\n    LANGEXTRACT_API_KEY\n    OPENAI_API_KEY\n    GOOGLE_APPLICATION_CREDENTIALS\n    GOOGLE_CLOUD_PROJECT\ndeps = .[all,dev,test]\ncommands =\n    pytest tests/test_live_api.py -v -m live_api --maxfail=1\n\n[testenv:ollama-integration]\nbasepython = python3.11\ndeps =\n    .[openai,dev,test]\n    requests>=2.25.0\ncommands =\n    pytest tests/test_ollama_integration.py -v --tb=short\n\n[testenv:plugin-integration]\nbasepython = python3.11\nsetenv =\n    PIP_NO_INPUT = 1\n    PIP_DISABLE_PIP_VERSION_CHECK = 1\ndeps =\n    .[dev,test]\ncommands =\n    pytest tests/provider_plugin_test.py::PluginE2ETest -v -m \"requires_pip\"\n\n[testenv:plugin-smoke]\nbasepython = python3.11\ndeps =\n    .[dev,test]\ncommands =\n    pytest tests/provider_plugin_test.py::PluginSmokeTest -v\n"
  }
]