[
  {
    "path": ".github/workflows/release.yaml",
    "content": "name: Release GLiClass to PyPI\n\non:\n  push:\n    tags:\n      - 'v*'  # Trigger on version tags (e.g., v1.0.0, v2.1.3)\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  build:\n    name: Build distribution 📦\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v6\n      with:\n        persist-credentials: false\n    - name: Set up Python\n      uses: actions/setup-python@v6\n      with:\n        python-version: \"3.x\"\n    - name: Install pypa/build\n      run: >-\n        python3 -m\n        pip install\n        build\n        --user\n    - name: Build a binary wheel and a source tarball\n      run: python3 -m build\n    - name: Store the distribution packages\n      uses: actions/upload-artifact@v5\n      with:\n        name: python-package-distributions\n        path: dist/\n\n  publish-to-pypi:\n    name: >-\n      Publish Python 🐍 distribution 📦 to PyPI\n    if: startsWith(github.ref, 'refs/tags/')  # only publish to PyPI on tag pushes\n    needs:\n    - build\n    runs-on: ubuntu-latest\n    environment:\n      name: pypi\n      url: https://pypi.org/project/gliclass/  # Replace <package-name> with your PyPI project name\n    permissions:\n      id-token: write  # IMPORTANT: mandatory for trusted publishing\n\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 0  # Fetch all history to check branches\n    - name: Verify tag is on main branch\n      run: |\n        if ! git branch -r --contains ${{ github.ref_name }} | grep -q 'origin/main'; then\n          echo \"Error: Tag ${{ github.ref_name }} is not on the main branch\"\n          exit 1\n        fi\n        echo \"✓ Tag ${{ github.ref_name }} is on main branch\"\n    - name: Download all the dists\n      uses: actions/download-artifact@v6\n      with:\n        name: python-package-distributions\n        path: dist/\n    - name: Publish distribution 📦 to PyPI\n      uses: pypa/gh-action-pypi-publish@release/v1"
  },
  {
    "path": ".github/workflows/tests.yml",
    "content": "name: Tests\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n  workflow_dispatch:\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  test:\n    name: pytest (Python ${{ matrix.python-version }})\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: [\"3.10\", \"3.11\", \"3.12\"]\n\n    steps:\n      - name: Check out repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Cache pip\n        uses: actions/cache@v4\n        with:\n          path: ~/.cache/pip\n          key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('pyproject.toml') }}\n          restore-keys: |\n            ${{ runner.os }}-py${{ matrix.python-version }}-pip-\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install -e .\n          pip install pytest pytest-asyncio\n\n      - name: Run pytest\n        run: pytest -v --tb=short\n\n  lint:\n    name: ruff\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check out repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Install ruff\n        run: pip install ruff\n\n      - name: ruff check\n        run: ruff check gliclass\n\n      - name: ruff format --check\n        run: ruff format --check gliclass\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n#custom\nmodels/\nwandb/\ngradio_cached_examples/\ntest.ipynb\ndemo1.py\n.gradio/\nuv.lock\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n.ruff_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# ⭐ GLiClass: Generalist and Lightweight Model for Sequence Classification\n\n**GLiClass** is an efficient, zero-shot sequence classification model inspired by the [GLiNER](https://github.com/urchade/GLiNER/tree/main) framework. It achieves comparable performance to traditional cross-encoder models while being significantly more computationally efficient, offering classification results approximately **10 times faster** by performing classification in a single forward pass.\n\n<p align=\"center\">\n    <a href=\"https://medium.com/@knowledgrator/pushing-zero-shot-classification-to-the-limit-696a2403032f\">📄 Blog</a>\n    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>\n    <a href=\"https://discord.gg/dkyeAgs9DG\">📢 Discord</a>\n    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>\n    <a href=\"https://huggingface.co/spaces/knowledgator/GLiClass_SandBox\">📺 Demo</a>\n    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>\n    <a href=\"https://huggingface.co/models?sort=trending&search=gliclass\">🤗 Available models</a>\n    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>\n    <a href=\"https://colab.research.google.com/github/Knowledgator/GLiClass/blob/main/finetuning.ipynb\">\n        <img align=\"center\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" />\n    </a>\n</p>\n\n### 🚀 Quick Start\n\nInstall GLiClass easily using pip:\n\n```bash\npip install gliclass\n```\n\n#### Install from Source\n\nClone and install directly from GitHub:\n\n```bash\ngit clone https://github.com/Knowledgator/GLiClass\ncd GLiClass\n\npython -m venv venv\nsource venv/bin/activate  # Windows: venv\\Scripts\\activate\n\npip install -r requirements.txt\npip install .\n```\n\nVerify your installation:\n\n```python\nimport gliclass\nprint(gliclass.__version__)\n```\n\n### 🧑‍💻 Usage Example\n\n```python\nfrom gliclass import GLiClassModel, ZeroShotClassificationPipeline\nfrom transformers import AutoTokenizer\n\nmodel = GLiClassModel.from_pretrained(\"knowledgator/gliclass-small-v1.0\")\ntokenizer = AutoTokenizer.from_pretrained(\"knowledgator/gliclass-small-v1.0\")\n\npipeline = ZeroShotClassificationPipeline(\n    model, tokenizer, classification_type='multi-label', device='cuda:0'\n)\n\ntext = \"One day I will see the world!\"\nlabels = [\"travel\", \"dreams\", \"sport\", \"science\", \"politics\"]\nresults = pipeline(text, labels, threshold=0.5)[0]\n\nfor result in results:\n    print(f\"{result['label']} => {result['score']:.3f}\")\n```\n\n### 🔥 New Features\n\n#### Hierarchical Labels\n\nGLiClass now supports hierarchical label structures using dot notation:\n\n```python\nhierarchical_labels = {\n    \"sentiment\": [\"positive\", \"negative\", \"neutral\"],\n    \"topic\": [\"product\", \"service\", \"shipping\"]\n}\n\ntext = \"The product quality is amazing but delivery was slow\"\nresults = pipeline(text, hierarchical_labels, threshold=0.5)[0]\n\nfor result in results:\n    print(f\"{result['label']} => {result['score']:.3f}\")\n# Output:\n# sentiment.positive => 0.892\n# topic.product => 0.921\n# topic.shipping => 0.763\n```\n\nGet hierarchical output matching your input structure:\n\n```python\nresults = pipeline(text, hierarchical_labels, return_hierarchical=True)[0]\nprint(results)\n# Output:\n# {\n#     \"sentiment\": {\"positive\": 0.892, \"negative\": 0.051, \"neutral\": 0.124},\n#     \"topic\": {\"product\": 0.921, \"service\": 0.153, \"shipping\": 0.763}\n# }\n```\n\n#### Few-Shot Examples\n\nImprove classification accuracy with in-context examples using the `<<EXAMPLE>>` token:\n\n```python\nexamples = [\n    {\n        \"text\": \"Love this item, great quality!\",\n        \"labels\": [\"positive\", \"product\"]\n    },\n    {\n        \"text\": \"Customer support was unhelpful\",\n        \"labels\": [\"negative\", \"service\"]\n    }\n]\n\ntext = \"Fast delivery and the item works perfectly!\"\nlabels = [\"positive\", \"negative\", \"product\", \"service\", \"shipping\"]\n\nresults = pipeline(text, labels, examples=examples, threshold=0.5)[0]\n\nfor result in results:\n    print(f\"{result['label']} => {result['score']:.3f}\")\n```\n\n#### Task Description Prompts\n\nAdd custom prompts to guide the classification task:\n\n```python\ntext = \"The battery life on this phone is incredible\"\nlabels = [\"positive\", \"negative\", \"neutral\"]\n\nresults = pipeline(\n    text,\n    labels,\n    prompt=\"Classify the sentiment of this product review:\",\n    threshold=0.5\n)[0]\n```\n\nUse per-text prompts for batch processing:\n\n```python\ntexts = [\"Review about electronics\", \"Review about clothing\"]\nprompts = [\n    \"Analyze this electronics review:\",\n    \"Analyze this clothing review:\"\n]\n\nresults = pipeline(texts, labels, prompt=prompts)\n```\n\n#### Long Document Classification\n\nProcess long documents with automatic text chunking:\n\n```python\nfrom gliclass import ZeroShotClassificationWithChunkingPipeline\n\nchunking_pipeline = ZeroShotClassificationWithChunkingPipeline(\n    model,\n    tokenizer,\n    text_chunk_size=8192,\n    text_chunk_overlap=256,\n    labels_chunk_size=8\n)\n\nlong_document = \"...\" # Very long text\nlabels = [\"category1\", \"category2\", \"category3\"]\n\nresults = chunking_pipeline(long_document, labels, threshold=0.5)\n```\n\n### 🌟 Retrieval-Augmented Classification (RAC)\n\nWith new models trained with retrieval-agumented classification, such as [this model](https://huggingface.co/knowledgator/gliclass-base-v2.0-rac-init) you can specify examples to improve classification accuracy:\n\n```python\nexample = {\n    \"text\": \"A new machine learning platform automates complex data workflows but faces integration issues.\",\n    \"all_labels\": [\"AI\", \"automation\", \"data_analysis\", \"usability\", \"integration\"],\n    \"true_labels\": [\"AI\", \"integration\", \"automation\"]\n}\n\ntext = \"The new AI-powered tool streamlines data analysis but has limited integration capabilities.\"\nlabels = [\"AI\", \"automation\", \"data_analysis\", \"usability\", \"integration\"]\n\nresults = pipeline(text, labels, threshold=0.1, rac_examples=[example])[0]\n\nfor predict in results:\n    print(f\"{predict['label']} => {predict['score']:.3f}\")\n```\n\n### 🚀 Production Serving\n\nDeploy GLiClass with Ray Serve for production workloads with dynamic batching and memory-aware processing.\n\n#### Installation\n\n```bash\npip install gliclass[serve]\n```\n\n#### Quick Start\n\n```bash\n# Default model\npython -m gliclass.serve\n\n# Specify model and port\npython -m gliclass.serve --model knowledgator/gliclass-edge-v3.0 --port 8000\n\n# With config file\npython -m gliclass.serve --config serve_configs/serve_config.yaml\n```\n\n#### Python Client\n\n```python\nfrom gliclass.serve import GLiClassClient\n\nclient = GLiClassClient(url=\"http://localhost:8000/gliclass\")\n\nresult = client.classify(\n    text=\"This is a great product!\",\n    labels=[\"positive\", \"negative\", \"neutral\"],\n    threshold=0.3,\n)\nprint(result)  # [{\"label\": \"positive\", \"score\": 0.95}, ...]\n```\n\n#### HTTP API\n\nThe HTTP endpoint processes one text per request.\n\n```bash\ncurl -X POST http://localhost:8000/gliclass \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"texts\": \"This is a great product!\",\n    \"labels\": [\"positive\", \"negative\", \"neutral\"],\n    \"threshold\": 0.3\n  }'\n\n# Response: [{\"label\": \"positive\", \"score\": 0.95}, ...]\n```\n\n**Note:** For batch processing multiple texts, use the `ZeroShotClassificationPipeline` directly instead of the serving API.\n\nSee `serve_configs/serve_config.yaml` for full configuration options.\n\n### 🎯 Key Use Cases\n\n- **Sentiment Analysis:** Rapidly classify texts as positive, negative, or neutral.\n- **Document Classification:** Efficiently organize and categorize large document collections.\n- **Search Results Re-ranking:** Improve relevance and precision by reranking search outputs.\n- **News Categorization:** Automatically tag and organize news articles into predefined categories.\n- **Fact Checking:** Quickly validate and categorize statements based on factual accuracy.\n\n### 🛠️ How to Train\n\nPrepare your training data as follows:\n\n```json\n[\n  {\"text\": \"Sample text.\", \"all_labels\": [\"sports\", \"science\", \"business\"], \"true_labels\": [\"sports\"]},\n  ...\n]\n```\n\nOptionally, specify confidence scores explicitly:\n\n```json\n[\n  {\"text\": \"Sample text.\", \"all_labels\": [\"sports\", \"science\"], \"true_labels\": {\"sports\": 0.9}},\n  ...\n]\n```\n\nPlease, refer to the `train.py` script to set up your training from scratch or fine-tune existing models.\n\n### ⚙️ Advanced Configuration\n\n#### Architecture Types\n\nGLiClass supports multiple architecture types:\n\n- **uni-encoder**: Single encoder for both text and labels (default, most efficient)\n- **bi-encoder**: Separate encoders for text and labels\n- **bi-encoder-fused**: Bi-encoder with label embeddings fused into text encoding\n- **encoder-decoder**: Encoder-decoder architecture for sequence-to-sequence tasks\n\n```python\nfrom gliclass import GLiClassBiEncoder\n\n# Load a bi-encoder model\nmodel = GLiClassBiEncoder.from_pretrained(\"knowledgator/gliclass-biencoder-v1.0\")\n```\n\n#### Pooling Strategies\n\nConfigure how token embeddings are pooled:\n\n- `first`: First token (CLS token)\n- `avg`: Average pooling\n- `max`: Max pooling\n- `last`: Last token\n- `sum`: Sum pooling\n- `rms`: Root mean square pooling\n- `abs_max`: Max of absolute values\n- `abs_avg`: Average of absolute values\n\n```python\nfrom gliclass import GLiClassModelConfig\n\nconfig = GLiClassModelConfig(\n    pooling_strategy='avg',\n    class_token_pooling='average'  # or 'first'\n)\n```\n\n#### Scoring Mechanisms\n\nChoose different scoring mechanisms for classification:\n\n- `simple`: Dot product (fastest)\n- `weighted-dot`: Weighted dot product with learned projections\n- `mlp`: Multi-layer perceptron scorer\n- `hopfield`: Hopfield network-based scorer\n\n```python\nconfig = GLiClassModelConfig(\n    scorer_type='mlp'\n)\n```\n\n---\n\n### Flash Attention Backends\n\nGLiClass supports optional flash attention backends for faster inference.\n\n#### Install\n\n```bash\npip install flashdeberta   # DeBERTa v2\npip install turbot5        # T5 / mT5\n```\n\n---\n\n#### FlashDeBERTa (DeBERTa v2)\n\nEnable via environment variable:\n\n```bash\nexport USE_FLASHDEBERTA=1\n```\n\nIf `flashdeberta` is installed, DeBERTa v2 models will use `FlashDebertaV2Model`.\nOtherwise, GLiClass falls back to `DebertaV2Model`.\n\n---\n\n#### TurboT5 (T5 / mT5)\n\nEnable via environment variable:\n\n```bash\nexport TURBOT5_ATTN_TYPE=triton-basic\n```\n\nIf `turbot5` is installed, T5 / mT5 models will use `FlashT5EncoderModel`.\nOtherwise, GLiClass falls back to `T5EncoderModel`.\n\nNotes:\n* Flash backends are **optional**\n* Enabled automatically when available\n* No code changes required\n\nWant it even tighter (single block), or is this the sweet spot?\n\n\n## 📚 Citations\n\nIf you find GLiClass useful in your research or project, please cite our papers:\n\n\n```bibtex\n@misc{stepanov2025gliclassgeneralistlightweightmodel,\n      title={GLiClass: Generalist Lightweight Model for Sequence Classification Tasks}, \n      author={Ihor Stepanov and Mykhailo Shtopko and Dmytro Vodianytskyi and Oleksandr Lukashov and Alexander Yavorskyi and Mykyta Yaroshenko},\n      year={2025},\n      eprint={2508.07662},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2508.07662}, \n}\n```\n"
  },
  {
    "path": "demo.py",
    "content": "\"\"\"\nGLiClass Enhanced Demo with Advanced Features\n\nFeatures:\n- Task description prompts\n- Hierarchical label inputs (JSON format)\n- Few-shot examples\n- Hierarchical output structure\n- Label descriptions\n\"\"\"\n\nimport json\nfrom typing import Dict, List, Any, Union, Optional\nimport gradio as gr\nimport torch\nfrom transformers import AutoTokenizer\n\nfrom gliclass import GLiClassModel, ZeroShotClassificationPipeline\n\n# Initialize model and pipeline\ndevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n\nmodel_path = \"knowledgator/gliclass-small-v1.0\"\nmodel = GLiClassModel.from_pretrained(model_path)\ntokenizer = AutoTokenizer.from_pretrained(model_path)\n\npipeline = ZeroShotClassificationPipeline(\n    model, tokenizer, \n    classification_type='multi-label', \n    device=device\n)\n\n# ============== Example Texts ==============\n\nTEXT_PRODUCT_REVIEW = \"\"\"\nI recently purchased the Sony WH-1000XM4 Wireless Noise-Canceling Headphones from Amazon and I must say, I'm thoroughly impressed. The package arrived in New York within 2 days, thanks to Amazon Prime's expedited shipping.\n\nThe headphones themselves are remarkable. The noise-canceling feature works like a charm in the bustling city environment, and the 30-hour battery life means I don't have to charge them every day. Connecting them to my Samsung Galaxy S21 was a breeze, and the sound quality is second to none.\n\nI also appreciated the customer service from Amazon when I had a question about the warranty. They responded within an hour and provided all the information I needed.\n\nHowever, the headphones did not come with a hard case, which was listed in the product description. I contacted Amazon, and they offered a 10% discount on my next purchase as an apology.\n\nOverall, I'd give these headphones a 4.5/5 rating and highly recommend them to anyone looking for top-notch quality in both product and service.\n\"\"\"\n\nTEXT_TECH_COMPANIES = \"\"\"\nApple Inc. is an American multinational technology company headquartered in Cupertino, California. Apple is the world's largest technology company by revenue, with US$394.3 billion in 2022 revenue. As of March 2023, Apple is the world's biggest company by market capitalization.\n\nMicrosoft was founded by Bill Gates and Paul Allen on April 4, 1975 to develop and sell BASIC interpreters for the Altair 8800. During his career at Microsoft, Gates held the positions of chairman, chief executive officer, president and chief software architect.\n\nApple was founded as Apple Computer Company on April 1, 1976, by Steve Wozniak, Steve Jobs (1955–2011) and Ronald Wayne to develop and sell Wozniak's Apple I personal computer.\n\"\"\"\n\nTEXT_SCIENTIFIC = \"\"\"\nSeveral studies have reported its pharmacological activities, including anti-inflammatory, antimicrobial, and antitumoral effects. \nThe effect of E-anethole was studied in the osteosarcoma MG-63 cell line, and the antiproliferative activity was evaluated by an MTT assay. \nIt showed a GI50 value of 60.25 μM with apoptosis induction through the mitochondrial-mediated pathway. Additionally, it induced cell cycle arrest at the G0/G1 phase, up-regulated the expression of p53, caspase-3, and caspase-9, and down-regulated Bcl-xL expression.\n\"\"\"\n\nTEXT_RESTAURANT_REVIEW = \"\"\"\nWe visited La Maison last Friday for our anniversary dinner. The ambiance was absolutely stunning - dim lighting, soft jazz music, and elegant table settings. Our waiter, Marcus, was incredibly attentive without being intrusive.\n\nFor appetizers, we had the truffle bruschetta and the soup of the day. Both were divine! The main courses - filet mignon for me and lobster risotto for my wife - were cooked to perfection. \n\nThe only downside was the wait time for our desserts, which took about 25 minutes. However, the chocolate soufflé was worth the wait!\n\nPrice was on the higher side ($180 for two), but the quality justified the cost. Will definitely return!\n\"\"\"\n\nTEXT_NEWS_POLITICS = \"\"\"\nThe Senate passed a landmark bipartisan infrastructure bill late Thursday night, allocating $1.2 trillion for roads, bridges, broadband internet, and clean energy initiatives. The vote was 69-30, with 19 Republican senators joining all Democrats in support.\n\nPresident Biden called the passage \"a historic investment in America's future\" and urged the House to act quickly. However, progressive Democrats have signaled they won't vote for the infrastructure bill unless it's paired with a larger social spending package.\n\nSenate Minority Leader criticized portions of the bill related to climate spending, calling them \"unnecessary green new deal provisions,\" while environmental groups praised the clean energy investments as \"a step in the right direction, but not nearly enough.\"\n\"\"\"\n\nTEXT_SPORTS = \"\"\"\nIn a thrilling overtime finish, the Lakers defeated the Celtics 118-112 in Game 7 of the NBA Finals. LeBron James delivered a historic performance with 42 points, 16 rebounds, and 10 assists, securing his fifth championship ring and fourth Finals MVP award.\n\nThe game was tied at 102 with 30 seconds remaining in regulation when Marcus Smart hit a contested three-pointer. However, James answered with a driving layup at the buzzer to force overtime.\n\nIn the extra period, the Lakers outscored Boston 16-10, with Anthony Davis contributing two crucial blocks in the final minute. \"This is what you dream about as a kid,\" James said in the post-game interview. \"Playing against the Celtics, Game 7, everything on the line.\"\n\"\"\"\n\nTEXT_MOVIE_REVIEW = \"\"\"\nChristopher Nolan's \"Oppenheimer\" is a masterwork of biographical cinema that demands to be seen on the largest screen possible. Cillian Murphy delivers a career-defining performance as J. Robert Oppenheimer, capturing both the brilliance and moral anguish of the father of the atomic bomb.\n\nThe film's nonlinear structure, weaving between the Manhattan Project, the 1954 security hearing, and the 1959 Lewis Strauss confirmation hearing, could have been confusing. Instead, Nolan crafts a compelling narrative that builds to a devastating emotional climax.\n\nAt three hours, some viewers may find the pacing challenging, particularly in the courtroom sequences. However, the technical achievements - Ludwig Göransson's haunting score, Hoyte van Hoytema's IMAX cinematography - make this an unmissable theatrical experience. Rating: 9/10\n\"\"\"\n\nTEXT_TECH_STARTUP = \"\"\"\nSan Francisco-based AI startup Anthropic announced today it has raised $450 million in Series C funding, valuing the company at $5 billion. The round was led by Spark Capital, with participation from Google and existing investors.\n\nFounded in 2021 by former OpenAI researchers Dario and Daniela Amodei, Anthropic has positioned itself as a leader in AI safety research. The company's Claude assistant has gained significant market share in the enterprise segment.\n\n\"This funding will accelerate our research into interpretable and steerable AI systems,\" said CEO Dario Amodei. \"We believe safety and capability go hand in hand.\" The company plans to double its research team and expand internationally, with offices planned in London and Tokyo.\n\"\"\"\n\nTEXT_HEALTH_WELLNESS = \"\"\"\nA new study published in the Journal of the American Medical Association suggests that intermittent fasting may offer significant benefits beyond weight loss. Researchers followed 500 participants over two years and found improvements in cardiovascular health markers, insulin sensitivity, and cognitive function.\n\nParticipants who followed a 16:8 fasting protocol (eating within an 8-hour window) showed a 15% reduction in LDL cholesterol and a 20% improvement in fasting glucose levels compared to the control group.\n\nHowever, experts caution that intermittent fasting isn't suitable for everyone. \"Pregnant women, people with a history of eating disorders, and those with certain medical conditions should consult their doctor first,\" said Dr. Sarah Chen, the study's lead author. \"It's not a magic solution, but for many people, it can be a sustainable approach to improving metabolic health.\"\n\"\"\"\n\nTEXT_TRAVEL = \"\"\"\nHidden among the limestone karsts of Ha Long Bay, Cat Ba Island offers travelers an authentic Vietnamese experience away from the tourist crowds. We spent five days exploring this gem and discovered why it's becoming a favorite among backpackers and adventure seekers.\n\nThe island's national park features challenging hikes through tropical rainforest, with the trek to the peak of Ngu Lam offering panoramic views of the bay. We also kayaked through hidden lagoons and explored caves that few tourists ever see.\n\nAccommodation ranges from basic hostels ($8/night) to comfortable eco-resorts ($60/night). The seafood is incredibly fresh - we had the best grilled squid of our lives at a family-run restaurant in Cat Ba Town for just $5. Pro tip: rent a motorbike to explore the quieter beaches on the island's east side.\n\"\"\"\n\nTEXT_COOKING_RECIPE = \"\"\"\nThis Thai green curry comes together in just 30 minutes and tastes better than takeout. The secret is making your own curry paste - it takes an extra 10 minutes but the flavor difference is remarkable.\n\nFor the paste, blend together: 10 green chilies, 4 garlic cloves, 2 shallots, 1 stalk lemongrass, 1 inch galangal, handful of cilantro stems, 1 tsp cumin, 1 tsp coriander, zest of 1 lime, and 2 tbsp fish sauce. \n\nHeat coconut oil in a wok, fry the paste for 2 minutes until fragrant. Add chicken (or tofu), cook until browned. Pour in coconut milk, add bamboo shoots, Thai eggplant, and basil. Simmer for 15 minutes. Season with palm sugar and more fish sauce to taste.\n\nServe over jasmine rice with extra chilies on the side. This recipe serves 4 and can be made ahead - the flavors actually improve overnight.\n\"\"\"\n\nTEXT_FINANCIAL_ADVICE = \"\"\"\nWith inflation running at 4.2% and the Fed signaling more rate hikes, many investors are wondering how to position their portfolios. Here's what our analysis suggests for Q4 2024.\n\nFixed income is finally attractive again. With 10-year Treasury yields above 4.5%, bonds offer meaningful real returns for the first time in years. We recommend increasing allocation to investment-grade corporate bonds and TIPS for inflation protection.\n\nFor equities, we're cautiously optimistic on value stocks, particularly in the energy and financial sectors. Tech valuations remain stretched despite recent pullbacks. International developed markets, especially Japan and Europe, offer better risk-reward at current levels.\n\nRemember: past performance doesn't guarantee future results. This is general information, not personalized advice. Consult a financial advisor before making investment decisions.\n\"\"\"\n\nTEXT_ENVIRONMENTAL = \"\"\"\nThe Great Barrier Reef experienced its sixth mass bleaching event in a decade this summer, with aerial surveys showing 91% of reefs affected. Scientists warn that without dramatic action on climate change, the world's largest coral ecosystem may not survive beyond 2050.\n\n\"We're witnessing the collapse of one of Earth's most biodiverse ecosystems in real time,\" said Dr. Terry Hughes of James Cook University. Water temperatures reached 2°C above the February average, causing corals to expel the symbiotic algae that give them color and nutrients.\n\nSome researchers are experimenting with heat-resistant coral varieties and cloud-brightening technology to shade reefs. However, most scientists agree these are stopgap measures. \"The only real solution is rapid decarbonization,\" Hughes said. \"Everything else is just buying time.\"\n\"\"\"\n\nTEXT_EDUCATION = \"\"\"\nThe debate over standardized testing in American schools has intensified following a new report showing significant post-pandemic learning gaps. The National Assessment of Educational Progress found that fourth-grade math scores dropped to levels not seen since 2005.\n\nProponents of testing argue that standardized assessments are essential for identifying struggling students and holding schools accountable. \"Without data, we're flying blind,\" said Education Secretary Miguel Cardona. \"Tests help us direct resources where they're needed most.\"\n\nCritics counter that high-stakes testing narrows the curriculum and increases student stress without improving outcomes. \"We're testing kids more than ever, but educational outcomes aren't improving,\" said education researcher Dr. Pasi Sahlberg. \"Countries like Finland, which use minimal standardized testing, consistently outperform the US.\"\n\"\"\"\n\nTEXT_FASHION = \"\"\"\nMilan Fashion Week wrapped up yesterday with several surprising trends that will likely dominate fall/winter 2025. After years of quiet luxury and minimalism, designers are embracing bold maximalism - think dramatic volumes, clashing prints, and unapologetic color.\n\nPrada's collection featured oversized coats with exaggerated shoulders paired with flowing silk pants, while Gucci returned to its pattern-mixing roots under new creative direction. Versace went full baroque with gold-embroidered gowns that would feel at home in a Renaissance painting.\n\nSustainability remained a talking point, with Stella McCartney showcasing a collection made entirely from recycled ocean plastic. However, critics noted that the industry still has far to go. \"One sustainable collection doesn't offset the environmental impact of fast fashion,\" noted fashion journalist Vanessa Friedman. \"The industry needs systemic change, not just good PR.\"\n\"\"\"\n\nTEXT_LEGAL_CASE = \"\"\"\nThe Supreme Court agreed Monday to hear a case that could reshape the boundaries of free speech on social media platforms. The case, NetChoice v. Paxton, challenges Texas and Florida laws that prohibit large social media companies from removing certain political content.\n\nTech companies argue that the First Amendment protects their right to moderate content as they see fit, similar to how newspapers decide what to publish. \"Forcing platforms to host speech they find objectionable is compelled speech, which the Constitution forbids,\" said NetChoice counsel Paul Clement.\n\nTexas and Florida counter that social media platforms function as common carriers or public utilities and should be subject to similar non-discrimination requirements. \"These companies have become the modern public square,\" said Texas Attorney General Ken Paxton. \"They shouldn't be able to silence voices based on political viewpoint.\"\n\"\"\"\n\nTEXT_GAMING = \"\"\"\nAfter three years in development hell, \"Hollow Eclipse\" has finally launched - and it's everything fans hoped for. This action RPG from indie studio Moonlight Games delivers a haunting 40-hour adventure that rivals titles from studios with ten times the budget.\n\nThe combat system strikes a perfect balance between accessibility and depth. Basic attacks and dodges are simple to execute, but mastering the \"shadow merge\" mechanic - which lets you temporarily possess enemies - adds layers of strategy. Boss fights are challenging without feeling unfair, though the final boss may take even experienced players dozens of attempts.\n\nWhere the game truly shines is its atmosphere. The decaying gothic city of Velmoor is rendered in stunning hand-drawn art, and the ambient soundtrack creates constant unease. The story tackles themes of grief and memory with surprising emotional maturity. Minor technical issues (occasional frame drops, one softlock) can't diminish this achievement. Score: 9.5/10\n\"\"\"\n\nTEXT_REAL_ESTATE = \"\"\"\nThe housing market is sending mixed signals as we enter 2025. Existing home sales fell for the third consecutive month, down 4.1% in November, yet prices continue to climb in most metropolitan areas. The median home price hit $416,000, up 3.8% year-over-year.\n\nLow inventory remains the central issue. Many homeowners are reluctant to sell because they've locked in sub-3% mortgage rates and don't want to trade up to today's 7% rates. This \"lock-in effect\" has created a severe shortage of listings, particularly in the starter home category.\n\n\"We're seeing bidding wars even in this high-rate environment because there's simply nothing to buy,\" said economist Lawrence Yun. First-time buyers are particularly squeezed, with affordability at its worst level since 1984. Some markets, including Austin and Phoenix, are showing price corrections, but coastal cities remain stubbornly expensive.\n\"\"\"\n\nTEXT_MENTAL_HEALTH = \"\"\"\nWorkplace burnout has reached epidemic proportions, with a new Gallup survey finding that 76% of employees experience burnout at least sometimes. But recognizing burnout isn't always straightforward - it often manifests differently than simple exhaustion.\n\nThe three hallmarks of burnout are: emotional exhaustion (feeling drained and unable to cope), depersonalization (becoming cynical and detached from work), and reduced personal accomplishment (feeling ineffective regardless of actual performance).\n\nRecovery requires more than a vacation. \"You can't just rest your way out of burnout,\" says psychologist Dr. Christina Maslach, who pioneered burnout research. \"You need to address the root causes - usually workload, lack of control, insufficient recognition, or values conflicts.\" Strategies include setting firm boundaries, delegating tasks, and having honest conversations with managers about sustainable workloads. In severe cases, professional support from a therapist can help.\n\"\"\"\n\nTEXT_ASTRONOMY = \"\"\"\nNASA's James Webb Space Telescope has detected what may be signs of biological activity in the atmosphere of K2-18b, an exoplanet 120 light-years away. The discovery has electrified the scientific community, though researchers caution against jumping to conclusions.\n\nThe telescope's spectrometers identified dimethyl sulfide (DMS), a molecule produced almost exclusively by living organisms on Earth. Webb also confirmed the presence of methane and carbon dioxide, consistent with a water-rich atmosphere.\n\n\"This is tantalizing, but not definitive proof of life,\" said lead researcher Dr. Nikku Madhusudhan. \"DMS could potentially be produced by unknown geological processes. We need more observations.\" K2-18b is a \"Hycean\" world - a planet with a hydrogen-rich atmosphere and potentially a liquid water ocean beneath. If confirmed, this would be humanity's first detection of a potential biosignature beyond our solar system.\n\"\"\"\n\n\ndef parse_labels_input(labels_input: str) -> Union[List[str], Dict[str, Any]]:\n    \"\"\"\n    Parse labels input - supports both comma-separated and JSON hierarchical format.\n    \n    Examples:\n    - \"positive, negative, neutral\" -> [\"positive\", \"negative\", \"neutral\"]\n    - '{\"sentiment\": [\"positive\", \"negative\"], \"topic\": [\"food\", \"service\"]}' -> dict\n    \"\"\"\n    labels_input = labels_input.strip()\n    \n    # Try parsing as JSON first (for hierarchical labels)\n    if labels_input.startswith('{'):\n        try:\n            return json.loads(labels_input)\n        except json.JSONDecodeError as e:\n            raise ValueError(f\"Invalid JSON format for hierarchical labels: {e}\")\n    \n    # Otherwise, treat as comma-separated flat labels\n    labels = [label.strip() for label in labels_input.split(',') if label.strip()]\n    return labels\n\n\ndef parse_examples_input(examples_input: str) -> Optional[List[Dict[str, Any]]]:\n    \"\"\"\n    Parse few-shot examples input (JSON format).\n    \n    Expected format:\n    [\n        {\"text\": \"Example text 1\", \"labels\": [\"label1\", \"label2\"]},\n        {\"text\": \"Example text 2\", \"labels\": [\"label3\"]}\n    ]\n    \"\"\"\n    if not examples_input or not examples_input.strip():\n        return None\n    \n    try:\n        examples = json.loads(examples_input.strip())\n        if not isinstance(examples, list):\n            raise ValueError(\"Examples must be a JSON array\")\n        \n        for i, ex in enumerate(examples):\n            if not isinstance(ex, dict):\n                raise ValueError(f\"Example {i+1} must be a JSON object\")\n            if 'text' not in ex:\n                raise ValueError(f\"Example {i+1} missing 'text' field\")\n            if 'labels' not in ex and 'true_labels' not in ex:\n                raise ValueError(f\"Example {i+1} missing 'labels' field\")\n        \n        return examples\n    except json.JSONDecodeError as e:\n        raise ValueError(f\"Invalid JSON format for examples: {e}\")\n\n\ndef format_output(\n    results: Union[List[Dict], Dict], \n    hierarchical: bool = False,\n    output_format: str = \"visual\"\n) -> Union[Dict[str, float], str]:\n    \"\"\"Format classification output for Gradio display.\"\"\"\n    \n    if output_format == \"json\":\n        return format_as_json(results, hierarchical)\n    \n    if hierarchical and isinstance(results, dict):\n        # Format hierarchical output as readable string\n        return format_hierarchical_dict(results)\n    \n    if isinstance(results, list):\n        return {result['label']: float(result['score']) for result in results}\n    \n    return results\n\n\ndef format_as_json(results: Union[List[Dict], Dict], hierarchical: bool = False) -> str:\n    \"\"\"Format results as pretty-printed JSON string.\"\"\"\n    if hierarchical and isinstance(results, dict):\n        # Already in hierarchical dict format\n        return json.dumps(results, indent=2, ensure_ascii=False)\n    \n    if isinstance(results, list):\n        # Convert list of predictions to structured format\n        output = {\n            \"predictions\": [\n                {\"label\": r[\"label\"], \"score\": round(r[\"score\"], 4)}\n                for r in results\n            ],\n            \"scores\": {r[\"label\"]: round(r[\"score\"], 4) for r in results}\n        }\n        return json.dumps(output, indent=2, ensure_ascii=False)\n    \n    return json.dumps(results, indent=2, ensure_ascii=False)\n\n\ndef format_hierarchical_dict(d: Dict, indent: int = 0) -> str:\n    \"\"\"Format hierarchical dict for display with visual score bars.\"\"\"\n    lines = []\n    prefix = \"  \" * indent\n    \n    for key, value in d.items():\n        if isinstance(value, dict):\n            lines.append(f\"{prefix}**{key}**:\")\n            lines.append(format_hierarchical_dict(value, indent + 1))\n        else:\n            score_bar = \"█\" * int(value * 20) + \"░\" * (20 - int(value * 20))\n            lines.append(f\"{prefix}{key}: {score_bar} {value:.3f}\")\n    \n    return \"\\n\".join(lines)\n\n\ndef classification(\n    text: str,\n    labels_input: str,\n    threshold: float,\n    multi_label: bool,\n    prompt: str,\n    examples_input: str,\n    hierarchical_output: bool,\n    output_format: str = \"visual\"\n) -> Union[Dict[str, float], str]:\n    \"\"\"\n    Perform classification with all advanced features.\n    \"\"\"\n    try:\n        # Parse labels (flat or hierarchical)\n        labels = parse_labels_input(labels_input)\n        \n        # Parse few-shot examples\n        examples = parse_examples_input(examples_input) if examples_input else None\n        \n        # Set classification type\n        pipeline.pipe.classification_type = 'multi-label' if multi_label else 'single-label'\n        \n        # Prepare prompt\n        task_prompt = prompt.strip() if prompt and prompt.strip() else None\n        \n        # Run classification\n        results = pipeline(\n            text, \n            labels, \n            threshold=threshold,\n            examples=examples,\n            prompt=task_prompt,\n            return_hierarchical=hierarchical_output\n        )[0]  # Single text, get first result\n        \n        # Format output based on selected format\n        if output_format == \"json\":\n            return format_as_json(results, hierarchical_output)\n        elif hierarchical_output:\n            return format_hierarchical_dict(results)\n        else:\n            return {result['label']: float(result['score']) for result in results}\n            \n    except Exception as e:\n        return f\"Error: {str(e)}\"\n\n\n# ============== Example Configurations ==============\n\nEXAMPLES = [\n    # Example 1: Basic flat labels with prompt\n    [\n        TEXT_PRODUCT_REVIEW,\n        \"product review, electronics, positive feedback, negative feedback, customer service, shipping\",\n        0.5,\n        True,\n        \"Classify this customer review by topic and sentiment:\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n    # Example 2: Hierarchical labels for restaurant review\n    [\n        TEXT_RESTAURANT_REVIEW,\n        '''{\n    \"sentiment\": [\"positive\", \"negative\", \"mixed\"],\n    \"aspects\": [\"food quality\", \"service\", \"ambiance\", \"price\", \"wait time\"],\n    \"recommendation\": [\"would recommend\", \"would not recommend\"]\n}''',\n        0.4,\n        True,\n        \"Analyze this restaurant review:\",\n        \"\",\n        True,\n        \"visual\"\n    ],\n    # Example 3: News article with few-shot examples\n    [\n        TEXT_NEWS_POLITICS,\n        \"politics, business, technology, sports, entertainment, science, health\",\n        0.5,\n        True,\n        \"Classify this news article by category:\",\n        '''[\n    {\"text\": \"The Federal Reserve raised interest rates by 0.25% today, citing persistent inflation concerns.\", \"labels\": [\"politics\", \"business\"]},\n    {\"text\": \"Scientists discover high new high-temperature superconductor material that works at room temperature.\", \"labels\": [\"science\", \"technology\"]}\n]''',\n        False,\n        \"visual\"\n    ],\n    # Example 4: Scientific classification with hierarchical output\n    [\n        TEXT_SCIENTIFIC,\n        '''{\n    \"domain\": [\"biology\", \"chemistry\", \"medicine\", \"physics\"],\n    \"research_type\": [\"experimental\", \"theoretical\", \"review\"],\n    \"application\": [\"therapeutic\", \"diagnostic\", \"basic research\"]\n}''',\n        0.3,\n        True,\n        \"Classify this scientific abstract:\",\n        \"\",\n        True,\n        \"visual\"\n    ],\n    # Example 5: Sports article - single label\n    [\n        TEXT_SPORTS,\n        \"basketball, football, soccer, tennis, baseball, hockey, golf\",\n        0.5,\n        False,\n        \"What sport is this article about?\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n    # Example 6: Movie review with detailed sentiment (JSON output)\n    [\n        TEXT_MOVIE_REVIEW,\n        '''{\n    \"overall_sentiment\": [\"positive\", \"negative\", \"mixed\"],\n    \"aspects_praised\": [\"acting\", \"direction\", \"cinematography\", \"music\", \"story\", \"pacing\"],\n    \"aspects_criticized\": [\"acting\", \"direction\", \"cinematography\", \"music\", \"story\", \"pacing\"],\n    \"recommendation\": [\"must watch\", \"worth watching\", \"skip it\"]\n}''',\n        0.35,\n        True,\n        \"Analyze this movie review in detail:\",\n        \"\",\n        True,\n        \"json\"\n    ],\n    # Example 7: Tech startup news\n    [\n        TEXT_TECH_STARTUP,\n        \"funding announcement, product launch, acquisition, IPO, partnership, hiring, layoffs, legal\",\n        0.4,\n        True,\n        \"What type of tech news is this?\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n    # Example 8: Health article with hierarchical categories\n    [\n        TEXT_HEALTH_WELLNESS,\n        '''{\n    \"topic\": [\"nutrition\", \"exercise\", \"mental health\", \"sleep\", \"medical research\"],\n    \"content_type\": [\"research findings\", \"practical advice\", \"expert opinion\", \"warning\"],\n    \"audience\": [\"general public\", \"healthcare professionals\", \"patients\"]\n}''',\n        0.4,\n        True,\n        \"Categorize this health article:\",\n        \"\",\n        True,\n        \"visual\"\n    ],\n    # Example 9: Travel content (JSON output)\n    [\n        TEXT_TRAVEL,\n        \"destination guide, hotel review, restaurant review, adventure travel, budget travel, luxury travel, travel tips\",\n        0.4,\n        True,\n        \"What type of travel content is this?\",\n        \"\",\n        False,\n        \"json\"\n    ],\n    # Example 10: Recipe classification\n    [\n        TEXT_COOKING_RECIPE,\n        '''{\n    \"cuisine\": [\"Thai\", \"Italian\", \"Mexican\", \"Indian\", \"Chinese\", \"Japanese\", \"French\", \"American\"],\n    \"difficulty\": [\"easy\", \"medium\", \"hard\"],\n    \"meal_type\": [\"breakfast\", \"lunch\", \"dinner\", \"dessert\", \"snack\"],\n    \"dietary\": [\"vegetarian friendly\", \"vegan friendly\", \"gluten free\", \"dairy free\", \"contains meat\"]\n}''',\n        0.35,\n        True,\n        \"Classify this recipe:\",\n        \"\",\n        True,\n        \"visual\"\n    ],\n    # Example 11: Financial content with examples\n    [\n        TEXT_FINANCIAL_ADVICE,\n        \"investment advice, market analysis, personal finance, retirement planning, tax advice, economic news\",\n        0.4,\n        True,\n        \"Categorize this financial content:\",\n        '''[\n    {\"text\": \"Here are 5 ways to maximize your 401k contributions before year end.\", \"labels\": [\"personal finance\", \"retirement planning\", \"tax advice\"]},\n    {\"text\": \"The S&P 500 rose 2% today following strong jobs report.\", \"labels\": [\"market analysis\", \"economic news\"]}\n]''',\n        False,\n        \"visual\"\n    ],\n    # Example 12: Environmental news (JSON output)\n    [\n        TEXT_ENVIRONMENTAL,\n        '''{\n    \"topic\": [\"climate change\", \"biodiversity\", \"pollution\", \"conservation\", \"renewable energy\"],\n    \"tone\": [\"alarming\", \"hopeful\", \"neutral\", \"urgent\"],\n    \"focus\": [\"problem description\", \"solutions\", \"policy\", \"research findings\"]\n}''',\n        0.35,\n        True,\n        \"Analyze this environmental article:\",\n        \"\",\n        True,\n        \"json\"\n    ],\n    # Example 13: Education debate\n    [\n        TEXT_EDUCATION,\n        \"education policy, standardized testing, curriculum, teacher issues, student welfare, technology in education, higher education\",\n        0.4,\n        True,\n        \"What education topics does this article cover?\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n    # Example 14: Fashion news with hierarchy\n    [\n        TEXT_FASHION,\n        '''{\n    \"content_type\": [\"trend report\", \"designer profile\", \"collection review\", \"industry news\", \"sustainability\"],\n    \"season\": [\"spring/summer\", \"fall/winter\"],\n    \"market_segment\": [\"luxury\", \"fast fashion\", \"sustainable fashion\", \"streetwear\"]\n}''',\n        0.4,\n        True,\n        \"Classify this fashion article:\",\n        \"\",\n        True,\n        \"visual\"\n    ],\n    # Example 15: Legal case (JSON output)\n    [\n        TEXT_LEGAL_CASE,\n        \"constitutional law, criminal law, civil rights, corporate law, intellectual property, free speech, privacy\",\n        0.4,\n        True,\n        \"What areas of law does this case involve?\",\n        \"\",\n        False,\n        \"json\"\n    ],\n    # Example 16: Gaming review with detailed analysis\n    [\n        TEXT_GAMING,\n        '''{\n    \"genre\": [\"action\", \"RPG\", \"adventure\", \"puzzle\", \"strategy\", \"simulation\", \"sports\"],\n    \"platform_feel\": [\"indie\", \"AAA\", \"mid-tier\"],\n    \"strengths\": [\"gameplay\", \"story\", \"graphics\", \"music\", \"replayability\"],\n    \"weaknesses\": [\"bugs\", \"difficulty\", \"length\", \"graphics\", \"story\"],\n    \"recommendation\": [\"must play\", \"worth playing\", \"wait for sale\", \"skip\"]\n}''',\n        0.35,\n        True,\n        \"Analyze this game review:\",\n        \"\",\n        True,\n        \"visual\"\n    ],\n    # Example 17: Real estate market analysis\n    [\n        TEXT_REAL_ESTATE,\n        \"market analysis, buying advice, selling advice, investment, rental market, mortgage rates, housing policy\",\n        0.4,\n        True,\n        \"What real estate topics are covered?\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n    # Example 18: Mental health with few-shot (JSON output)\n    [\n        TEXT_MENTAL_HEALTH,\n        '''{\n    \"topic\": [\"burnout\", \"anxiety\", \"depression\", \"stress management\", \"work-life balance\"],\n    \"content_type\": [\"educational\", \"self-help advice\", \"research summary\", \"personal story\"],\n    \"actionability\": [\"provides concrete steps\", \"general awareness\", \"seeks professional help\"]\n}''',\n        0.35,\n        True,\n        \"Categorize this mental health content:\",\n        '''[\n    {\"text\": \"Feeling overwhelmed? Try the 5-4-3-2-1 grounding technique: notice 5 things you see, 4 you hear...\", \"labels\": [\"topic.anxiety\", \"topic.stress management\", \"content_type.self-help advice\", \"actionability.provides concrete steps\"]},\n    {\"text\": \"A new study links social media use exceeding 3 hours daily with increased rates of depression in teens.\", \"labels\": [\"topic.depression\", \"content_type.research summary\", \"actionability.general awareness\"]}\n]''',\n        True,\n        \"json\"\n    ],\n    # Example 19: Astronomy discovery\n    [\n        TEXT_ASTRONOMY,\n        \"exoplanets, astrobiology, cosmology, solar system, space exploration, telescopes, astrophysics\",\n        0.4,\n        True,\n        \"What astronomy topics are discussed?\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n    # Example 20: Tech companies - single label\n    [\n        TEXT_TECH_COMPANIES,\n        \"company profile, product announcement, financial report, industry analysis, biography, opinion piece\",\n        0.5,\n        False,\n        \"What is the primary type of this article?\",\n        \"\",\n        False,\n        \"visual\"\n    ],\n]\n\n\n# ============== Gradio Interface ==============\n\nwith gr.Blocks(\n    title=\"GLiClass Advanced Demo\",\n    theme=gr.themes.Soft(\n        primary_hue=\"blue\",\n        secondary_hue=\"slate\",\n    )\n) as demo:\n    \n    gr.Markdown(\"\"\"\n    # 🏷️ GLiClass Advanced Zero-Shot Classification\n    \n    Enhanced demo featuring **prompts**, **hierarchical labels**, **few-shot examples**, and **structured outputs**.\n    \"\"\")\n    \n    with gr.Accordion(\"📖 How to Use This Demo\", open=False):\n        gr.Markdown(\"\"\"\n        ## Features Overview\n        \n        ### 1. Task Description Prompts\n        Add a natural language description of the classification task to guide the model.\n        \n        **Example:** `\"Classify this customer review by sentiment and topic:\"`\n        \n        ---\n        \n        ### 2. Hierarchical Labels (JSON Format)\n        Structure your labels in categories for organized classification:\n        \n        ```json\n        {\n            \"sentiment\": [\"positive\", \"negative\", \"neutral\"],\n            \"topic\": [\"product\", \"service\", \"shipping\"],\n            \"urgency\": [\"high\", \"medium\", \"low\"]\n        }\n        ```\n        \n        Or use simple comma-separated labels: `positive, negative, neutral`\n        \n        ---\n        \n        ### 3. Few-Shot Examples\n        Provide examples to guide the model's understanding:\n        \n        ```json\n        [\n            {\"text\": \"Great product, love it!\", \"labels\": [\"positive\", \"product\"]},\n            {\"text\": \"Shipping was delayed by 2 weeks\", \"labels\": [\"negative\", \"shipping\"]}\n        ]\n        ```\n        \n        ---\n        \n        ### 4. Hierarchical Output\n        When enabled with hierarchical labels, returns structured scores matching your input format.\n        \"\"\")\n    \n    with gr.Accordion(\"💻 Code Example\", open=False):\n        gr.Code(\n            '''from gliclass import GLiClassModel, ZeroShotClassificationPipeline\nfrom transformers import AutoTokenizer\n\nmodel = GLiClassModel.from_pretrained(\"knowledgator/gliclass-small-v1\")\ntokenizer = AutoTokenizer.from_pretrained(\"knowledgator/gliclass-small-v1\")\n\npipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')\n\n# Basic usage\ntext = \"The product quality is amazing but delivery was slow\"\nlabels = [\"positive\", \"negative\", \"product\", \"shipping\"]\nresults = pipeline(text, labels, threshold=0.5)[0]\n\n# With hierarchical labels\nhierarchical_labels = {\n    \"sentiment\": [\"positive\", \"negative\", \"neutral\"],\n    \"topic\": [\"product\", \"service\", \"shipping\"]\n}\n\nresults = pipeline(\n    text, \n    hierarchical_labels,\n    prompt=\"Classify this review:\",\n    return_hierarchical=True\n)[0]\n\n# With few-shot examples\nexamples = [\n    {\"text\": \"Love this item!\", \"labels\": [\"sentiment.positive\", \"topic.product\"]},\n    {\"text\": \"Terrible customer support\", \"labels\": [\"sentiment.negative\", \"topic.service\"]}\n]\n\nresults = pipeline(\n    text,\n    hierarchical_labels, \n    examples=examples,\n    prompt=\"Classify customer feedback:\"\n)[0]\n''',\n            language=\"python\",\n        )\n    \n    with gr.Row():\n        with gr.Column(scale=2):\n            input_text = gr.Textbox(\n                value=EXAMPLES[0][0],\n                label=\"📝 Text Input\",\n                placeholder=\"Enter the text you want to classify...\",\n                lines=8\n            )\n            \n            prompt_input = gr.Textbox(\n                value=EXAMPLES[0][4],\n                label=\"💡 Task Description Prompt (Optional)\",\n                placeholder=\"E.g., 'Classify this customer review by sentiment and topic:'\",\n                lines=1\n            )\n        \n        with gr.Column(scale=1):\n            labels_input = gr.Textbox(\n                value=EXAMPLES[0][1],\n                label=\"🏷️ Labels (comma-separated or JSON)\",\n                placeholder='positive, negative\\n\\nOR\\n\\n{\"category\": [\"label1\", \"label2\"]}',\n                lines=6\n            )\n            \n            with gr.Row():\n                threshold = gr.Slider(\n                    0, 1,\n                    value=0.5,\n                    step=0.01,\n                    label=\"Threshold\",\n                    info=\"Confidence threshold for predictions\"\n                )\n            \n            with gr.Row():\n                multi_label = gr.Checkbox(\n                    value=True,\n                    label=\"Multi-label\",\n                    info=\"Allow multiple labels per text\"\n                )\n                hierarchical_output = gr.Checkbox(\n                    value=False,\n                    label=\"Hierarchical Output\",\n                    info=\"Return structured output matching label hierarchy\"\n                )\n            \n            with gr.Row():\n                output_format = gr.Radio(\n                    choices=[\"visual\", \"json\"],\n                    value=\"visual\",\n                    label=\"Output Format\",\n                    info=\"Visual: charts/bars | JSON: raw data\"\n                )\n    \n    with gr.Accordion(\"🎯 Few-Shot Examples (Optional)\", open=False):\n        examples_input = gr.Textbox(\n            value=\"\",\n            label=\"Examples (JSON format)\",\n            placeholder='''[\n    {\"text\": \"Example text 1\", \"labels\": [\"label1\", \"label2\"]},\n    {\"text\": \"Example text 2\", \"labels\": [\"label3\"]}\n]''',\n            lines=5\n        )\n        gr.Markdown(\"\"\"\n        *Provide labeled examples to guide the model. Each example needs a `text` field and a `labels` array.*\n        \"\"\")\n    \n    submit_btn = gr.Button(\"🚀 Classify\", variant=\"primary\", size=\"lg\")\n    \n    output = gr.Label(label=\"📊 Classification Results\")\n    output_text = gr.Textbox(\n        label=\"📊 Hierarchical Results\", \n        visible=False, \n        lines=10\n    )\n    output_json = gr.Code(\n        label=\"📊 JSON Output\",\n        language=\"json\",\n        visible=False,\n        lines=15\n    )\n    \n    # Dynamic output visibility based on format and hierarchical toggle\n    def update_output_visibility(hierarchical: bool, fmt: str):\n        if fmt == \"json\":\n            return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)\n        elif hierarchical:\n            return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)\n        else:\n            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)\n    \n    hierarchical_output.change(\n        fn=update_output_visibility,\n        inputs=[hierarchical_output, output_format],\n        outputs=[output, output_text, output_json]\n    )\n    \n    output_format.change(\n        fn=update_output_visibility,\n        inputs=[hierarchical_output, output_format],\n        outputs=[output, output_text, output_json]\n    )\n    \n    # Classification function wrapper for different outputs\n    def classify_wrapper(text, labels, threshold, multi_label, prompt, examples, hierarchical, fmt):\n        result = classification(text, labels, threshold, multi_label, prompt, examples, hierarchical, fmt)\n        \n        if fmt == \"json\":\n            return None, None, result\n        elif hierarchical or isinstance(result, str):\n            return None, result, None\n        else:\n            return result, None, None\n    \n    # Event handlers\n    submit_btn.click(\n        fn=classify_wrapper,\n        inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],\n        outputs=[output, output_text, output_json]\n    )\n    \n    input_text.submit(\n        fn=classify_wrapper,\n        inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],\n        outputs=[output, output_text, output_json]\n    )\n    \n    gr.Markdown(\"### 📚 Example Configurations\")\n    \n    gr.Examples(\n        examples=EXAMPLES,\n        inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],\n        outputs=[output, output_text, output_json],\n        fn=classify_wrapper,\n        cache_examples=False,\n        examples_per_page=5\n    )\n    \n    gr.Markdown(\"\"\"\n    ---\n    \n    ### 🔧 Tips for Best Results\n    \n    | Feature | Best Practice |\n    |---------|---------------|\n    | **Prompts** | Be specific about the task, e.g., \"Classify by sentiment:\" vs \"Analyze:\" |\n    | **Labels** | Use descriptive labels; \"customer service issue\" > \"service\" |\n    | **Hierarchical** | Group related labels under categories for organized results |\n    | **Examples** | 2-3 diverse examples often improve accuracy significantly |\n    | **Threshold** | Start at 0.5, lower for more predictions, raise for higher precision |\n    \"\"\")\n\n\nif __name__ == \"__main__\":\n    demo.queue()\n    demo.launch(debug=True, share=True)"
  },
  {
    "path": "gliclass/__init__.py",
    "content": "from .model import GLiClassModel, GLiClassBiEncoder, GLiClassUniEncoder, GLiClassEncoderDecoderCLS\nfrom .config import GLiClassModelConfig\nfrom .pipeline import (\n    ZeroShotClassificationPipeline,\n    BiEncoderZeroShotClassificationPipeline,\n    ZeroShotClassificationWithChunkingPipeline,\n)\n\n__version__ = \"0.1.19\"\n\n# Serve module (optional import)\ntry:\n    from . import serve\nexcept ImportError:\n    serve = None\n"
  },
  {
    "path": "gliclass/config.py",
    "content": "from transformers import AutoConfig\nfrom transformers.utils import logging\nfrom transformers.models.auto import CONFIG_MAPPING\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom .utils import is_module_available\n\nIS_TURBOT5 = is_module_available(\"turbot5\")\n\nif IS_TURBOT5:\n    from turbot5.model.config import T5Config\nelse:\n    from transformers import T5Config\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass GLiClassModelConfig(PretrainedConfig):\n    model_type = \"GLiClass\"\n    is_composition = True\n\n    def __init__(\n        self,\n        encoder_config=None,\n        encoder_model=None,\n        label_model_config=None,\n        label_model_name=None,\n        class_token_index=-1,\n        text_token_index=-1,\n        example_token_index=-1,\n        ignore_index=-100,\n        hidden_size=None,\n        projector_hidden_act=\"gelu\",\n        vocab_size=None,\n        problem_type=\"single_label_classification\",\n        max_num_classes=25,\n        use_lstm=False,\n        initializer_range=0.03,\n        scorer_type=\"simple\",\n        scorer_num_heads=16,\n        scorer_mlp_hidden_size=1024,\n        scorer_attn_dropout=0.1,\n        pooling_strategy=\"first\",\n        class_token_pooling=\"first\",\n        focal_loss_alpha=0.5,\n        focal_loss_gamma=2,\n        focal_loss_reduction=None,\n        logit_scale_init_value=2.6592,\n        normalize_features=False,\n        extract_text_features=False,\n        max_labels_alloc: str = \"dynamic\",\n        contrastive_loss_coef=0,\n        architecture_type=\"uni-encoder\",\n        prompt_first=False,\n        squeeze_layers=False,\n        layer_wise=False,\n        encoder_layer_id=-1,\n        embed_class_token=True,\n        dropout=0.1,\n        use_segment_embeddings=False,\n        **kwargs,\n    ):\n        if isinstance(encoder_config, dict):\n            encoder_config[\"model_type\"] = encoder_config.get(\"model_type\", \"deberta-v2\")\n            if encoder_config[\"model_type\"] == \"t5\":\n                encoder_config = T5Config(**encoder_config)\n            elif encoder_config[\"model_type\"] in CONFIG_MAPPING:\n                encoder_config = CONFIG_MAPPING[encoder_config[\"model_type\"]](**encoder_config)\n            else:\n                _name = encoder_model or kwargs.get(\"encoder_model_name\")\n                if _name:\n                    encoder_config = AutoConfig.from_pretrained(_name, trust_remote_code=True)\n                else:\n                    encoder_config = PretrainedConfig(**encoder_config)\n        elif encoder_config is None:\n            encoder_config = CONFIG_MAPPING[\"deberta-v2\"]()\n\n        self.encoder_config = encoder_config\n        self.encoder_model_name = encoder_model\n\n        if label_model_name is not None:\n            if isinstance(label_model_config, dict):\n                label_model_config[\"model_type\"] = label_model_config.get(\"model_type\", \"deberta-v2\")\n                label_model_config = CONFIG_MAPPING[label_model_config[\"model_type\"]](**label_model_config)\n            elif label_model_config is None:\n                label_model_config = CONFIG_MAPPING[\"deberta-v2\"]()\n\n            self.label_model_config = label_model_config\n        else:\n            self.label_model_config = None\n        self.label_model_name = label_model_name\n\n        if hidden_size is None:\n            self.hidden_size = self.encoder_config.hidden_size\n        else:\n            self.hidden_size = hidden_size\n\n        if vocab_size is None:\n            self.vocab_size = self.encoder_config.vocab_size\n        else:\n            self.vocab_size = vocab_size\n\n        if class_token_index == -1:\n            self.class_token_index = self.vocab_size\n        else:\n            self.class_token_index = class_token_index\n\n        if text_token_index == -1:\n            self.text_token_index = self.vocab_size + 1\n        else:\n            self.text_token_index = text_token_index\n\n        if example_token_index == -1:\n            self.example_token_index = self.vocab_size + 2\n        else:\n            self.example_token_index = example_token_index\n\n        self.ignore_index = ignore_index\n        self.projector_hidden_act = projector_hidden_act\n        self.problem_type = problem_type\n        self.max_num_classes = max_num_classes\n        self.initializer_range = initializer_range\n        self.scorer_type = scorer_type\n        self.scorer_num_heads = scorer_num_heads\n        self.scorer_mlp_hidden_size = scorer_mlp_hidden_size\n        self.scorer_attn_dropout = scorer_attn_dropout\n        self.pooling_strategy = pooling_strategy\n        self.class_token_pooling = class_token_pooling\n        self.use_lstm = use_lstm\n        self.focal_loss_alpha = focal_loss_alpha\n        self.focal_loss_gamma = focal_loss_gamma\n        self.focal_loss_reduction = focal_loss_reduction\n        self.contrastive_loss_coef = contrastive_loss_coef\n        self.logit_scale_init_value = logit_scale_init_value\n        self.normalize_features = normalize_features\n        self.extract_text_features = extract_text_features\n        self.max_labels_alloc = max_labels_alloc\n        self.architecture_type = architecture_type\n        self.prompt_first = prompt_first\n        self.squeeze_layers = squeeze_layers\n        self.layer_wise = layer_wise\n        self.encoder_layer_id = encoder_layer_id\n        self.embed_class_token = embed_class_token\n        self.pad_token_id = self.encoder_config.pad_token_id\n        self.dropout = dropout\n        self.use_segment_embeddings = use_segment_embeddings\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "gliclass/data_processing.py",
    "content": "import copy\nimport random\nfrom dataclasses import dataclass\n\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.nn.utils.rnn import pad_sequence\n\n\n@dataclass\nclass AugmentationConfig:\n    \"\"\"Configuration for data augmentation.\"\"\"\n\n    enabled: bool = True\n\n    # Probability for each augmentation type\n    random_label_removal_prob: float = 0.15\n    random_label_addition_prob: float = 0.10\n    random_text_addition_prob: float = 0.05\n    random_add_description_prob: float = 0.25\n    random_add_synonyms_prob: float = 0.1\n    random_add_examples_prob: float = 0.25\n    max_num_examples: int = 5\n\n\nclass DataAugmenter:\n    def __init__(self, config, examples, labels, label2description=None):\n        self.config = config\n        self.examples = examples\n        self.labels = sorted(labels)\n        self.max_examples = self.config.max_num_examples\n        self.label2description = label2description or {}\n\n    def remove_labels(self, true_labels, all_labels):\n        if len(all_labels) <= 1:\n            return true_labels, all_labels\n        k = random.randint(1, len(all_labels))\n        all_labels = random.sample(all_labels, k=k)\n        true_labels = [lbl for lbl in true_labels if lbl in all_labels]\n        return true_labels, all_labels\n\n    def add_random_labels(self, all_labels):\n        if not self.labels:\n            return all_labels\n        num_add = len(all_labels) + 1\n        k = random.randint(1, min(num_add, len(self.labels)))\n        add_labels = random.sample(self.labels, k=k)\n        all_labels.extend(add_labels)\n        return all_labels\n\n    def add_random_text(self, text, all_labels):\n        if not self.examples:\n            return text\n        example = random.sample(self.examples, k=1)[0]\n        curr_labels = example[\"all_labels\"]\n        joint_labels = set(all_labels) & set(curr_labels)\n        if len(joint_labels):\n            return text\n        else:\n            if random.randint(0, 1):\n                text = example[\"text\"] + \" \" + text\n            else:\n                text = text + \" \" + example[\"text\"]\n            return text\n\n    def add_random_synonyms(self, all_labels):\n        \"\"\"Replace some labels with their synonyms if available.\"\"\"\n        if not self.label2description:\n            return all_labels\n\n        augmented_labels = []\n        for label in all_labels:\n            if label in self.label2description:\n                label_info = self.label2description[label]\n                synonyms = label_info.get(\"synonyms\", [])\n\n                if synonyms and random.random() < 0.5:\n                    augmented_labels.append(random.choice(synonyms))\n                else:\n                    augmented_labels.append(label)\n            else:\n                augmented_labels.append(label)\n\n        return augmented_labels\n\n    def add_random_descriptions(self, item):\n        \"\"\"Add descriptions to labels in the text or metadata.\"\"\"\n        if not self.label2description or not item[\"all_labels\"]:\n            return item\n\n        max_labels = min(3, len(item[\"all_labels\"]))\n        labels_to_describe = random.sample(item[\"all_labels\"], k=random.randint(1, max_labels))\n\n        descriptions = []\n        for label in labels_to_describe:\n            if label in self.label2description:\n                label_info = self.label2description[label]\n                desc_list = label_info.get(\"descriptions\", [])\n                if desc_list:\n                    descriptions.append(f\"{label}: {random.choice(desc_list)}\")\n\n        if descriptions:\n            desc_text = \" \".join(descriptions)\n            if random.random() < 0.5:\n                item[\"text\"] = desc_text + \" \" + item[\"text\"]\n            else:\n                item[\"text\"] = item[\"text\"] + \" \" + desc_text\n\n        return item\n\n    def add_random_examples(self, item):\n        \"\"\"Add example texts with similar labels.\"\"\"\n        if not item[\"all_labels\"]:\n            return item\n\n        candidate_examples = item.get(\"examples\", [])\n\n        item_label_set = set(item[\"all_labels\"])\n\n        if not candidate_examples:\n            for example in self.examples:\n                example_label_set = set(example[\"true_labels\"])\n                example_text = example[\"text\"]\n\n                overlap = item_label_set & example_label_set\n\n                # Only consider examples with at least one overlapping label\n                if overlap:\n                    candidate_examples.append({\"text\": example_text, \"labels\": list(example_label_set)})\n\n        if not candidate_examples:\n            return item\n\n        # Sort by overlap and select top examples\n        random.shuffle(candidate_examples)\n        top_candidates = candidate_examples[: self.max_examples]\n\n        num_examples = random.randint(1, min(2, len(top_candidates)))\n        selected_examples = random.sample(top_candidates, k=num_examples)\n\n        item[\"examples\"] = selected_examples\n\n        return item\n\n    def augment(self, item):\n        if not self.config.enabled:\n            return item\n\n        text = copy.deepcopy(item[\"text\"])\n        true_labels = copy.deepcopy(item[\"true_labels\"])\n        all_labels = copy.deepcopy(item[\"all_labels\"])\n\n        # Create augmented item\n        aug_item = {\"text\": text, \"true_labels\": true_labels, \"all_labels\": all_labels}\n\n        # Copy any additional fields\n        for key in item:\n            if key not in aug_item:\n                aug_item[key] = copy.deepcopy(item[key])\n\n        if random.random() < self.config.random_label_removal_prob:\n            aug_item[\"true_labels\"], aug_item[\"all_labels\"] = self.remove_labels(\n                aug_item[\"true_labels\"], aug_item[\"all_labels\"]\n            )\n\n        if random.random() < self.config.random_label_addition_prob:\n            aug_item[\"all_labels\"] = self.add_random_labels(aug_item[\"all_labels\"])\n\n        if random.random() < self.config.random_text_addition_prob:\n            aug_item[\"text\"] = self.add_random_text(aug_item[\"text\"], aug_item[\"all_labels\"])\n\n        if random.random() < self.config.random_add_synonyms_prob:\n            aug_item[\"all_labels\"] = self.add_random_synonyms(aug_item[\"all_labels\"])\n\n        if random.random() < self.config.random_add_description_prob:\n            aug_item = self.add_random_descriptions(aug_item)\n\n        if random.random() < self.config.random_add_examples_prob:\n            aug_item = self.add_random_examples(aug_item)\n\n        return aug_item\n\n\nclass GLiClassDataset(Dataset):\n    def __init__(\n        self,\n        examples,\n        tokenizer,\n        augment_config,\n        label2description={},\n        max_length=512,\n        problem_type=\"multi_label_classification\",\n        architecture_type=\"uni-encoder\",\n        add_description=True,\n        prompt_first=False,\n        get_negatives=False,\n        max_labels=50,\n        labels_tokenizer=None,\n        shuffle_labels=True,\n    ):\n        self.tokenizer = tokenizer\n        self.labels_tokenizer = labels_tokenizer\n        self.label2description = label2description\n        self.augment_config = augment_config\n        self.max_length = max_length\n        self._data = examples\n        self.add_description = add_description\n        self.problem_type = problem_type\n        self.architecture_type = architecture_type\n        self.prompt_first = prompt_first\n        self.dataset_labels = self.collect_dataset_labels()\n        self.get_negatives = get_negatives\n        self.max_labels = max_labels\n        self.shuffle_labels = shuffle_labels\n\n        self.sep_token = \"<<SEP>>\"\n        self.label_token = \"<<LABEL>>\"\n        self.example_token = \"<<EXAMPLE>>\"\n        self.augmenter = DataAugmenter(augment_config, examples, self.dataset_labels, label2description)\n        print(\"Total labels: \", len(self.dataset_labels))\n\n    def get_diversity(self):\n        return [item.get(\"_diversity\", {}).get(\"overall_diversity\", 0.5) for item in self.data]\n\n    def collect_dataset_labels(self):\n        dataset_labels = set()\n        for example in self._data:\n            dataset_labels.update(set(example[\"all_labels\"]))\n        return dataset_labels\n\n    def prepare_labels(self, example, label2idx, problem_type):\n        if problem_type == \"single_label_classification\":\n            labels = label2idx[example[\"true_labels\"][0]]\n        elif problem_type == \"multi_label_classification\":\n            if isinstance(example[\"true_labels\"], dict):\n                labels = [example[\"true_labels\"].get(label, 0.0) for label in example[\"all_labels\"]]\n            else:\n                labels = [1.0 if label in example[\"true_labels\"] else 0.0 for label in example[\"all_labels\"]]\n        else:\n            raise NotImplementedError(f\"{problem_type} is not implemented.\")\n        return torch.tensor(labels)\n\n    def prepare_prompt(self, item, label_token_first=True):\n        prompt_texts = []\n        for label in item[\"all_labels\"]:\n            if label_token_first:\n                label_tag = f\"{self.label_token}{label!s}\"\n            else:\n                label_tag = f\"{label!s}{self.label_token}\"\n            prompt_texts.append(label_tag)\n        prompt_texts.append(self.sep_token)\n        prompt = item.get(\"prompt\", \"\")\n        prompt_texts.append(prompt)\n        return prompt_texts\n\n    def format_examples(self, item):\n        examples = item.get(\"examples\", [])\n        if not examples:\n            return \"\"\n        examples = random.sample(examples, k=random.randint(1, len(examples)))\n        parts = []\n        for example in examples:\n            parts.append(self.example_token)\n            parts.append(example.get(\"text\", \"\"))\n            parts.append(\" \\nLabels:\\n \")\n            parts.append(\", \".join(example.get(\"labels\", example.get(\"true_labels\", []))))\n        parts.append(self.sep_token)\n        return \"\".join(parts)\n\n    def tokenize(self, texts):\n        tokenized_inputs = self.tokenizer(texts, truncation=True, max_length=self.max_length, padding=\"longest\")\n        return tokenized_inputs\n\n    def tokenize_labels(self, labels):\n        tokenized_inputs = self.labels_tokenizer(labels, truncation=True, max_length=self.max_length, padding=\"longest\")\n        return tokenized_inputs\n\n    def tokenize_and_prepare_labels_for_uniencoder(self, example):\n        if self.shuffle_labels:\n            random.shuffle(example[\"all_labels\"])\n        input_text = self.prepare_prompt(example)\n        examples_text = self.format_examples(example)\n        if self.prompt_first:\n            input_text = \"\".join(input_text) + str(example[\"text\"]) + examples_text\n        else:\n            input_text = str(example[\"text\"]) + \"\".join(input_text) + examples_text\n        label2idx = {label: idx for idx, label in enumerate(example[\"all_labels\"])}\n\n        tokenized_inputs = self.tokenize(input_text)\n        tokenized_inputs[\"labels\"] = self.prepare_labels(example, label2idx, self.problem_type)\n        tokenized_inputs[\"labels_text\"] = example[\"all_labels\"]\n        tokenized_inputs[\"input_texts\"] = example[\"text\"]\n        return tokenized_inputs\n\n    def tokenize_and_prepare_labels_for_encoder_decoder(self, example):\n        if self.shuffle_labels:\n            random.shuffle(example[\"all_labels\"])\n        class_texts = self.prepare_prompt(example, label_token_first=True)\n        class_texts = \"\".join(class_texts)\n        examples_text = self.format_examples(example)\n\n        label2idx = {label: idx for idx, label in enumerate(example[\"all_labels\"])}\n\n        input_text = str(example[\"text\"]) + examples_text\n        tokenized_inputs = self.tokenize(input_text)\n        tokenized_classes = self.tokenize(class_texts)\n        tokenized_inputs[\"class_input_ids\"] = tokenized_classes[\"input_ids\"]\n        tokenized_inputs[\"class_attention_mask\"] = tokenized_classes[\"attention_mask\"]\n        tokenized_inputs[\"labels\"] = self.prepare_labels(example, label2idx, self.problem_type)\n        return tokenized_inputs\n\n    def tokenize_and_prepare_labels_for_biencoder(self, example):\n        if self.shuffle_labels:\n            random.shuffle(example[\"all_labels\"])\n\n        def prepare_prompt(labels):\n            prompt_texts = []\n            for _label in labels:\n                label_tag = \"<<LABEL>>\"\n                prompt_texts.append(label_tag)\n            prompt_texts.append(\"<<SEP>>\")\n            return \"\".join(prompt_texts)\n\n        input_text = example[\"text\"]\n        class_texts = example[\"all_labels\"]\n\n        if self.architecture_type == \"bi-encoder-fused\":\n            prompt = prepare_prompt(class_texts)\n            if self.prompt_first:\n                input_text = f\"{prompt} {input_text}\"\n            else:\n                input_text = f\"{input_text} {prompt}\"\n\n        tokenized_inputs = self.tokenize(input_text)\n        tokenized_classes = self.tokenize_labels(class_texts)\n\n        tokenized_inputs[\"class_input_ids\"] = torch.tensor(tokenized_classes[\"input_ids\"])\n        tokenized_inputs[\"class_attention_mask\"] = torch.tensor(tokenized_classes[\"attention_mask\"])\n\n        label2idx = {label: idx for idx, label in enumerate(example[\"all_labels\"])}\n\n        tokenized_inputs[\"labels_mask\"] = torch.ones(len(class_texts))\n        tokenized_inputs[\"labels\"] = self.prepare_labels(example, label2idx, self.problem_type)\n        return tokenized_inputs\n\n    def __len__(self):\n        return len(self._data)\n\n    def __getitem__(self, idx):\n        example = self._data[idx]\n\n        example = self.augmenter.augment(example)\n\n        if self.architecture_type == \"uni-encoder\":\n            model_inputs = self.tokenize_and_prepare_labels_for_uniencoder(example)\n        elif self.architecture_type in {\"encoder-decoder\", \"encoder-decoder-cls\"}:\n            model_inputs = self.tokenize_and_prepare_labels_for_encoder_decoder(example)\n        elif self.architecture_type in {\"bi-encoder\", \"bi-encoder-fused\"}:\n            model_inputs = self.tokenize_and_prepare_labels_for_biencoder(example)\n        else:\n            raise NotImplementedError(\"This architecture type is not implemented.\")\n        return model_inputs\n\n\ndef pad_2d_tensor(key_data):\n    \"\"\"\n    Pad a list of 2D tensors to have the same size along both dimensions.\n\n    :param key_data: List of 2D tensors to pad.\n    :return: Tensor of padded tensors stacked along a new batch dimension.\n    \"\"\"\n    if not key_data:\n        raise ValueError(\"The input list 'key_data' should not be empty.\")\n\n    # Determine the maximum size along both dimensions\n    max_rows = max(tensor.shape[0] for tensor in key_data)\n    max_cols = max(tensor.shape[1] for tensor in key_data)\n\n    tensors = []\n\n    for tensor in key_data:\n        rows, cols = tensor.shape\n        row_padding = max_rows - rows\n        col_padding = max_cols - cols\n        # Pad the tensor along both dimensions\n        padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode=\"constant\", value=0)\n        tensors.append(padded_tensor)\n\n    # Stack the tensors into a single tensor along a new batch dimension\n    padded_tensors = torch.stack(tensors)\n\n    return padded_tensors\n\n\nclass DataCollatorWithPadding:\n    def __init__(self, device=\"cuda:0\", config=None):\n        self.device = device\n        self._max_labels_alloc = getattr(config, \"max_labels_alloc\", \"dynamic\") if config is not None else \"dynamic\"\n\n    def _resolve_max_num_classes(self, batch):\n        if self._max_labels_alloc == \"dynamic\":\n            first = batch[0]\n            if \"labels_text\" in first:\n                return max(len(item[\"labels_text\"]) for item in batch)\n            if \"labels_mask\" in first:\n                return max(item[\"labels_mask\"].shape[0] for item in batch)\n            first_labels = first.get(\"labels\")\n            if isinstance(first_labels, torch.Tensor) and first_labels.dim() >= 1:\n                return max(item[\"labels\"].shape[0] for item in batch)\n            return None\n        if isinstance(self._max_labels_alloc, int):\n            return self._max_labels_alloc\n        return None  # 'fixed': model uses config.max_num_classes\n\n    def __call__(self, batch):\n        keys = batch[0].keys()\n        padded_batch = {key: [] for key in keys}\n\n        for key in keys:\n            key_data = [item[key] for item in batch]\n            if isinstance(key_data[0], torch.Tensor):\n                if key_data[0].dim() == 1:\n                    padded_batch[key] = pad_sequence(key_data, batch_first=True)\n                elif key_data[0].dim() == 2:\n                    padded_batch[key] = pad_2d_tensor(key_data)\n            elif isinstance(key_data[0], list):\n                data_el = \"string\"\n                if len(key_data[0]):\n                    data_el = key_data[0][0]\n                if isinstance(data_el, str):\n                    padded_batch[key] = key_data\n                else:\n                    max_length = max(len(seq) for seq in key_data)\n                    padded_batch[key] = torch.tensor([seq + [0] * (max_length - len(seq)) for seq in key_data])\n            elif type(key_data[0]) in {int, float}:\n                padded_batch[key] = torch.tensor(key_data)\n            elif isinstance(key_data[0], str):\n                padded_batch[key] = key_data\n            else:\n                raise TypeError(f\"Unsupported data type: {type(key_data[0])}\")\n\n        padded_batch[\"max_num_classes\"] = self._resolve_max_num_classes(batch)\n        return padded_batch\n"
  },
  {
    "path": "gliclass/layers.py",
    "content": "# Copyright 2020 Microsoft and the Hugging Face Inc. team and Knowledgator.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence\nfrom transformers.activations import ACT2FN\n\nfrom .config import GLiClassModelConfig\n\n\nclass LstmSeq2SeqEncoder(nn.Module):\n    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, bidirectional=False):\n        super().__init__()\n        self.lstm = nn.LSTM(\n            input_size=input_size,\n            hidden_size=hidden_size,\n            num_layers=num_layers,\n            dropout=dropout,\n            bidirectional=bidirectional,\n            batch_first=True,\n        )\n\n    def forward(self, x, mask, hidden=None):\n        # Packing the input sequence\n        lengths = mask.sum(dim=1).cpu()\n        packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)\n\n        # Passing packed sequence through LSTM\n        packed_output, hidden = self.lstm(packed_x, hidden)\n\n        # Unpacking the output sequence\n        output, _ = pad_packed_sequence(packed_output, batch_first=True)\n\n        return output\n\n\nclass FeaturesProjector(nn.Module):\n    def __init__(self, config: GLiClassModelConfig):\n        super().__init__()\n\n        self.linear_1 = nn.Linear(config.encoder_config.hidden_size, config.hidden_size, bias=True)\n        self.act = ACT2FN[config.projector_hidden_act]\n        self.dropout = nn.Dropout(config.dropout)\n        self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)\n\n    def forward(self, features):\n        hidden_states = self.linear_1(features)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass BiEncoderProjector(nn.Module):\n    def __init__(self, config: GLiClassModelConfig):\n        super().__init__()\n\n        self.linear_1 = nn.Linear(config.label_model_config.hidden_size, config.hidden_size, bias=True)\n        self.act = ACT2FN[config.projector_hidden_act]\n        self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)\n\n    def forward(self, features):\n        hidden_states = self.linear_1(features)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\n# Copied from transformers.models.deberta.modeling_deberta.DropoutContext\nclass DropoutContext:\n    def __init__(self):\n        self.dropout = 0\n        self.mask = None\n        self.scale = 1\n        self.reuse_mask = True\n\n\n# Copied from transformers.models.deberta.modeling_deberta.get_mask\ndef get_mask(input, local_context):\n    if not isinstance(local_context, DropoutContext):\n        dropout = local_context\n        mask = None\n    else:\n        dropout = local_context.dropout\n        dropout *= local_context.scale\n        mask = local_context.mask if local_context.reuse_mask else None\n\n    if dropout > 0 and mask is None:\n        mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)\n\n    if isinstance(local_context, DropoutContext) and local_context.mask is None:\n        local_context.mask = mask\n\n    return mask, dropout\n\n\n# Copied from transformers.models.deberta.modeling_deberta.XDropout\nclass XDropout(torch.autograd.Function):\n    \"\"\"Optimized dropout function to save computation and memory by using mask operation instead of multiplication.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input, local_ctx):\n        mask, dropout = get_mask(input, local_ctx)\n        ctx.scale = 1.0 / (1 - dropout)\n        if dropout > 0:\n            ctx.save_for_backward(mask)\n            return input.masked_fill(mask, 0) * ctx.scale\n        else:\n            return input\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        if ctx.scale > 1:\n            (mask,) = ctx.saved_tensors\n            return grad_output.masked_fill(mask, 0) * ctx.scale, None\n        else:\n            return grad_output, None\n\n    @staticmethod\n    def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: float | DropoutContext) -> torch._C.Value:\n        from torch.onnx import symbolic_opset12\n\n        dropout_p = local_ctx\n        if isinstance(local_ctx, DropoutContext):\n            dropout_p = local_ctx.dropout\n        # StableDropout only calls this function when training.\n        train = True\n        # TODO: We should check if the opset_version being used to export\n        # is > 12 here, but there's no good way to do that. As-is, if the\n        # opset_version < 12, export will fail with a CheckerError.\n        # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:\n        # if opset_version < 12:\n        #   return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)\n        return symbolic_opset12.dropout(g, input, dropout_p, train)\n\n\n# Copied from transformers.models.deberta.modeling_deberta.StableDropout\nclass StableDropout(nn.Module):\n    \"\"\"\n    Optimized dropout module for stabilizing the training.\n\n    Args:\n        drop_prob (float): the dropout probabilities\n    \"\"\"\n\n    def __init__(self, drop_prob):\n        super().__init__()\n        self.drop_prob = drop_prob\n        self.count = 0\n        self.context_stack = None\n\n    def forward(self, x):\n        \"\"\"\n        Call the module.\n\n        Args:\n            x (`torch.tensor`): The input tensor to apply dropout\n        \"\"\"\n        if self.training and self.drop_prob > 0:\n            return XDropout.apply(x, self.get_context())\n        return x\n\n    def clear_context(self):\n        self.count = 0\n        self.context_stack = None\n\n    def init_context(self, reuse_mask=True, scale=1):\n        if self.context_stack is None:\n            self.context_stack = []\n        self.count = 0\n        for c in self.context_stack:\n            c.reuse_mask = reuse_mask\n            c.scale = scale\n\n    def get_context(self):\n        if self.context_stack is not None:\n            if self.count >= len(self.context_stack):\n                self.context_stack.append(DropoutContext())\n            ctx = self.context_stack[self.count]\n            ctx.dropout = self.drop_prob\n            self.count += 1\n            return ctx\n        else:\n            return self.drop_prob\n\n\nclass SelfAttentionBlock(nn.Module):\n    def __init__(self, d_model, num_heads, dropout=0.1):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x, mask=None):\n        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)\n        return self.norm(x + self.dropout(attn_output))\n\n\nclass CrossAttentionBlock(nn.Module):\n    def __init__(self, d_model, num_heads, dropout=0.1):\n        super().__init__()\n        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)\n        self.norm = nn.LayerNorm(d_model)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, query, key, value, mask=None):\n        attn_output, _ = self.cross_attn(query, key, value, attn_mask=mask)\n        return self.norm(query + self.dropout(attn_output))\n\n\nclass Fuser(nn.Module):\n    def __init__(self, d_model, num_heads, num_layers, dropout=0.1):\n        super().__init__()\n        self.d_model = d_model\n        self.layers = nn.ModuleList(\n            [\n                nn.ModuleList(\n                    [SelfAttentionBlock(d_model, num_heads, dropout), CrossAttentionBlock(d_model, num_heads, dropout)]\n                )\n                for _ in range(num_layers)\n            ]\n        )\n        self.fc = nn.Linear(d_model, d_model)\n\n    def forward(self, query, key, query_mask=None, key_mask=None):\n        if query_mask is not None and key_mask is not None:\n            self_attn_mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2)\n            cross_attn_mask = query_mask.unsqueeze(-1) * key_mask.unsqueeze(1)\n        else:\n            self_attn_mask = None\n            cross_attn_mask = None\n\n        value = self.fc(key)\n\n        for self_attn, cross_attn in self.layers:\n            query = self_attn(query, mask=self_attn_mask)\n            query = cross_attn(query, key, value, mask=cross_attn_mask)\n\n        return query\n\n\nclass LayerwiseAttention(nn.Module):\n    def __init__(self, num_layers, hidden_size, output_size=None):\n        super().__init__()\n        self.num_layers = num_layers\n        self.hidden_size = hidden_size\n        self.output_size = output_size if output_size is not None else hidden_size\n\n        # Squeeze operation\n        self.squeeze = nn.Linear(hidden_size, 1)\n\n        # Excitation operation\n        self.W1 = nn.Linear(num_layers, num_layers // 2)\n        self.W2 = nn.Linear(num_layers // 2, num_layers)\n\n        # Final projection\n        self.output_projection = nn.Linear(self.hidden_size, self.output_size)\n\n    def forward(self, encoder_outputs):\n        # encoder_outputs is a list of tensors, each of shape [B, L, D]\n        _B, _L, _D = encoder_outputs[0].shape\n\n        # Concatenate all layers\n        U = torch.stack(encoder_outputs, dim=1)  # [B, K, L, D]\n\n        # Squeeze operation\n        Z = self.squeeze(U).squeeze(-1)  # [B, K, L]\n        Z = Z.mean(dim=2)  # [B, K]\n\n        # Excitation operation\n        s = self.W2(F.relu(self.W1(Z)))  # [B, K]\n        s = torch.sigmoid(s)  # [B, K]\n\n        # Apply attention weights\n        U_weighted = U * s.unsqueeze(-1).unsqueeze(-1)  # [B, K, L, D]\n\n        # Sum across layers\n        U_sum = U_weighted.sum(dim=1)  # [B, L, D]\n\n        # Final projection\n        output = self.output_projection(U_sum)  # [B, L, output_size]\n\n        return output\n"
  },
  {
    "path": "gliclass/loss_functions.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\ndef sequence_contrastive_loss(embeddings, mask):\n    # embeddings shape: (B, L, D)\n    # mask shape: (B, L)\n    B, L, _D = embeddings.shape\n\n    # Normalize embeddings\n    embeddings = F.normalize(embeddings, p=2, dim=-1)\n\n    # Compute similarity matrix\n    sim_matrix = torch.matmul(embeddings, embeddings.transpose(1, 2))  # / self.temperature\n\n    # Create labels for cross entropy (diagonal indices)\n    labels = torch.arange(L, device=embeddings.device).unsqueeze(0).expand(B, -1)\n\n    # Compute loss for each element in the batch\n    loss = F.cross_entropy(sim_matrix.reshape(B * L, L), labels.reshape(-1), reduction=\"none\")\n\n    # Apply mask to loss\n    loss = loss.view(B, L) * mask\n\n    # Compute mean loss over non-padded elements\n    loss = loss.sum() / mask.sum()\n\n    return loss\n\n\ndef focal_loss_with_logits(\n    inputs: torch.Tensor,\n    targets: torch.Tensor,\n    alpha: float = 0.25,\n    gamma: float = 2,\n    reduction: str = \"none\",\n    label_smoothing: float = 0.0,\n    ignore_index: int = -100,  # default value for ignored index\n) -> torch.Tensor:\n    \"\"\"\n    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.\n\n    Args:\n        inputs (Tensor): A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary\n                classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n        alpha (float): Weighting factor in range (0,1) to balance\n                positive vs negative examples or -1 for ignore. Default: ``0.25``.\n        gamma (float): Exponent of the modulating factor (1 - p_t) to\n                balance easy vs hard examples. Default: ``2``.\n        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``\n                ``'none'``: No reduction will be applied to the output.\n                ``'mean'``: The output will be averaged.\n                ``'sum'``: The output will be summed. Default: ``'none'``.\n        label_smoothing (float): Specifies the amount of smoothing when computing the loss,\n                                                                where 0.0 means no smoothing.\n        ignore_index (int): Specifies a target value that is ignored and does not contribute\n                            to the input gradient. Default: ``-100``.\n\n    Returns:\n        Loss tensor with the reduction option applied.\n    \"\"\"\n    # Create a mask to ignore specified index\n    valid_mask = targets != ignore_index\n\n    # Apply label smoothing if needed\n    if label_smoothing != 0:\n        with torch.no_grad():\n            targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing\n\n    # Apply sigmoid activation to inputs\n    p = torch.sigmoid(inputs)\n\n    # Compute the binary cross-entropy loss without reduction\n    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n\n    # Apply the valid mask to the loss\n    loss = loss * valid_mask\n\n    # Apply focal loss modulation if gamma is greater than 0\n    if gamma > 0:\n        p_t = p * targets + (1 - p) * (1 - targets)\n        loss = loss * ((1 - p_t) ** gamma)\n\n    # Apply alpha weighting if alpha is specified\n    if alpha >= 0:\n        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)\n        loss = alpha_t * loss\n\n    # Apply reduction method\n    if reduction == \"none\":\n        return loss\n    elif reduction == \"mean\":\n        return loss.sum() / valid_mask.sum()  # Normalize by the number of valid (non-ignored) elements\n    elif reduction == \"sum\":\n        return loss.sum()\n    else:\n        raise ValueError(\n            f\"Invalid value for argument 'reduction': '{reduction}'. Supported reduction modes: 'none', 'mean', 'sum'\"\n        )\n"
  },
  {
    "path": "gliclass/model.py",
    "content": "import os\nimport warnings\nfrom typing import Tuple\nfrom pathlib import Path\nfrom dataclasses import dataclass\n\nimport torch\nimport transformers\nfrom torch import nn\nfrom packaging import version\nfrom transformers import AutoModel, AutoConfig, PreTrainedModel\nfrom transformers.utils import logging\nfrom transformers.modeling_outputs import SequenceClassifierOutput\n\n# Import initialization module (transformers 5.0+) or fallback to torch.nn.init\ntry:\n    from transformers import initialization as init\nexcept ImportError:\n    # transformers < 5.0 doesn't have this module, use torch.nn.init instead\n    from torch.nn import init\nfrom .utils import MissedPackageException, is_module_available\nfrom .config import GLiClassModelConfig\nfrom .layers import FeaturesProjector, BiEncoderProjector, LayerwiseAttention, LstmSeq2SeqEncoder\nfrom .scorers import SCORER2OBJECT\nfrom .poolings import POOLING2OBJECT\nfrom .loss_functions import focal_loss_with_logits, sequence_contrastive_loss\n\nIS_LLM2VEC = is_module_available(\"llm2vec\")\nIS_PEFT = is_module_available(\"peft\")\nIS_TURBOT5 = is_module_available(\"turbot5\")\nIS_FLASHDEBERTA = is_module_available(\"flashdeberta\")\n\nlogger = logging.get_logger(__name__)\n\nif IS_LLM2VEC:\n    from llm2vec.models import GemmaBiModel, LlamaBiModel, Qwen2BiModel, MistralBiModel\n\n    DECODER_MODEL_MAPPING = {\n        \"MistralConfig\": MistralBiModel,\n        \"LlamaConfig\": LlamaBiModel,\n        \"GemmaConfig\": GemmaBiModel,\n        \"Qwen2Config\": Qwen2BiModel,\n    }\nelse:\n    DECODER_MODEL_MAPPING = {}\n\nif IS_TURBOT5:\n    from turbot5.model.modeling import T5EncoderModel as FlashT5EncoderModel\nfrom transformers import T5EncoderModel, UMT5EncoderModel\n\nif IS_FLASHDEBERTA:\n    from flashdeberta import FlashDebertaV2Model\nfrom transformers import DebertaV2Model\n\nif IS_PEFT:\n    from peft import LoraConfig, get_peft_model\n\n\n@dataclass\nclass GLiClassOutput(SequenceClassifierOutput):\n    text_embeddings: torch.Tensor | None = None\n    class_embeddings: torch.Tensor | None = None\n\n\nclass GLiClassPreTrainedModel(PreTrainedModel):\n    config_class = GLiClassModelConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _supports_sdpa = False\n    _keys_to_ignore_on_load_unexpected = [\"position_embeddings\"]\n\n    def _initialize_weights(self, module, is_remote_code: bool = False):\n        \"\"\"\n        Initialize weights if not already initialized.\n\n        This method is called by transformers 5.0+ during post_init().\n        It uses the _is_hf_initialized flag to prevent reinitializing weights\n        that were already loaded from a checkpoint.\n\n        For transformers 4.x, this method is not called, maintaining backward compatibility.\n        \"\"\"\n        if getattr(module, \"_is_hf_initialized\", False):\n            return\n\n        self._init_weights(module)\n        module._is_hf_initialized = True\n\n    def _init_weights(self, module):\n        std = (\n            self.config.initializer_range\n            if hasattr(self.config, \"initializer_range\")\n            else self.config.encoder_config.initializer_range\n        )\n\n        if hasattr(module, \"class_embedding\"):\n            init.normal_(module.class_embedding, mean=0.0, std=std)\n\n        if hasattr(module, \"segment_embeddings\"):\n            init.normal_(module.segment_embeddings.weight, mean=0.0, std=std)\n\n        if isinstance(module, (nn.Linear, nn.Conv2d)):\n            init.normal_(module.weight, mean=0.0, std=std)\n            if module.bias is not None:\n                init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            init.normal_(module.weight, mean=0.0, std=std)\n            if module.padding_idx is not None:\n                init.zeros_(module.weight[module.padding_idx])\n        elif isinstance(module, nn.LSTM):\n            for name, param in module.named_parameters():\n                if \"weight_ih\" in name or \"weight_hh\" in name:\n                    init.normal_(param, mean=0.0, std=std)\n                elif \"bias\" in name:\n                    init.zeros_(param)\n\n\nclass GLiClassBaseModel(nn.Module):  # ):\n    def __init__(self, config: GLiClassModelConfig, device=\"cpu\", **kwargs):\n        super().__init__()\n        self.config = config\n        self.text_projector = FeaturesProjector(config)\n        self.classes_projector = FeaturesProjector(config)\n\n        if config.pooling_strategy not in POOLING2OBJECT:\n            raise NotImplementedError(f\"{config.pooling_strategy} is not implemented pooling type.\")\n        else:\n            self.pooler = POOLING2OBJECT[config.pooling_strategy]()\n\n        if config.pooling_strategy not in POOLING2OBJECT:\n            raise NotImplementedError(\n                f\"{config.scorer_type} is not implemented. Choose one of this: 'dot', 'weighted-dot'\"\n            )\n        else:\n            self.scorer = SCORER2OBJECT[config.scorer_type](\n                config.hidden_size,\n                num_heads=config.scorer_num_heads,\n                scorer_mlp_hidden_size=config.scorer_mlp_hidden_size,\n                attn_dropout=config.scorer_attn_dropout,\n            )\n\n        if config.use_lstm:\n            self.lstm = LstmSeq2SeqEncoder(config.hidden_size, config.hidden_size // 2, bidirectional=True)\n\n        if config.squeeze_layers:\n            self.layer_wise_attention = LayerwiseAttention(\n                config.encoder_config.num_hidden_layers, config.encoder_config.hidden_size\n            )\n\n        drop_out = getattr(config, \"dropout\", 0.0)\n        # self.dropout = StableDropout(drop_out)\n        self.dropout = nn.Dropout(drop_out)\n\n        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))\n\n        self.epsilon = 1e-8\n        self.vocab_size = config.vocab_size\n        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1\n        self.num_labels = -1\n\n        self.device = torch.device(device)\n\n    def _extract_class_features(self, token_embeds, input_ids, attention_mask, max_num_classes=None):\n        batch_size, _sequence_length, embed_dim = token_embeds.shape\n\n        class_token_mask = input_ids == self.config.class_token_index\n        num_class_tokens = torch.sum(class_token_mask, dim=-1, keepdim=True)\n\n        # max_num_classes from caller (CPU int) avoids GPU→CPU sync via .item()\n        max_embed_dim = max_num_classes if max_num_classes is not None else self.config.max_num_classes\n\n        # Get class token pooling method from config (default to \"first\" for backward compatibility)\n        class_token_pooling = getattr(self.config, \"class_token_pooling\", \"first\")\n\n        if class_token_pooling == \"average\":\n            # Average all tokens belonging to each class label\n            classes_embedding, classes_embedding_mask = self._extract_class_features_averaged(\n                token_embeds,\n                input_ids,\n                attention_mask,\n                class_token_mask,\n                num_class_tokens,\n                max_embed_dim,\n                batch_size,\n                embed_dim,\n            )\n        else:\n            # Original behavior: use only the class token (or token after it)\n            classes_embedding, classes_embedding_mask = self._extract_class_features_first(\n                token_embeds,\n                input_ids,\n                attention_mask,\n                class_token_mask,\n                num_class_tokens,\n                max_embed_dim,\n                batch_size,\n                embed_dim,\n            )\n\n        # Text features extraction\n        if self.config.extract_text_features:\n            text_token_mask = input_ids == self.config.text_token_index\n            text_token_indices = text_token_mask.int().argmax(dim=-1)  # (batch,)\n            max_text_length = input_ids.shape[-1]  # static, no GPU→CPU sync\n\n            # (batch, max_text_length): source position in token_embeds for each target slot\n            aranged_target_idx = (\n                torch.arange(max_text_length, device=token_embeds.device).unsqueeze(0).expand(batch_size, -1)\n            )\n            valid_mask = aranged_target_idx < (input_ids.shape[-1] - text_token_indices).unsqueeze(1)\n\n            source_indices = (text_token_indices.unsqueeze(1) + aranged_target_idx).clamp(max=input_ids.shape[-1] - 1)\n            batch_arange = torch.arange(batch_size, device=token_embeds.device).unsqueeze(1)\n\n            # Gather then zero-out invalid positions — no nonzero/scatter needed\n            text_tokens_embeddings = token_embeds[batch_arange, source_indices] * valid_mask.unsqueeze(-1).to(\n                token_embeds.dtype\n            )\n            text_tokens_mask = attention_mask[batch_arange, source_indices] * valid_mask\n        else:\n            text_tokens_embeddings = token_embeds\n            text_tokens_mask = attention_mask\n        return classes_embedding, classes_embedding_mask, text_tokens_embeddings, text_tokens_mask\n\n    def _extract_class_features_first(\n        self,\n        token_embeds,\n        input_ids,\n        attention_mask,\n        class_token_mask,\n        num_class_tokens,\n        max_embed_dim,\n        batch_size,\n        embed_dim,\n    ):\n        \"\"\"Extract only the class token embedding (or token after it). Fully vectorized.\"\"\"\n        class_cum = class_token_mask.long().cumsum(dim=-1)  # (batch, seq)\n        k_range = torch.arange(max_embed_dim, device=token_embeds.device).view(1, -1, 1)\n\n        # select_mask[b, k, s] = True at the position of the k-th class token\n        select_mask = class_token_mask.unsqueeze(1) & ((class_cum.unsqueeze(1) - 1) == k_range)\n\n        if not self.config.embed_class_token:\n            # Shift right by 1: select the token immediately after each class token\n            shifted = torch.zeros_like(select_mask)\n            shifted[:, :, 1:] = select_mask[:, :, :-1]\n            select_mask = shifted\n\n        classes_embedding = torch.einsum(\"bks,bsd->bkd\", select_mask.to(token_embeds.dtype), token_embeds)\n\n        arange_k = torch.arange(max_embed_dim, device=token_embeds.device).unsqueeze(0)\n        classes_embedding_mask = (arange_k < num_class_tokens).to(attention_mask.dtype)\n\n        return classes_embedding, classes_embedding_mask\n\n    def _extract_class_features_averaged(\n        self,\n        token_embeds,\n        input_ids,\n        attention_mask,\n        class_token_mask,\n        num_class_tokens,\n        max_embed_dim,\n        batch_size,\n        embed_dim,\n    ):\n        \"\"\"Average all tokens belonging to each class label. Fully vectorized.\"\"\"\n        # class_cum[b, s] = cumulative count of class tokens up to position s\n        class_cum = class_token_mask.long().cumsum(dim=-1)  # (batch, seq)\n\n        if self.config.extract_text_features:\n            text_token_mask = input_ids == self.config.text_token_index\n        else:\n            text_token_mask = torch.zeros_like(class_token_mask)\n        # text_cum[b, s] >= 1 at and after the text token → use as exclusion boundary\n        text_cum = text_token_mask.long().cumsum(dim=-1)  # (batch, seq)\n\n        # span_mask[b, k, s] = True if token s belongs to the span of class k\n        k_range = torch.arange(max_embed_dim, device=token_embeds.device).view(1, -1, 1)\n        span_mask = (\n            (class_cum.unsqueeze(1) == (k_range + 1))  # in the span of class k\n            & (text_cum.unsqueeze(1) == 0)  # before the text boundary\n            & attention_mask.unsqueeze(1).bool()  # real token (not padding)\n        )\n        if not self.config.embed_class_token:\n            span_mask = span_mask & ~class_token_mask.unsqueeze(1)\n\n        span_float = span_mask.to(token_embeds.dtype)  # (batch, max_embed_dim, seq)\n        class_counts = span_float.sum(dim=-1, keepdim=True).clamp(min=1)\n        classes_embedding = torch.einsum(\"bks,bsd->bkd\", span_float, token_embeds) / class_counts\n\n        arange_k = torch.arange(max_embed_dim, device=token_embeds.device).unsqueeze(0)\n        classes_embedding_mask = (arange_k < num_class_tokens).to(attention_mask.dtype)\n\n        return classes_embedding, classes_embedding_mask\n\n    def get_loss(self, logits, labels, classes_embedding=None, classes_embedding_mask=None):\n        loss = None\n        if labels is not None:\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    # regression task\n                    loss_fn = nn.MSELoss()\n                    logits = logits.view(-1).to(labels.dtype)\n                    loss = loss_fn(logits, labels.view(-1))\n                elif labels.dim() == 1 or labels.size(-1) == 1:\n                    label_index = (labels >= 0).nonzero()\n                    labels = labels.long()\n                    if label_index.size(0) > 0:\n                        labeled_logits = torch.gather(\n                            logits, 0, label_index.expand(label_index.size(0), logits.size(1))\n                        )\n                        labels = torch.gather(labels, 0, label_index.view(-1))\n                        loss_fct = nn.CrossEntropyLoss()\n                        loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))\n                    else:\n                        loss = torch.tensor(0).to(logits)\n                else:\n                    log_softmax = nn.LogSoftmax(-1)\n                    loss = -((log_softmax(logits) * labels).sum(-1)).mean()\n            elif self.config.problem_type == \"regression\":\n                loss_fct = nn.MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = nn.CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                all_losses = focal_loss_with_logits(\n                    logits,\n                    labels,\n                    self.config.focal_loss_alpha,\n                    self.config.focal_loss_gamma,\n                    self.config.focal_loss_reduction,\n                )\n                if classes_embedding_mask is not None:\n                    all_losses = all_losses * classes_embedding_mask.float()\n                loss = all_losses.mean()\n\n            if self.config.contrastive_loss_coef > 0 and classes_embedding is not None:\n                contrastive_loss = sequence_contrastive_loss(classes_embedding, classes_embedding_mask)\n                loss = loss + contrastive_loss * self.config.contrastive_loss_coef\n        return loss\n\n\nclass GLiClassUniEncoder(GLiClassBaseModel):\n    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):\n        super().__init__(config)\n        if config.encoder_config is None:\n            if config.encoder_model_name is None:\n                raise ValueError(\"You need to specify encoder model name to use it as a backbone.\")\n            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)\n\n        config_name = config.encoder_config.__class__.__name__\n\n        model_kwargs = {}\n        if config_name in DECODER_MODEL_MAPPING:\n            if not IS_LLM2VEC:\n                raise MissedPackageException(\n                    f\"The llm2vec package must be installed to use this decoder model: {config_name}\"\n                )\n            else:\n                print(\"Loading decoder model using LLM2Vec...\")\n                ModelClass = DECODER_MODEL_MAPPING[config_name]\n            decoder = True\n        elif config_name in {\"T5Config\", \"MT5Config\", \"UMT5Config\"}:\n            decoder = False\n            turbot5_type = os.environ.get(\"TURBOT5_ATTN_TYPE\", \"\")\n            if turbot5_type and IS_TURBOT5:\n                ModelClass = FlashT5EncoderModel\n                model_kwargs = {\"attention_type\": turbot5_type}\n            elif config_name == \"UMT5Config\":\n                ModelClass = UMT5EncoderModel\n            else:\n                ModelClass = T5EncoderModel\n        elif config_name in {\"DebertaV2Config\"}:\n            decoder = False\n            if os.environ.get(\"USE_FLASHDEBERTA\", \"\") and IS_FLASHDEBERTA:\n                print(\"Using FlashDeberta backend.\")\n                ModelClass = FlashDebertaV2Model\n            else:\n                ModelClass = DebertaV2Model\n\n        else:\n            decoder = False\n            ModelClass = AutoModel\n\n        if from_pretrained:\n            self.encoder_model = ModelClass.from_pretrained(config.encoder_model_name, **model_kwargs)\n        elif decoder:\n            self.encoder_model = ModelClass(config.encoder_config)\n        elif config_name in {\"T5Config\", \"MT5Config\", \"UMT5Config\", \"DebertaV2Config\"}:\n            self.encoder_model = ModelClass._from_config(config.encoder_config)\n        else:\n            self.encoder_model = ModelClass.from_config(config.encoder_config)\n\n        if config.vocab_size is not None and hasattr(self.encoder_model, \"resize_token_embeddings\"):\n            current_vocab = self.encoder_model.config.vocab_size\n            if current_vocab != config.vocab_size:\n                self.encoder_model.resize_token_embeddings(config.vocab_size)\n\n        adapter_config_file = Path(config.encoder_model_name) / \"adapter_config.json\"\n\n        if adapter_config_file.exists():\n            if not IS_PEFT:\n                warnings.warn(\n                    \"Adapter configs were detected, if you want to apply them you need to install peft package.\",\n                    stacklevel=2,\n                )\n            else:\n                adapter_config = LoraConfig.from_pretrained(config.encoder_model_name)\n                self.encoder_model = get_peft_model(self.encoder_model, adapter_config)\n\n        if config.use_segment_embeddings:\n            self.segment_embeddings = nn.Embedding(3, config.encoder_config.hidden_size)\n            nn.init.normal_(self.segment_embeddings.weight, mean=0.0, std=config.initializer_range)\n\n    def _create_segment_ids(self, input_ids):\n        batch_size, _seq_length = input_ids.shape\n        segment_ids = torch.zeros_like(input_ids)  # Default: segment 0 (labels)\n\n        # Find example token positions\n        example_token_mask = input_ids == self.config.example_token_index\n        example_token_indices = example_token_mask.int().argmin(dim=-1)\n        has_example = example_token_mask.any(dim=-1)\n\n        text_token_mask = input_ids == self.config.text_token_index\n        text_token_indices = text_token_mask.int().argmax(dim=-1)\n\n        for batch_idx in range(batch_size):\n            text_start = text_token_indices[batch_idx].item()\n\n            # If examples exist, assign segment 1 to example section\n            if has_example[batch_idx]:\n                example_start = example_token_indices[batch_idx].item()\n                segment_ids[batch_idx, text_start:example_start] = 1\n                segment_ids[batch_idx, example_start:] = 2\n            else:\n                segment_ids[batch_idx, text_start:] = 1\n\n        return segment_ids\n\n    def process_encoder_output(self, input_ids, attention_mask, encoder_layer, labels=None, max_num_classes=None):\n        classes_embedding, classes_embedding_mask, text_token_embeddings, text_mask = self._extract_class_features(\n            encoder_layer, input_ids, attention_mask, max_num_classes\n        )\n        if self.config.use_lstm:\n            text_token_embeddings = self.lstm(text_token_embeddings, text_mask)\n\n        pooled_output = self.pooler(text_token_embeddings)\n        pooled_output = self.text_projector(pooled_output)\n        pooled_output = self.dropout(pooled_output)\n        if self.config.normalize_features:\n            pooled_output = pooled_output / (pooled_output.norm(p=2, dim=-1, keepdim=True) + self.epsilon)\n\n        classes_embedding = self.classes_projector(classes_embedding)\n        if self.config.normalize_features:\n            classes_embedding = classes_embedding / (classes_embedding.norm(p=2, dim=-1, keepdim=True) + self.epsilon)\n\n        logits = self.scorer(pooled_output, classes_embedding, text_mask=text_mask)\n\n        if self.config.normalize_features:\n            logits = logits * self.logit_scale.to(classes_embedding.device)\n\n        loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)\n        return (logits, loss, pooled_output, classes_embedding)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        output_attentions: bool | None = None,\n        output_hidden_states: bool | None = None,\n        output_text_embeddings: bool | None = None,\n        output_class_embeddings: bool | None = None,\n        return_dict: bool | None = None,\n        max_num_classes: int | None = None,\n        **kwargs,\n    ) -> Tuple | GLiClassOutput:\n        r\"\"\"\n        Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if self.config.squeeze_layers or self.config.layer_wise:\n            output_hidden_states = True\n            return_dict = True\n\n        if self.config.use_segment_embeddings:\n            embedding_layer = self.encoder_model.get_input_embeddings()\n            token_embeds = embedding_layer(input_ids)\n\n            segment_ids = self._create_segment_ids(input_ids)\n            segment_embeds = self.segment_embeddings(segment_ids)\n\n            inputs_embeds = token_embeds + segment_embeds\n\n            outputs = self.encoder_model(\n                inputs_embeds=inputs_embeds,\n                attention_mask=attention_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                **kwargs,\n            )\n        else:\n            outputs = self.encoder_model(\n                input_ids,\n                attention_mask=attention_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                **kwargs,\n            )\n\n        if self.config.layer_wise and labels is not None:\n            hidden_states = outputs.hidden_states\n            loss = 0\n            for encoder_layer in hidden_states:\n                logits, layer_loss, pooled_output, classes_embedding = self.process_encoder_output(\n                    input_ids, attention_mask, encoder_layer, labels, max_num_classes\n                )\n                loss += layer_loss\n        else:\n            if self.config.encoder_layer_id == -1:\n                if self.config.squeeze_layers:\n                    encoder_layer = self.layer_wise_attention(outputs.hidden_states)\n                else:\n                    encoder_layer = outputs[0]\n            else:\n                encoder_layer = outputs.hidden_states[self.config.encoder_layer_id]\n            logits, loss, pooled_output, classes_embedding = self.process_encoder_output(\n                input_ids, attention_mask, encoder_layer, labels, max_num_classes\n            )\n\n        if not return_dict:\n            output = (logits, *outputs[1:])\n            return ((loss, *output)) if loss is not None else output\n\n        return GLiClassOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            text_embeddings=pooled_output if output_text_embeddings else None,\n            class_embeddings=classes_embedding if output_class_embeddings else None,\n        )\n\n\nclass GLiClassEncoderDecoder(GLiClassBaseModel):\n    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):\n        super().__init__(config)\n        if config.encoder_config is None:\n            if config.encoder_model_name is None:\n                raise ValueError(\"You need to specify encoder model name to use it as a backbone.\")\n            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)\n\n        if not config.encoder_config.is_encoder_decoder:\n            raise ValueError(\"You need to choose encoder-decoder model as a backbone.\")\n\n        if from_pretrained:\n            self.encoder_decoder_model = AutoModel.from_pretrained(config.encoder_model_name)\n        else:\n            self.encoder_decoder_model = AutoModel.from_config(config.encoder_config)\n\n    @staticmethod\n    def _make_bidirectional_4d_mask(attention_mask_2d, dtype):\n        \"\"\"Convert a 2D padding mask into a 4D bidirectional attention mask.\n\n        When a 4D mask is passed to the decoder, the model uses it as-is\n        without applying its default causal pattern, enabling bidirectional\n        self-attention in the decoder.\n\n        Args:\n            attention_mask_2d: (batch_size, seq_length) with 1 for real tokens, 0 for padding.\n            dtype: The dtype of the model (needed for the min-value fill).\n\n        Returns:\n            4D mask of shape (batch_size, 1, seq_length, seq_length).\n            Values are 0.0 for attended positions and a large negative value for masked positions.\n        \"\"\"\n        batch_size, seq_length = attention_mask_2d.shape\n        # (batch_size, 1, 1, seq_length) - masks out padding columns\n        padding_mask = (1.0 - attention_mask_2d.to(dtype))[:, None, None, :] * torch.finfo(dtype).min\n        return padding_mask.expand(batch_size, 1, seq_length, seq_length)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        class_input_ids: torch.Tensor | None = None,\n        class_attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        output_attentions: bool | None = None,\n        output_hidden_states: bool | None = None,\n        output_text_embeddings: bool | None = None,\n        output_class_embeddings: bool | None = None,\n        return_dict: bool | None = True,\n        **kwargs,\n    ) -> Tuple | SequenceClassifierOutput:\n        r\"\"\"\n        Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # Build a 4D bidirectional mask for the decoder so it attends to\n        # all non-padding positions instead of using causal masking.\n        decoder_4d_mask = None\n        if class_attention_mask is not None:\n            model_dtype = next(self.encoder_decoder_model.parameters()).dtype\n            decoder_4d_mask = self._make_bidirectional_4d_mask(class_attention_mask, model_dtype)\n\n        outputs = self.encoder_decoder_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            decoder_input_ids=class_input_ids,\n            decoder_attention_mask=decoder_4d_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            **kwargs,\n        )\n        text_token_embeddings = outputs.encoder_last_hidden_state\n        decoder_token_embeddings = outputs.last_hidden_state\n        classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(\n            decoder_token_embeddings, class_input_ids, class_attention_mask\n        )\n\n        if self.config.use_lstm:\n            text_token_embeddings = self.lstm(text_token_embeddings, attention_mask)\n\n        pooled_output = self.pooler(text_token_embeddings)\n        pooled_output = self.text_projector(pooled_output)\n        pooled_output = self.dropout(pooled_output)\n        if self.config.normalize_features:\n            pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)\n\n        classes_embedding = self.classes_projector(classes_embedding)\n        if self.config.normalize_features:\n            classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)\n\n        logits = self.scorer(pooled_output, classes_embedding)\n\n        if self.config.normalize_features:\n            logits = logits * self.logit_scale.to(classes_embedding.device)\n\n        loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)\n\n        if not return_dict:\n            output = (logits, *outputs[1:])\n            return ((loss, *output)) if loss is not None else output\n\n        return GLiClassOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.decoder_hidden_states,\n            attentions=outputs.decoder_attentions,\n            text_embeddings=pooled_output if output_text_embeddings else None,\n            class_embeddings=classes_embedding if output_class_embeddings else None,\n        )\n\n\nclass GLiClassEncoderDecoderCLS(GLiClassBaseModel):\n    \"\"\"Encoder-decoder architecture where labels go to the encoder and text goes to the decoder.\n\n    Class features are extracted from encoder output using _extract_class_features().\n    Text features are extracted from the last non-padding token of the decoder output.\n    \"\"\"\n\n    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):\n        super().__init__(config)\n        if config.encoder_config is None:\n            if config.encoder_model_name is None:\n                raise ValueError(\"You need to specify encoder model name to use it as a backbone.\")\n            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)\n\n        if not config.encoder_config.is_encoder_decoder:\n            raise ValueError(\"You need to choose encoder-decoder model as a backbone.\")\n\n        if from_pretrained:\n            self.encoder_decoder_model = AutoModel.from_pretrained(config.encoder_model_name)\n        else:\n            self.encoder_decoder_model = AutoModel.from_config(config.encoder_config)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        class_input_ids: torch.Tensor | None = None,\n        class_attention_mask: torch.Tensor | None = None,\n        inputs_embeds: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        output_attentions: bool | None = None,\n        output_hidden_states: bool | None = None,\n        output_text_embeddings: bool | None = None,\n        output_class_embeddings: bool | None = None,\n        return_dict: bool | None = True,\n        **kwargs,\n    ) -> Tuple | SequenceClassifierOutput:\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # Labels → encoder, Text → decoder\n        outputs = self.encoder_decoder_model(\n            input_ids=class_input_ids,\n            attention_mask=class_attention_mask,\n            decoder_input_ids=input_ids,\n            decoder_attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            **kwargs,\n        )\n\n        # Class features from encoder output\n        encoder_token_embeddings = outputs.encoder_last_hidden_state\n        classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(\n            encoder_token_embeddings, class_input_ids, class_attention_mask\n        )\n\n        # Text features from decoder's last non-padding token\n        decoder_output = outputs.last_hidden_state\n        batch_size = decoder_output.shape[0]\n        last_non_pad_idx = attention_mask.sum(dim=1) - 1\n        pooled_output = decoder_output[torch.arange(batch_size, device=decoder_output.device), last_non_pad_idx]\n\n        pooled_output = self.text_projector(pooled_output)\n        pooled_output = self.dropout(pooled_output)\n        if self.config.normalize_features:\n            pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)\n\n        classes_embedding = self.classes_projector(classes_embedding)\n        if self.config.normalize_features:\n            classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)\n\n        logits = self.scorer(pooled_output, classes_embedding)\n\n        if self.config.normalize_features:\n            logits = logits * self.logit_scale.to(classes_embedding.device)\n\n        loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)\n\n        if not return_dict:\n            output = (logits, *outputs[1:])\n            return ((loss, *output)) if loss is not None else output\n\n        return GLiClassOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.decoder_hidden_states,\n            attentions=outputs.decoder_attentions,\n            text_embeddings=pooled_output if output_text_embeddings else None,\n            class_embeddings=classes_embedding if output_class_embeddings else None,\n        )\n\n\nclass GLiClassBiEncoder(GLiClassBaseModel):\n    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):\n        super().__init__(config)\n        if config.encoder_config is None:\n            if config.encoder_model_name is None:\n                raise ValueError(\"You need to specify encoder model name to use it as a backbone.\")\n            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)\n\n        if config.label_model_config is None:\n            if config.label_model_name is None:\n                raise ValueError(\"You need to specify label model name to use it as a backbone.\")\n            config.label_model_config = AutoConfig.from_pretrained(config.label_model_name)\n\n        def initialize_encoder(configs, model_name, from_pretrained):\n            if from_pretrained:\n                return AutoModel.from_pretrained(model_name)\n            else:\n                return AutoModel.from_config(configs)\n\n        self.encoder_model = initialize_encoder(config.encoder_config, config.encoder_model_name, from_pretrained)\n        self.label_encoder = initialize_encoder(config.label_model_config, config.label_model_name, from_pretrained)\n        self.biencoder_projector = BiEncoderProjector(config)\n\n    def pool_outputs(self, encoder_outputs):\n        text_embeddings = self.pooler(encoder_outputs[0])\n        text_embeddings = self.text_projector(text_embeddings)\n        text_embeddings = self.dropout(text_embeddings)\n        if self.config.normalize_features:\n            text_embeddings = nn.functional.normalize(text_embeddings, p=2, dim=-1, eps=self.epsilon)\n        return text_embeddings\n\n    def encode_text(self, input_ids, attention_mask):\n        outputs = self.encoder_model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))\n        text_embeddings = self.pool_outputs(outputs)\n        return text_embeddings\n\n    def encode_classes(self, class_input_ids, class_attention_mask, labels_mask=None):\n        batch_size = class_input_ids.shape[0]\n        num_classes = class_input_ids.shape[1]\n        if labels_mask is not None:\n            batch_indices, indices = torch.where(labels_mask == 1)\n            selected_input_ids = class_input_ids[batch_indices, indices]\n            selected_attention_mask = class_attention_mask[batch_indices, indices]\n\n            outputs = self.label_encoder(selected_input_ids, attention_mask=selected_attention_mask)\n            class_embeddings_filtered = self.pooler(outputs[0])\n\n            class_embeddings = torch.zeros(\n                batch_size,\n                num_classes,\n                class_embeddings_filtered.shape[-1],\n                dtype=class_embeddings_filtered.dtype,\n                device=class_embeddings_filtered.device,\n            )\n\n            class_embeddings[batch_indices, indices] = class_embeddings_filtered\n        else:\n            class_input_ids = class_input_ids.view(-1, class_input_ids.shape[-1])\n            class_attention_mask = class_attention_mask.view(-1, class_input_ids.shape[-1])\n            outputs = self.label_encoder(class_input_ids, attention_mask=class_attention_mask)\n            class_embeddings = self.pooler(outputs[0])\n            class_embeddings = class_embeddings.reshape(batch_size, num_classes, -1)\n        class_embeddings = self.biencoder_projector(class_embeddings)\n        class_embeddings = self.classes_projector(class_embeddings)\n        if self.config.normalize_features:\n            class_embeddings = nn.functional.normalize(class_embeddings, p=2, dim=-1, eps=self.epsilon)\n        return class_embeddings\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        class_input_ids: torch.Tensor | None = None,\n        class_attention_mask: torch.Tensor | None = None,\n        labels_mask: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        output_text_embeddings: bool | None = None,\n        output_class_embeddings: bool | None = None,\n        return_dict: bool | None = None,\n        **kwargs,\n    ) -> Tuple | SequenceClassifierOutput:\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        text_embeddings = self.encode_text(input_ids, attention_mask)\n        class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)\n        logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)\n\n        if labels_mask is not None:\n            logits = torch.where(labels_mask == 0, -1e3, logits)\n\n        loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)\n\n        if not return_dict:\n            output = (logits,)\n            return ((loss, *output)) if loss is not None else output\n\n        return GLiClassOutput(\n            loss=loss,\n            logits=logits,\n            text_embeddings=text_embeddings if output_text_embeddings else None,\n            class_embeddings=class_embeddings if output_class_embeddings else None,\n        )\n\n\nclass GLiClassBiEncoderFused(GLiClassBiEncoder):\n    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):\n        super().__init__(config, from_pretrained)\n\n    def encode_text(self, input_ids, attention_mask, class_embeddings, labels_mask):\n        embedding_layer = self.encoder_model.get_input_embeddings()\n        inputs_embeds = embedding_layer(input_ids)\n\n        class_token_mask = input_ids == self.config.class_token_index\n        batch_indices, class_token_indices = torch.where(class_token_mask)\n\n        labels_batch_indices, labels_indices = torch.where(labels_mask == 1)\n\n        selected_class_embeddings = class_embeddings[labels_batch_indices, labels_indices]\n\n        inputs_embeds[batch_indices, class_token_indices] = selected_class_embeddings\n        encoder_outputs = self.encoder_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask.squeeze(1))\n\n        post_class_embeddings = torch.zeros_like(class_embeddings)\n        post_class_embeddings[labels_batch_indices, labels_indices] = encoder_outputs[0][\n            batch_indices, class_token_indices\n        ]\n        return encoder_outputs, post_class_embeddings\n\n    def forward(\n        self,\n        input_ids: torch.Tensor | None = None,\n        attention_mask: torch.Tensor | None = None,\n        class_input_ids: torch.Tensor | None = None,\n        class_attention_mask: torch.Tensor | None = None,\n        labels_mask: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        output_text_embeddings: bool | None = None,\n        output_class_embeddings: bool | None = None,\n        return_dict: bool | None = None,\n        **kwargs,\n    ) -> Tuple | SequenceClassifierOutput:\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        raw_class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)\n\n        encoder_outputs, class_embeddings = self.encode_text(\n            input_ids, attention_mask, raw_class_embeddings, labels_mask\n        )\n\n        text_embeddings = self.pool_outputs(encoder_outputs)\n\n        logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)\n\n        if labels_mask is not None:\n            logits = torch.where(labels_mask == 0, -1e3, logits)\n\n        loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)\n\n        if not return_dict:\n            output = (logits,)\n            return ((loss, *output)) if loss is not None else output\n\n        return GLiClassOutput(\n            loss=loss,\n            logits=logits,\n            text_embeddings=text_embeddings if output_text_embeddings else None,\n            class_embeddings=class_embeddings if output_class_embeddings else None,\n        )\n\n\nclass GLiClassModel(GLiClassPreTrainedModel):\n    def __init__(self, config, from_pretrained=False):\n        super().__init__(config)\n        if config.architecture_type == \"uni-encoder\":\n            self.model = GLiClassUniEncoder(config, from_pretrained)\n        elif config.architecture_type == \"bi-encoder\":\n            self.model = GLiClassBiEncoder(config, from_pretrained)\n        elif config.architecture_type == \"bi-encoder-fused\":\n            self.model = GLiClassBiEncoderFused(config, from_pretrained)\n        elif config.architecture_type == \"encoder-decoder\":\n            self.model = GLiClassEncoderDecoder(config, from_pretrained)\n        elif config.architecture_type == \"encoder-decoder-cls\":\n            self.model = GLiClassEncoderDecoderCLS(config, from_pretrained)\n        self.post_init()\n\n    def get_input_embeddings(self):\n        if self.config.architecture_type in {\"uni-encoder\"}:\n            return self.model.encoder_model.get_input_embeddings()\n        elif self.config.architecture_type in {\"encoder-decoder\", \"encoder-decoder-cls\"}:\n            return self.model.encoder_decoder_model.get_input_embeddings()\n        else:\n            raise NotImplementedError(\"Getting input embeddings is not implemented for bi-encoder architecture\")\n\n    def set_input_embeddings(self, value):\n        if self.config.architecture_type in {\"uni-encoder\"}:\n            self.model.encoder_model.set_input_embeddings(value)\n            return None\n        elif self.config.architecture_type in {\"encoder-decoder\", \"encoder-decoder-cls\"}:\n            self.model.encoder_decoder_model.set_input_embeddings(value)\n        elif self.config.architecture_type in {\"bi-encoder\", \"bi-encoder-fused\"}:\n            self.model.encoder_model.set_input_embeddings(value)\n        else:\n            raise NotImplementedError(\"Setting input embeddings is not implemented for bi-encoder architecture\")\n\n    def tie_weights(self, recompute_mapping=True, missing_keys=None):\n        \"\"\"\n        Tie model weights for architectures that share parameters.\n\n        This method handles:\n        - Version compatibility between transformers v4 and v5\n        - Different GLiClass architecture types\n        - Special handling for T5/MT5 models in transformers v5+ where encoder.embed_tokens\n          may be incorrectly initialized instead of being tied to shared.weight\n\n        Args:\n            recompute_mapping: Whether to recompute weight mapping (transformers v5+)\n            missing_keys: Keys that are missing from checkpoint (transformers v5+)\n        \"\"\"\n        # Get encoder model based on architecture type\n        encoder_model = None\n        if self.config.architecture_type in {\"uni-encoder\"}:\n            encoder_model = self.model.encoder_model\n        elif self.config.architecture_type in {\"encoder-decoder\", \"encoder-decoder-cls\"}:\n            encoder_model = self.model.encoder_decoder_model\n        elif self.config.architecture_type in {\"bi-encoder\", \"bi-encoder-fused\"}:\n            encoder_model = self.model.encoder_model\n        else:\n            raise NotImplementedError(\"Tie weights is not implemented for this architecture type\")\n\n        # Call base tie_weights with version-appropriate parameters\n        if version.parse(transformers.__version__) >= version.parse(\"5.0.0\"):\n            result = encoder_model.tie_weights(recompute_mapping=recompute_mapping, missing_keys=missing_keys)\n        else:\n            result = encoder_model.tie_weights()\n\n        # Fix for T5/MT5/UMT5 models in transformers v5+\n        # In v5, if encoder.embed_tokens.weight is missing from checkpoint, it gets randomly\n        # initialized instead of being tied to shared.weight. We explicitly ensure proper tying.\n        if (\n            encoder_model is not None\n            and hasattr(encoder_model, \"shared\")\n            and hasattr(encoder_model, \"encoder\")\n            and hasattr(encoder_model.encoder, \"embed_tokens\")\n        ):\n            shared_weight = encoder_model.shared.weight\n            embed_weight = encoder_model.encoder.embed_tokens.weight\n\n            # Only tie if they're not already the same tensor\n            if shared_weight is not embed_weight:\n                encoder_model.encoder.embed_tokens.weight = shared_weight\n                if version.parse(transformers.__version__) >= version.parse(\"5.0.0\"):\n                    logger.info(\n                        \"Applied transformers v5 compatibility fix: tied encoder.embed_tokens.weight \"\n                        \"to shared.weight for T5-based model\"\n                    )\n\n        return result\n\n    def resize_token_embeddings(self, new_num_tokens: int | None = None, pad_to_multiple_of=None) -> nn.Embedding:\n        if self.config.architecture_type in {\"uni-encoder\"}:\n            model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)\n        elif self.config.architecture_type in {\"encoder-decoder\", \"encoder-decoder-cls\"}:\n            model_embeds = self.model.encoder_decoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)\n        elif self.config.architecture_type in {\"bi-encoder-fused\"}:\n            model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)\n        else:\n            raise NotImplementedError(\"Resizing is not implemented for bi-encoder architecture\")\n        self.config.encoder_config.vocab_size = model_embeds.num_embeddings\n        self.config.vocab_size = model_embeds.num_embeddings\n        self.vocab_size = model_embeds.num_embeddings\n        return model_embeds\n\n    def forward(self, *args, **kwargs):\n        outputs = self.model(*args, **kwargs)\n        return outputs\n"
  },
  {
    "path": "gliclass/ops.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n# ─── Attention (padded) ───────────────────────────────────────────────────────\n\n\ndef attn_padded(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    key_padding_mask: torch.Tensor | None = None,\n    dropout_p: float = 0.0,\n) -> torch.Tensor:\n    \"\"\"\n    Padded attention via F.scaled_dot_product_attention.\n    Uses FlashAttention backend automatically on CUDA when available.\n\n    Args:\n        q:                [batch, nq, nheads, head_dim]\n        k:                [batch, nk, nheads, head_dim]\n        v:                [batch, nk, nheads, head_dim]\n        key_padding_mask: [batch, nk] bool, True = real token\n    Returns:\n        [batch, nq, nheads, head_dim]\n    \"\"\"\n    q = q.transpose(1, 2)\n    k = k.transpose(1, 2)\n    v = v.transpose(1, 2)\n\n    attn_mask = None\n    if key_padding_mask is not None:\n        attn_mask = key_padding_mask[:, None, None, :].bool()\n\n    out = F.scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=attn_mask,\n        dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,\n    )\n    return out.transpose(1, 2)  # [batch, nq, nheads, head_dim]\n"
  },
  {
    "path": "gliclass/pipeline.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, List\n\nimport torch\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer\n\nfrom .model import GLiClassModel, GLiClassBiEncoder\nfrom .utils import retrieval_augmented_text\n\n\ndef flatten_hierarchical_labels(\n    labels: List[str] | Dict[str, Any], prefix: str = \"\", separator: str = \".\"\n) -> List[str]:\n    \"\"\"\n    Flatten hierarchical labels into dot notation.\n\n    Supports arbitrary nesting depth. Examples:\n\n    Input: {\"sentiment\": [\"positive\", \"negative\", \"neutral\"], \"topic\": [\"product\", \"service\", \"shipping\"]}\n    Output: [\"sentiment.positive\", \"sentiment.negative\", \"sentiment.neutral\",\n             \"topic.product\", \"topic.service\", \"topic.shipping\"]\n\n    Input: {\n        \"category\": {\n            \"electronics\": [\"phone\", \"laptop\"],\n            \"clothing\": [\"shirt\", \"pants\"]\n        }\n    }\n    Output: [\n        \"category.electronics.phone\",\n        \"category.electronics.laptop\",\n        \"category.clothing.shirt\",\n        \"category.clothing.pants\"\n    ]\n\n    Input: [\"label1\", \"label2\"]  # Already flat\n    Output: [\"label1\", \"label2\"]\n\n    Args:\n        labels: Either a list of string labels or a hierarchical dict\n        prefix: Current prefix for recursion (internal use)\n        separator: Separator to use between hierarchy levels (default: \".\")\n\n    Returns:\n        List of flattened label strings with dot notation\n    \"\"\"\n    if isinstance(labels, list):\n        if prefix:\n            return [f\"{prefix}{separator}{label}\" for label in labels]\n        return labels\n\n    elif isinstance(labels, dict):\n        flattened = []\n        for key, value in labels.items():\n            new_prefix = f\"{prefix}{separator}{key}\" if prefix else key\n            flattened.extend(flatten_hierarchical_labels(value, new_prefix, separator))\n        return flattened\n\n    elif isinstance(labels, str):\n        if prefix:\n            return [f\"{prefix}{separator}{labels}\"]\n        return [labels]\n\n    else:\n        raise ValueError(f\"Unsupported label type: {type(labels)}. Expected list, dict, or str.\")\n\n\ndef build_hierarchical_output(\n    predictions: List[Dict[str, float]],\n    original_labels: List[str] | Dict[str, Any],\n    separator: str = \".\",\n    all_scores: Dict[str, float] | None = None,\n) -> Dict[str, float] | Dict[str, Any]:\n    \"\"\"\n    Build hierarchical output structure matching the input labels structure.\n\n    Args:\n        predictions: List of prediction dicts with 'label' and 'score'\n        original_labels: Original hierarchical labels structure\n        separator: Separator used in flattened labels\n        all_scores: Optional dict of all label scores (for complete output)\n\n    Returns:\n        Hierarchical structure with scores matching the input format\n\n    Example:\n        Input predictions: [\n            {'label': 'sentiment.positive', 'score': 0.85},\n            {'label': 'topic.product', 'score': 0.72}\n        ]\n        Input original_labels: {\n            \"sentiment\": [\"positive\", \"negative\", \"neutral\"],\n            \"topic\": [\"product\", \"service\", \"shipping\"]\n        }\n        Output: {\n            \"sentiment\": {\"positive\": 0.85, \"negative\": 0.0, \"neutral\": 0.0},\n            \"topic\": {\"product\": 0.72, \"service\": 0.0, \"shipping\": 0.0}\n        }\n    \"\"\"\n    score_lookup = {pred[\"label\"]: pred[\"score\"] for pred in predictions}\n\n    if all_scores:\n        for k, v in all_scores.items():\n            if k not in score_lookup:\n                score_lookup[k] = v\n\n    def _build_recursive(structure: List[str] | Dict[str, Any], prefix: str = \"\") -> Dict[str, float] | Dict[str, Any]:\n        if isinstance(structure, list):\n            result = {}\n            for label in structure:\n                full_label = f\"{prefix}{separator}{label}\" if prefix else label\n                result[label] = score_lookup.get(full_label, 0.0)\n            return result\n\n        elif isinstance(structure, dict):\n            result = {}\n            for key, value in structure.items():\n                new_prefix = f\"{prefix}{separator}{key}\" if prefix else key\n                result[key] = _build_recursive(value, new_prefix)\n            return result\n\n        elif isinstance(structure, str):\n            full_label = f\"{prefix}{separator}{structure}\" if prefix else structure\n            return {structure: score_lookup.get(full_label, 0.0)}\n\n        return {}\n\n    if isinstance(original_labels, list):\n        return {label: score_lookup.get(label, 0.0) for label in original_labels}\n\n    return _build_recursive(original_labels)\n\n\ndef format_examples_prompt(\n    examples: List[Dict[str, Any]], example_token: str = \"<<EXAMPLE>>\", sep_token: str = \"<<SEP>>\"\n) -> str:\n    r\"\"\"\n    Format few-shot examples into a prompt string using <<EXAMPLE>> token.\n\n    Format matches training: <<EXAMPLE>>text \\nLabels:\\n label1, label2\n    with a single <<SEP>> after all examples.\n\n    Args:\n        examples: List of example dicts with 'text' and 'labels'/'true_labels' keys\n        example_token: Token to mark examples (default: \"<<EXAMPLE>>\")\n        sep_token: Separator token after all examples (default: \"<<SEP>>\")\n\n    Returns:\n        Formatted examples string\n    \"\"\"\n    if not examples:\n        return \"\"\n\n    formatted_parts = []\n    for example in examples:\n        text = example.get(\"text\", \"\")\n        labels = example.get(\"labels\", example.get(\"true_labels\", []))\n\n        if isinstance(labels, list):\n            labels_str = \", \".join(labels)\n        else:\n            labels_str = str(labels)\n\n        # Match training format: \" \\nLabels:\\n \" instead of \"\\nLabels: \"\n        formatted_parts.append(f\"{example_token}{text} \\nLabels:\\n {labels_str}\")\n\n    # Add single SEP token after all examples (matching training)\n    formatted_parts.append(sep_token)\n\n    return \"\".join(formatted_parts)\n\n\nclass BaseZeroShotClassificationPipeline(ABC):\n    def __init__(\n        self,\n        model,\n        tokenizer,\n        max_classes=25,\n        max_length=1024,\n        classification_type=\"multi-label\",\n        device=\"cuda:0\",\n        progress_bar=True,\n        label_separator: str = \".\",\n    ):\n        self.model = model\n        if isinstance(tokenizer, str):\n            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)\n        else:\n            self.tokenizer = tokenizer\n        self.max_classes = max_classes\n        self.classification_type = classification_type\n        self.max_length = max_length\n        self.progress_bar = progress_bar\n        self.label_separator = label_separator\n        self._max_labels_alloc = getattr(model.config, \"max_labels_alloc\", \"dynamic\")\n\n        self.example_token = \"<<EXAMPLE>>\"\n        self.label_token = \"<<LABEL>>\"\n        self.sep_token = \"<<SEP>>\"\n\n        if not isinstance(device, torch.device):\n            if torch.cuda.is_available() and \"cuda\" in device:\n                self.device = torch.device(device)\n            else:\n                self.device = torch.device(\"cpu\")\n        else:\n            self.device = device\n\n        if self.model.device != self.device:\n            self.model.to(self.device)\n\n        # Ensure model is in evaluation mode for inference\n        self.model.eval()\n\n    def _normalize_classification_type(self, classification_type: str | None) -> str:\n        if classification_type is None:\n            return self.classification_type\n\n        normalized = classification_type.strip().lower()\n        if normalized in {\"single\", \"single-label\", \"single_label\"}:\n            return \"single-label\"\n        if normalized in {\"multi\", \"multi-label\", \"multi_label\"}:\n            return \"multi-label\"\n        raise ValueError(\"Unsupported classification type: choose 'single-label' or 'multi-label'\")\n\n    def _normalize_texts(self, texts: str | List[str]) -> List[str]:\n        if isinstance(texts, str):\n            return [texts]\n        return texts\n\n    def _normalize_thresholds(self, threshold: float | List[float], num_texts: int) -> List[float]:\n        if isinstance(threshold, list):\n            if len(threshold) != num_texts:\n                raise ValueError(\"Length of threshold list must match number of texts.\")\n            return threshold\n        return [threshold] * num_texts\n\n    def _normalize_classification_types(\n        self,\n        classification_type: str | List[str] | None,\n        num_texts: int,\n    ) -> List[str]:\n        if isinstance(classification_type, list):\n            if len(classification_type) != num_texts:\n                raise ValueError(\"Length of classification_type list must match number of texts.\")\n            return [self._normalize_classification_type(item) for item in classification_type]\n\n        normalized = self._normalize_classification_type(classification_type)\n        return [normalized] * num_texts\n\n    def _process_labels(\n        self, labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]]\n    ) -> List[str] | List[List[str]]:\n        \"\"\"Process labels to handle hierarchical structures.\"\"\"\n        if not labels:\n            return labels\n\n        if isinstance(labels, dict):\n            return flatten_hierarchical_labels(labels, separator=self.label_separator)\n\n        if isinstance(labels, list):\n            if len(labels) == 0:\n                return labels\n\n            first_elem = labels[0]\n\n            if isinstance(first_elem, str):\n                return labels\n\n            if isinstance(first_elem, dict):\n                return [flatten_hierarchical_labels(lbl, separator=self.label_separator) for lbl in labels]\n\n            if isinstance(first_elem, list):\n                if first_elem and isinstance(first_elem[0], dict):\n                    return [flatten_hierarchical_labels(lbl, separator=self.label_separator) for lbl in labels]\n                return labels\n\n        return labels\n\n    def _format_examples_for_input(self, examples: List[Dict[str, Any]] | None = None) -> str:\n        \"\"\"Format few-shot examples using <<EXAMPLE>> and <<SEP>> tokens.\"\"\"\n        if not examples:\n            return \"\"\n        examples = [example for example in examples if example is not None]\n        if not examples:\n            return \"\"\n        return format_examples_prompt(examples, example_token=self.example_token, sep_token=self.sep_token)\n\n    def _examples_are_per_text(self, examples) -> bool:\n        \"\"\"Detect whether examples are provided per text rather than shared.\"\"\"\n        if not isinstance(examples, list) or len(examples) == 0:\n            return False\n        if all(isinstance(example, dict) for example in examples):\n            return False\n        return all(example is None or isinstance(example, list) for example in examples)\n\n    def _get_text_examples(self, examples, index: int):\n        \"\"\"Get examples for a single text from shared or per-text input.\"\"\"\n        if not examples:\n            return None\n        if self._examples_are_per_text(examples):\n            return examples[index] if index < len(examples) else None\n        return examples\n\n    def _format_prompt(self, prompt: str | List[str] | None = None, index: int = 0) -> str:\n        \"\"\"Format the task description prompt.\"\"\"\n        if prompt is None:\n            return \"\"\n\n        if isinstance(prompt, str):\n            return prompt\n\n        if isinstance(prompt, list):\n            if index < len(prompt):\n                return prompt[index]\n            return prompt[0] if prompt else \"\"\n\n        return \"\"\n\n    def _resolve_max_num_classes(self, batch_labels, same_labels: bool):\n        if self._max_labels_alloc == \"dynamic\":\n            return len(batch_labels) if same_labels else max(len(labels) for labels in batch_labels)\n        if isinstance(self._max_labels_alloc, int):\n            return self._max_labels_alloc\n        return None  # 'fixed': model uses config.max_num_classes\n\n    @abstractmethod\n    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):\n        pass\n\n    def _get_batch_examples(self, examples, start_idx, batch_size):\n        \"\"\"Get examples for current batch.\"\"\"\n        if not examples:\n            return None\n        if self._examples_are_per_text(examples):\n            return examples[start_idx : start_idx + batch_size]\n        return examples\n\n    def _get_batch_prompt(self, prompt, start_idx, batch_size):\n        \"\"\"Get prompt for current batch.\"\"\"\n        if not prompt:\n            return None\n        if isinstance(prompt, list):\n            return prompt[start_idx : start_idx + batch_size]\n        return prompt\n\n    @torch.no_grad()\n    def get_embeddings(self, texts, labels, batch_size=8, examples=None, prompt=None):\n        if isinstance(texts, str):\n            texts = [texts]\n\n        labels = self._process_labels(labels)\n\n        if isinstance(labels[0], str):\n            same_labels = True\n        else:\n            same_labels = False\n\n        results = []\n\n        iterable = range(0, len(texts), batch_size)\n        if self.progress_bar:\n            iterable = tqdm(iterable)\n\n        for idx in iterable:\n            batch_texts = texts[idx : idx + batch_size]\n            batch_examples = self._get_batch_examples(examples, idx, len(batch_texts))\n            batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))\n\n            tokenized_inputs = self.prepare_inputs(\n                batch_texts, labels, same_labels, examples=batch_examples, prompt=batch_prompt\n            )\n            max_num_classes = self._resolve_max_num_classes(labels, same_labels)\n            model_output = self.model(\n                **tokenized_inputs,\n                max_num_classes=max_num_classes,\n                output_text_embeddings=True,\n                output_class_embeddings=True,\n            )\n            logits = model_output.logits\n            text_embeddings = model_output.text_embeddings\n            class_embeddings = model_output.class_embeddings\n            batch_size_actual = logits.shape[0]\n\n            for i in range(batch_size_actual):\n                result = {\n                    \"logits\": logits[i].cpu().numpy(),\n                    \"text_embedding\": text_embeddings[i].cpu().numpy(),\n                    \"class_embeddings\": class_embeddings[i].cpu().numpy(),\n                }\n                results.append(result)\n\n        return results\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        texts: str | List[str],\n        labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]],\n        threshold: float | List[float] = 0.5,\n        batch_size: int = 8,\n        classification_type: str | List[str] | None = None,\n        rac_examples: List | None = None,\n        examples: List[Dict[str, Any]] | None = None,\n        prompt: str | List[str] | None = None,\n        return_hierarchical: bool = False,\n    ):\n        \"\"\"\n        Perform zero-shot classification.\n\n        Args:\n            texts: Single text or list of texts to classify\n            labels: Labels in various formats (flat list or hierarchical dict)\n            threshold: Classification threshold for multi-label, either one\n                value for all texts or one value per text\n            batch_size: Batch size for processing\n            classification_type: Override classification mode globally or per text.\n                If None, uses the pipeline's configured classification_type\n            rac_examples: Retrieval augmented examples (legacy)\n            examples: Few-shot examples with 'text' and 'labels'/'true_labels' keys\n            prompt: Task description - string (same for all) or list (per-text)\n            return_hierarchical: If True, return hierarchical structure with all scores\n\n        Returns:\n            List of classification results or hierarchical dict structure.\n        \"\"\"\n        original_labels = labels\n\n        texts = self._normalize_texts(texts)\n        thresholds = self._normalize_thresholds(threshold, len(texts))\n        classification_types = self._normalize_classification_types(classification_type, len(texts))\n\n        if rac_examples:\n            if len(texts) == 1 and not isinstance(rac_examples[0], list):\n                texts = [retrieval_augmented_text(texts[0], rac_examples)]\n            else:\n                texts = [retrieval_augmented_text(text, ex) for text, ex in zip(texts, rac_examples)]\n\n        processed_labels = self._process_labels(labels)\n\n        if isinstance(processed_labels[0], str):\n            same_labels = True\n        else:\n            same_labels = False\n\n        results = []\n        all_scores_list = []\n\n        iterable = range(0, len(texts), batch_size)\n        if self.progress_bar:\n            iterable = tqdm(iterable)\n\n        for idx in iterable:\n            batch_texts = texts[idx : idx + batch_size]\n            if not same_labels:\n                batch_labels = processed_labels[idx : idx + batch_size]\n            else:\n                batch_labels = processed_labels\n\n            batch_examples = self._get_batch_examples(examples, idx, len(batch_texts))\n            batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))\n\n            tokenized_inputs = self.prepare_inputs(\n                batch_texts, batch_labels, same_labels, examples=batch_examples, prompt=batch_prompt\n            )\n            max_num_classes = self._resolve_max_num_classes(batch_labels, same_labels)\n            model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)\n            logits = model_output.logits\n            probs = torch.sigmoid(logits)\n\n            for i in range(len(batch_texts)):\n                global_idx = idx + i\n                item_classification_type = classification_types[global_idx]\n                item_threshold = thresholds[global_idx]\n\n                if same_labels:\n                    curr_labels = batch_labels\n                else:\n                    curr_labels = batch_labels[i]\n\n                if item_classification_type == \"single-label\":\n                    score = torch.softmax(logits[i][: len(curr_labels)], dim=-1)\n\n                    if return_hierarchical:\n                        all_scores = {curr_labels[j]: score[j].item() for j in range(len(curr_labels))}\n                        all_scores_list.append(all_scores)\n\n                    pred_label = curr_labels[torch.argmax(score).item()]\n                    results.append([{\"label\": pred_label, \"score\": score.max().item()}])\n                elif item_classification_type == \"multi-label\":\n                    text_results = []\n\n                    if return_hierarchical:\n                        all_scores = {curr_labels[j]: probs[i][j].item() for j in range(len(curr_labels))}\n                        all_scores_list.append(all_scores)\n\n                    for j, prob in enumerate(probs[i][: len(curr_labels)]):\n                        score = prob.item()\n                        if score >= item_threshold:\n                            text_results.append({\"label\": curr_labels[j], \"score\": score})\n                    results.append(text_results)\n                else:\n                    raise ValueError(\"Unsupported classification type: choose 'single-label' or 'multi-label'\")\n\n        if return_hierarchical:\n            hierarchical_results = []\n            for i, (result, all_scores) in enumerate(zip(results, all_scores_list)):\n                if same_labels:\n                    orig_lbl = original_labels\n                else:\n                    orig_lbl = original_labels[i] if i < len(original_labels) else original_labels\n\n                hierarchical_results.append(\n                    build_hierarchical_output(result, orig_lbl, self.label_separator, all_scores)\n                )\n            return hierarchical_results\n\n        return results\n\n\nclass UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):\n    def __init__(\n        self,\n        model,\n        tokenizer,\n        max_classes=25,\n        max_length=1024,\n        classification_type=\"multi-label\",\n        device=\"cuda:0\",\n        progress_bar=True,\n        label_separator: str = \".\",\n    ):\n        super().__init__(\n            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n        )\n\n    def prepare_input(self, text, labels, examples=None, prompt=None):\n        \"\"\"\n        Prepare input matching training format from data_processing.py:\n        Order: Labels → SEP → Prompt → Text → Examples.\n        \"\"\"\n        input_parts = []\n\n        # 1. Add labels\n        for label in labels:\n            label_tag = f\"{self.label_token}{label}\"\n            input_parts.append(label_tag)\n        input_parts.append(self.sep_token)\n\n        # 2. Add task description prompt\n        if prompt:\n            input_parts.append(prompt)\n\n        # 3. Format examples to go after text\n        examples_str = \"\"\n        if examples:\n            examples_str = self._format_examples_for_input(examples)\n\n        if self.model.config.prompt_first:\n            return \"\".join(input_parts) + text + examples_str\n        else:\n            return text + \"\".join(input_parts) + examples_str\n\n    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):\n        inputs = []\n\n        if same_labels:\n            for i, text in enumerate(texts):\n                text_examples = self._get_text_examples(examples, i)\n                text_prompt = self._format_prompt(prompt, i)\n                inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))\n        else:\n            for i, (text, labels_) in enumerate(zip(texts, labels)):\n                text_examples = self._get_text_examples(examples, i)\n                text_prompt = self._format_prompt(prompt, i)\n                inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))\n\n        tokenized_inputs = self.tokenizer(\n            inputs, truncation=True, max_length=self.max_length, padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n\n        return tokenized_inputs\n\n\nclass EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):\n    def __init__(\n        self,\n        model,\n        tokenizer,\n        max_classes=25,\n        max_length=1024,\n        classification_type=\"multi-label\",\n        device=\"cuda:0\",\n        progress_bar=True,\n        label_separator: str = \".\",\n    ):\n        super().__init__(\n            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n        )\n\n    def prepare_labels_prompt(self, labels, prompt=None):\n        \"\"\"Match training format: Labels → SEP → Prompt.\"\"\"\n        input_parts = []\n\n        for label in labels:\n            # label_tag = f\"{label}{self.label_token}\"\n            label_tag = f\"{self.label_token}{label}\"\n            input_parts.append(label_tag)\n        input_parts.append(self.sep_token)\n\n        if prompt:\n            input_parts.append(prompt)\n\n        return \"\".join(input_parts)\n\n    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):\n        prompts = []\n        processed_texts = []\n\n        if same_labels:\n            for i, text in enumerate(texts):\n                text_examples = self._get_text_examples(examples, i)\n                text_prompt = self._format_prompt(prompt, i)\n                prompts.append(self.prepare_labels_prompt(labels, text_prompt))\n                examples_str = self._format_examples_for_input(text_examples) if text_examples else \"\"\n                processed_texts.append(text + examples_str)\n        else:\n            for i, labels_ in enumerate(labels):\n                text_examples = self._get_text_examples(examples, i)\n                text_prompt = self._format_prompt(prompt, i)\n                prompts.append(self.prepare_labels_prompt(labels_, text_prompt))\n                examples_str = self._format_examples_for_input(text_examples) if text_examples else \"\"\n                processed_texts.append(texts[i] + examples_str)\n\n        tokenized_inputs = self.tokenizer(\n            processed_texts, truncation=True, max_length=self.max_length, padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n\n        tokenized_classes = self.tokenizer(\n            prompts, max_length=self.max_length, truncation=True, padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n\n        tokenized_inputs[\"class_input_ids\"] = tokenized_classes[\"input_ids\"]\n        tokenized_inputs[\"class_attention_mask\"] = tokenized_classes[\"attention_mask\"]\n\n        return tokenized_inputs\n\n\nclass BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):\n    def __init__(\n        self,\n        model,\n        tokenizer,\n        max_classes=25,\n        max_length=1024,\n        classification_type=\"multi-label\",\n        device=\"cuda:0\",\n        progress_bar=True,\n        label_separator: str = \".\",\n    ):\n        super().__init__(\n            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n        )\n        self.labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)\n\n    def prepare_input(self, text, labels, examples=None, prompt=None):\n        input_parts = []\n\n        if prompt:\n            input_parts.append(prompt)\n            input_parts.append(\" \")\n\n        for _label in labels:\n            input_parts.append(self.label_token)\n        input_parts.append(self.sep_token)\n\n        examples_str = \"\"\n        if examples:\n            examples_str = self._format_examples_for_input(examples)\n\n        if self.model.config.prompt_first:\n            return \"\".join(input_parts) + text + examples_str\n        else:\n            return text + \"\".join(input_parts) + examples_str\n\n    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):\n        if self.model.config.architecture_type == \"bi-encoder-fused\":\n            inputs = []\n            if same_labels:\n                for i, text in enumerate(texts):\n                    text_examples = self._get_text_examples(examples, i)\n                    text_prompt = self._format_prompt(prompt, i)\n                    inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))\n            else:\n                for i, (text, labels_) in enumerate(zip(texts, labels)):\n                    text_examples = self._get_text_examples(examples, i)\n                    text_prompt = self._format_prompt(prompt, i)\n                    inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))\n        else:\n            inputs = []\n            for i, text in enumerate(texts):\n                text_prompt = self._format_prompt(prompt, i)\n                if text_prompt:\n                    inputs.append(f\"{text_prompt} {text}\")\n                else:\n                    inputs.append(text)\n\n        if same_labels:\n            tokenized_inputs = self.tokenizer(\n                inputs, truncation=True, max_length=self.max_length, padding=\"longest\", return_tensors=\"pt\"\n            ).to(self.device)\n\n            tokenized_labels = self.labels_tokenizer(\n                labels, truncation=True, max_length=self.max_length, padding=\"longest\", return_tensors=\"pt\"\n            ).to(self.device)\n\n            tokenized_inputs[\"class_input_ids\"] = tokenized_labels[\"input_ids\"].expand(len(texts), -1, -1)\n            tokenized_inputs[\"class_attention_mask\"] = tokenized_labels[\"attention_mask\"].expand(len(texts), -1, -1)\n\n            labels_mask = [[1 for _ in range(len(labels))] for _ in range(len(texts))]\n            tokenized_inputs[\"labels_mask\"] = torch.tensor(labels_mask).to(self.device)\n        else:\n            tokenized_inputs = self.tokenizer(\n                inputs, truncation=True, max_length=self.max_length, padding=\"longest\", return_tensors=\"pt\"\n            ).to(self.device)\n\n            class_input_ids = []\n            class_attention_mask = []\n\n            for labels_set in labels:\n                tokenized_labels = self.labels_tokenizer(\n                    labels_set, truncation=True, max_length=self.max_length, padding=\"max_length\", return_tensors=\"pt\"\n                ).to(self.device)\n                class_input_ids.append(tokenized_labels[\"input_ids\"])\n                class_attention_mask.append(tokenized_labels[\"attention_mask\"])\n\n            tokenized_inputs[\"class_input_ids\"] = torch.stack(class_input_ids)\n            tokenized_inputs[\"class_attention_mask\"] = torch.stack(class_attention_mask)\n\n            labels_mask = [[1 for _ in range(len(labels[j]))] for j in range(len(texts))]\n            tokenized_inputs[\"labels_mask\"] = torch.tensor(labels_mask).to(self.device)\n        return tokenized_inputs\n\n\nclass ZeroShotClassificationPipeline:\n    \"\"\"\n    Main pipeline class for zero-shot classification with GLiClass models.\n\n    Supports:\n    - Hierarchical labels with dot notation (e.g., {\"sentiment\": [\"positive\", \"negative\"]})\n    - Few-shot examples with <<EXAMPLE>> token\n    - Task description prompts\n    - Hierarchical output format matching input structure\n\n    Example usage:\n\n    ```python\n    from gliclass import ZeroShotClassificationPipeline\n\n    pipeline = ZeroShotClassificationPipeline(model, tokenizer)\n\n    # === Hierarchical Labels for Review Classification ===\n    hierarchical_labels = {\n        \"sentiment\": [\"positive\", \"negative\", \"neutral\"],\n        \"topic\": [\"product\", \"service\", \"shipping\"],\n    }\n\n    # Basic classification\n    results = pipeline(\"The product quality is amazing but delivery was slow\", hierarchical_labels)\n    # Results: [\n    #     {'label': 'sentiment.positive', 'score': 0.89},\n    #     {'label': 'topic.product', 'score': 0.92},\n    #     {'label': 'topic.shipping', 'score': 0.76}\n    # ]\n\n    # === With Task Description Prompt ===\n    results = pipeline(\n        \"The product quality is amazing but delivery was slow\",\n        hierarchical_labels,\n        prompt=\"Classify this customer review by sentiment and topic:\",\n    )\n\n    # === With Few-Shot Examples (uses <<EXAMPLE>> token) ===\n    examples = [\n        {\"text\": \"Love this item, great quality!\", \"labels\": [\"sentiment.positive\", \"topic.product\"]},\n        {\"text\": \"Customer support was unhelpful and rude\", \"labels\": [\"sentiment.negative\", \"topic.service\"]},\n        {\"text\": \"Package arrived damaged after 2 weeks\", \"labels\": [\"sentiment.negative\", \"topic.shipping\"]},\n    ]\n\n    results = pipeline(\n        \"Fast delivery and the item works perfectly!\",\n        hierarchical_labels,\n        examples=examples,\n        prompt=\"Classify customer feedback:\",\n    )\n\n    # === Hierarchical Output (matches input structure) ===\n    results = pipeline(\n        \"The product quality is amazing but delivery was slow\", hierarchical_labels, return_hierarchical=True\n    )\n    # Returns:\n    # {\n    #     \"sentiment\": {\n    #         \"positive\": 0.89,\n    #         \"negative\": 0.05,\n    #         \"neutral\": 0.12\n    #     },\n    #     \"topic\": {\n    #         \"product\": 0.92,\n    #         \"service\": 0.15,\n    #         \"shipping\": 0.76\n    #     }\n    # }\n\n    # === Per-Text Prompts ===\n    results = pipeline(\n        [\"Electronics review text\", \"Clothing review text\"],\n        hierarchical_labels,\n        prompt=[\"Analyze this electronics review:\", \"Analyze this clothing review:\"],\n    )\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        tokenizer,\n        max_classes: int = 25,\n        max_length: int = 1024,\n        classification_type: str = \"multi-label\",\n        device: str = \"cuda:0\",\n        progress_bar: bool = True,\n        label_separator: str = \".\",\n    ):\n        \"\"\"\n        Initialize the classification pipeline.\n\n        Args:\n            model: GLiClass model or path to model\n            tokenizer: Tokenizer or path to tokenizer\n            max_classes: Maximum number of classes to process\n            max_length: Maximum sequence length\n            classification_type: 'single-label' or 'multi-label'\n            device: Device to run inference on\n            progress_bar: Whether to show progress bar\n            label_separator: Separator for hierarchical label notation (default: \".\")\n        \"\"\"\n        if isinstance(model, str):\n            model = GLiClassBiEncoder.from_pretrained(model)\n\n        self.label_separator = label_separator\n\n        if model.config.architecture_type == \"uni-encoder\":\n            self.pipe = UniEncoderZeroShotClassificationPipeline(\n                model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n            )\n        elif model.config.architecture_type in {\"encoder-decoder\", \"encoder-decoder-cls\"}:\n            self.pipe = EncoderDecoderZeroShotClassificationPipeline(\n                model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n            )\n        elif model.config.architecture_type in {\"bi-encoder\", \"bi-encoder-fused\"}:\n            self.pipe = BiEncoderZeroShotClassificationPipeline(\n                model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n            )\n        else:\n            raise NotImplementedError(\"This architecture is not implemented\")\n\n    def flatten_labels(self, labels: List[str] | Dict[str, Any]) -> List[str]:\n        \"\"\"\n        Flatten hierarchical labels to dot notation.\n\n        Example:\n            >>> pipeline.flatten_labels(\n            ...     {\"sentiment\": [\"positive\", \"negative\", \"neutral\"], \"topic\": [\"product\", \"service\", \"shipping\"]}\n            ... )\n            [\"sentiment.positive\", \"sentiment.negative\", \"sentiment.neutral\",\n             \"topic.product\", \"topic.service\", \"topic.shipping\"]\n        \"\"\"\n        return flatten_hierarchical_labels(labels, separator=self.label_separator)\n\n    def get_embeddings(self, *args, **kwargs):\n        \"\"\"Get embeddings for texts and labels.\"\"\"\n        return self.pipe.get_embeddings(*args, **kwargs)\n\n    def __call__(\n        self,\n        texts: str | List[str],\n        labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]],\n        threshold: float | List[float] = 0.5,\n        batch_size: int = 8,\n        classification_type: str | List[str] | None = None,\n        rac_examples: List | None = None,\n        examples: List[Dict[str, Any]] | None = None,\n        prompt: str | List[str] | None = None,\n        return_hierarchical: bool = False,\n    ):\n        \"\"\"\n        Perform zero-shot classification.\n\n        Args:\n            texts: Single text or list of texts to classify\n            labels: Labels - flat list or hierarchical dict\n                Examples:\n                - [\"positive\", \"negative\"] - flat labels\n                - {\"sentiment\": [\"positive\", \"negative\"], \"topic\": [\"product\", \"service\"]}\n            threshold: Classification threshold for multi-label, either one\n                value for all texts or one value per text\n            batch_size: Batch size for processing\n            classification_type: Override classification mode globally or per text.\n                If None, uses the pipeline's configured classification_type\n            rac_examples: Retrieval augmented examples (legacy)\n            examples: Few-shot examples, each with 'text' and 'labels' keys\n            prompt: Task description - string or list of strings (per-text)\n            return_hierarchical: If True, return structure matching input labels\n\n        Returns:\n            List of predictions (flat) or hierarchical dicts with all scores\n        \"\"\"\n        return self.pipe(\n            texts,\n            labels,\n            threshold=threshold,\n            batch_size=batch_size,\n            classification_type=classification_type,\n            rac_examples=rac_examples,\n            examples=examples,\n            prompt=prompt,\n            return_hierarchical=return_hierarchical,\n        )\n\n\nclass ZeroShotClassificationWithChunkingPipeline(BaseZeroShotClassificationPipeline):\n    \"\"\"Pipeline with long text chunking support.\"\"\"\n\n    def __init__(\n        self,\n        model,\n        tokenizer,\n        max_classes: int = 25,\n        max_length: int = 1024,\n        classification_type: str = \"multi-label\",\n        device: str = \"cuda:0\",\n        progress_bar: bool = True,\n        text_chunk_size: int = 8192,\n        text_chunk_overlap: int = 256,\n        labels_chunk_size: int = 8,\n        label_separator: str = \".\",\n    ):\n        if isinstance(model, str):\n            model = GLiClassModel.from_pretrained(model)\n        super().__init__(\n            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator\n        )\n\n        self.text_chunk_size = text_chunk_size\n        self.text_chunk_overlap = text_chunk_overlap\n        self.labels_chunk_size = labels_chunk_size\n\n    def chunk_text(self, text, chunk_size=None, overlap=None):\n        \"\"\"Split text into overlapping chunks.\"\"\"\n        if chunk_size is None:\n            chunk_size = self.text_chunk_size\n        if overlap is None:\n            overlap = self.text_chunk_overlap\n\n        if len(text) <= chunk_size:\n            return [text]\n\n        chunks = []\n        start = 0\n        while start < len(text):\n            end = start + chunk_size\n            chunk = text[start:end]\n            chunks.append(chunk)\n\n            if end >= len(text):\n                break\n\n            start = end - overlap\n\n        return chunks\n\n    def prepare_input(self, text, labels, examples=None, prompt=None):\n        \"\"\"\n        Prepare input matching training format from data_processing.py:\n        Order: Labels → SEP → Prompt → Text → Examples.\n        \"\"\"\n        input_parts = []\n\n        # 1. Add labels\n        for label in labels:\n            label_tag = f\"{self.label_token}{label}\"\n            input_parts.append(label_tag)\n        input_parts.append(self.sep_token)\n\n        # 2. Add task description prompt\n        if prompt:\n            input_parts.append(prompt)\n\n        # 3. Format examples to go after text\n        examples_str = \"\"\n        if examples:\n            examples_str = self._format_examples_for_input(examples)\n\n        if self.model.config.prompt_first:\n            return \"\".join(input_parts) + text + examples_str\n        else:\n            return text + \"\".join(input_parts) + examples_str\n\n    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):\n        inputs = []\n\n        if same_labels:\n            for i, text in enumerate(texts):\n                text_examples = self._get_text_examples(examples, i)\n                text_prompt = self._format_prompt(prompt, i)\n                inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))\n        else:\n            for i, (text, labels_) in enumerate(zip(texts, labels)):\n                text_examples = self._get_text_examples(examples, i)\n                text_prompt = self._format_prompt(prompt, i)\n                inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))\n\n        tokenized_inputs = self.tokenizer(\n            inputs, truncation=True, max_length=self.max_length, padding=\"longest\", return_tensors=\"pt\"\n        ).to(self.device)\n        return tokenized_inputs\n\n    def aggregate_chunk_scores(self, chunk_scores: List[Dict[str, float]], labels: List[str]) -> Dict[str, float]:\n        \"\"\"Aggregate scores across text chunks using max pooling.\"\"\"\n        aggregated = dict.fromkeys(labels, 0.0)\n\n        for scores in chunk_scores:\n            for label, score in scores.items():\n                aggregated[label] = max(aggregated[label], score)\n\n        return aggregated\n\n    @torch.no_grad()\n    def process_single_text(self, text, labels, threshold=0.5, examples=None, prompt=None):\n        \"\"\"Process a single long text through chunks.\"\"\"\n        text_chunks = self.chunk_text(text)\n\n        all_chunk_scores = []\n\n        for text_chunk in text_chunks:\n            chunk_logits = []\n            all_labels = []\n\n            for labels_idx in range(0, len(labels), self.labels_chunk_size):\n                curr_labels = labels[labels_idx : labels_idx + self.labels_chunk_size]\n                if labels_idx == 0:\n                    all_labels = []\n                all_labels.extend(curr_labels)\n\n                tokenized_inputs = self.prepare_inputs(\n                    [text_chunk], curr_labels, same_labels=True, examples=examples, prompt=prompt\n                )\n                max_num_classes = self._resolve_max_num_classes(curr_labels, same_labels=True)\n                model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)\n                logits = model_output.logits\n\n                chunk_logits.extend(logits[0][: len(curr_labels)].tolist())\n\n            text_logits = torch.tensor(chunk_logits)\n\n            if self.classification_type == \"single-label\":\n                scores = torch.softmax(text_logits, dim=-1)\n            else:\n                scores = torch.sigmoid(text_logits)\n\n            chunk_score_dict = {label: scores[i].item() for i, label in enumerate(all_labels)}\n            all_chunk_scores.append(chunk_score_dict)\n\n        aggregated_scores = self.aggregate_chunk_scores(all_chunk_scores, labels)\n\n        if self.classification_type == \"single-label\":\n            total = sum(aggregated_scores.values())\n            if total > 0:\n                aggregated_scores = {k: v / total for k, v in aggregated_scores.items()}\n\n            best_label = max(aggregated_scores, key=aggregated_scores.get)\n            return [{\"label\": best_label, \"score\": aggregated_scores[best_label]}], aggregated_scores\n\n        else:\n            text_results = []\n            for label, score in aggregated_scores.items():\n                if score >= threshold:\n                    text_results.append({\"label\": label, \"score\": score})\n            text_results.sort(key=lambda x: x[\"score\"], reverse=True)\n            return text_results, aggregated_scores\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        texts,\n        labels,\n        threshold=0.5,\n        batch_size=8,\n        labels_chunk_size=None,\n        text_chunk_size=None,\n        text_chunk_overlap=None,\n        rac_examples=None,\n        examples=None,\n        prompt=None,\n        return_hierarchical: bool = False,\n    ):\n        \"\"\"Classification with chunking for long texts.\"\"\"\n        original_labels = labels\n\n        if labels_chunk_size is not None:\n            self.labels_chunk_size = labels_chunk_size\n        if text_chunk_size is not None:\n            self.text_chunk_size = text_chunk_size\n        if text_chunk_overlap is not None:\n            self.text_chunk_overlap = text_chunk_overlap\n\n        if isinstance(texts, str):\n            if rac_examples:\n                texts = retrieval_augmented_text(texts, rac_examples)\n            texts = [texts]\n        elif rac_examples:\n            texts = [retrieval_augmented_text(text, ex) for text, ex in zip(texts, rac_examples)]\n\n        labels = self._process_labels(labels)\n\n        short_texts, short_indices = [], []\n        long_texts, long_indices = [], []\n\n        for i, text in enumerate(texts):\n            if len(text) <= self.text_chunk_size:\n                short_texts.append(text)\n                short_indices.append(i)\n            else:\n                long_texts.append(text)\n                long_indices.append(i)\n\n        results = [None] * len(texts)\n        all_scores_list = [None] * len(texts)\n\n        if short_texts:\n            iterable = range(0, len(short_texts), batch_size)\n            if self.progress_bar:\n                iterable = tqdm(iterable, desc=\"Processing short texts\")\n\n            for idx in iterable:\n                batch_texts = short_texts[idx : idx + batch_size]\n                batch_indices = short_indices[idx : idx + batch_size]\n\n                all_logits = [[] for _ in range(len(batch_texts))]\n                all_labels = []\n\n                for labels_idx in range(0, len(labels), self.labels_chunk_size):\n                    curr_labels = labels[labels_idx : labels_idx + self.labels_chunk_size]\n                    if labels_idx == 0:\n                        all_labels = []\n                    all_labels.extend(curr_labels)\n\n                    batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))\n                    tokenized_inputs = self.prepare_inputs(\n                        batch_texts, curr_labels, same_labels=True, examples=examples, prompt=batch_prompt\n                    )\n                    max_num_classes = self._resolve_max_num_classes(curr_labels, same_labels=True)\n                    model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)\n                    logits = model_output.logits\n\n                    for i in range(len(batch_texts)):\n                        all_logits[i].extend(logits[i][: len(curr_labels)].tolist())\n\n                for i, orig_idx in enumerate(batch_indices):\n                    text_logits = torch.tensor(all_logits[i])\n\n                    if self.classification_type == \"single-label\":\n                        score = torch.softmax(text_logits, dim=-1)\n                        pred_idx = torch.argmax(score).item()\n                        pred_label = all_labels[pred_idx]\n                        results[orig_idx] = [{\"label\": pred_label, \"score\": score[pred_idx].item()}]\n                        all_scores_list[orig_idx] = {all_labels[j]: score[j].item() for j in range(len(all_labels))}\n\n                    elif self.classification_type == \"multi-label\":\n                        probs = torch.sigmoid(text_logits)\n                        text_results = []\n                        all_scores_list[orig_idx] = {all_labels[j]: probs[j].item() for j in range(len(all_labels))}\n                        for j, prob in enumerate(probs):\n                            score_val = prob.item()\n                            if score_val >= threshold:\n                                text_results.append({\"label\": all_labels[j], \"score\": score_val})\n                        text_results.sort(key=lambda x: x[\"score\"], reverse=True)\n                        results[orig_idx] = text_results\n\n        if long_texts:\n            iterable = range(len(long_texts))\n            if self.progress_bar:\n                iterable = tqdm(iterable, desc=\"Processing long texts\")\n\n            for i in iterable:\n                text = long_texts[i]\n                orig_idx = long_indices[i]\n                text_prompt = self._format_prompt(prompt, orig_idx)\n                text_results, all_scores = self.process_single_text(\n                    text, labels, threshold, examples=examples, prompt=text_prompt\n                )\n                results[orig_idx] = text_results\n                all_scores_list[orig_idx] = all_scores\n\n        if return_hierarchical:\n            hierarchical_results = []\n            for result, all_scores in zip(results, all_scores_list):\n                hierarchical_results.append(\n                    build_hierarchical_output(result, original_labels, self.label_separator, all_scores)\n                )\n            return hierarchical_results\n\n        return results\n\n\n# Utility functions\n\n\ndef parse_hierarchical_prediction(prediction: str, separator: str = \".\") -> Dict[str, Any]:\n    \"\"\"Parse dot-notation prediction into hierarchy levels.\"\"\"\n    parts = prediction.split(separator)\n    result = {\"full\": prediction}\n    for i, part in enumerate(parts):\n        result[f\"level_{i}\"] = part\n    return result\n\n\ndef group_predictions_by_hierarchy(\n    predictions: List[Dict[str, Any]], separator: str = \".\"\n) -> Dict[str, List[Dict[str, Any]]]:\n    \"\"\"Group predictions by top-level category.\"\"\"\n    grouped = {}\n    for pred in predictions:\n        label = pred[\"label\"]\n        parts = label.split(separator)\n        top_level = parts[0] if parts else label\n\n        if top_level not in grouped:\n            grouped[top_level] = []\n        grouped[top_level].append(pred)\n\n    for key in grouped:\n        grouped[key].sort(key=lambda x: x[\"score\"], reverse=True)\n\n    return grouped\n\n\ndef get_best_per_category(predictions: List[Dict[str, Any]], separator: str = \".\") -> Dict[str, Dict[str, Any]]:\n    \"\"\"Get best prediction per top-level category.\"\"\"\n    grouped = group_predictions_by_hierarchy(predictions, separator)\n    return {category: preds[0] for category, preds in grouped.items() if preds}\n"
  },
  {
    "path": "gliclass/poolings.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass GlobalMaxPooling1D(nn.Module):\n    \"\"\"Applies Global Max Pooling on the timesteps dimension.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        return x.amax(dim=1)\n\n\nclass FirstTokenPooling1D(nn.Module):\n    \"\"\"Takes the first token's embedding.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        return x[:, 0, :]\n\n\nclass LastTokenPooling1D(nn.Module):\n    \"\"\"Takes the last token's embedding.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        return x[:, -1, :]\n\n\nclass GlobalAvgPooling1D(nn.Module):\n    \"\"\"Applies Global Average Pooling on the timesteps dimension.\"\"\"\n\n    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None):\n        if attention_mask is not None:\n            attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(dtype=x.dtype)\n            x = x * attention_mask\n            return x.sum(1) / attention_mask.sum(1)\n        else:\n            return x.mean(dim=1)\n\n\nclass GlobalSumPooling1D(nn.Module):\n    \"\"\"Applies Global Sum Pooling on the timesteps dimension.\"\"\"\n\n    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None):\n        if attention_mask is not None:\n            x = x * attention_mask\n        return x.sum(dim=1)\n\n\nclass GlobalRMSPooling1D(nn.Module):\n    \"\"\"Applies Global RMS Pooling on the timesteps dimension.\"\"\"\n\n    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None):\n        if attention_mask is not None:\n            attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(dtype=x.dtype)\n            x = x * attention_mask\n            return (x.pow(2).sum(dim=1) / attention_mask.sum(1)).sqrt()\n        else:\n            return x.pow(2).mean(dim=1).sqrt()\n\n\nclass GlobalAbsMaxPooling1D(nn.Module):\n    \"\"\"Applies Global Max Pooling of absolute values on the timesteps dimension.\"\"\"\n\n    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None):\n        if attention_mask is not None:\n            attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(dtype=x.dtype)\n            x = x * attention_mask\n        return x.abs().amax(dim=1)\n\n\nclass GlobalAbsAvgPooling1D(nn.Module):\n    \"\"\"Applies Global Average Pooling of absolute values on the timesteps dimension.\"\"\"\n\n    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None):\n        if attention_mask is not None:\n            attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(dtype=x.dtype)\n            x = (x * attention_mask).abs()\n            return x.sum(dim=1) / attention_mask.sum(1)\n        else:\n            return x.abs().mean(dim=1)\n\n\nclass PassPooling1D(nn.Module):\n    \"\"\"Passes the input through without pooling.\"\"\"\n\n    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None):\n        return x\n\n\nPOOLING2OBJECT = {\n    \"max\": GlobalMaxPooling1D,\n    \"first\": FirstTokenPooling1D,\n    \"last\": LastTokenPooling1D,\n    \"avg\": GlobalAvgPooling1D,\n    \"sum\": GlobalSumPooling1D,\n    \"rms\": GlobalRMSPooling1D,\n    \"abs_max\": GlobalAbsMaxPooling1D,\n    \"abs_avg\": GlobalAbsAvgPooling1D,\n    \"pass\": PassPooling1D,\n}\n"
  },
  {
    "path": "gliclass/scorers.py",
    "content": "import torch\nfrom torch import nn\n\nfrom .ops import attn_padded\n\n\nclass ScorerWeightedDot(nn.Module):\n    def __init__(self, hidden_size, dropout=0.1, **kwargs):\n        super().__init__()\n\n        self.proj_text = nn.Linear(hidden_size, hidden_size * 2)\n        self.proj_label = nn.Linear(hidden_size, hidden_size * 2)\n\n        self.out_mlp = nn.Sequential(\n            nn.Linear(hidden_size * 3, hidden_size * 4),\n            nn.Dropout(dropout),\n            nn.ReLU(),\n            nn.Linear(hidden_size * 4, 1),  # start, end, score\n        )\n\n    def forward(self, text_rep, label_rep, **kwargs):\n        batch_size, hidden_size = text_rep.shape\n        num_classes = label_rep.shape[1]\n\n        # (batch_size, 1, 3, hidden_size)\n        text_rep = self.proj_text(text_rep).view(batch_size, 1, 1, 2, hidden_size)\n        label_rep = self.proj_label(label_rep).view(batch_size, 1, num_classes, 2, hidden_size)\n\n        # (2, batch_size, 1, num_classes, hidden_size)\n        text_rep = text_rep.expand(-1, -1, num_classes, -1, -1).permute(3, 0, 1, 2, 4)\n        label_rep = label_rep.expand(-1, 1, -1, -1, -1).permute(3, 0, 1, 2, 4)\n\n        # (batch_size, 1, num_classes, hidden_size * 3)\n        cat = torch.cat([text_rep[0], label_rep[0], text_rep[1] * label_rep[1]], dim=-1)\n\n        # (batch_size, num_classes)\n        scores = self.out_mlp(cat).view(batch_size, num_classes)\n\n        return scores\n\n\nclass ScorerDot(nn.Module):\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n        pass\n\n    def forward(self, text_rep, label_rep, **kwargs):\n        # dot product with einsum\n        scores = torch.einsum(\"BD,BCD->BC\", text_rep, label_rep)\n        return scores\n\n\nclass MLPScorer(nn.Module):\n    def __init__(self, hidden_size, mlp_hidden_size=256, **kwargs):\n        super().__init__()\n\n        # Calculate the input size for the MLP\n        total_input_size = hidden_size * 2\n\n        # Define the MLP\n        self.mlp = nn.Sequential(\n            nn.Linear(total_input_size, mlp_hidden_size),\n            nn.ReLU(),\n            nn.Linear(mlp_hidden_size, mlp_hidden_size // 2),\n            nn.ReLU(),\n            nn.Linear(mlp_hidden_size // 2, 1),\n        )\n\n    def forward(self, text_rep, label_rep, **kwargs):\n        # Concatenate text and label representations\n        batch_size, num_labels, dim = label_rep.shape\n        text_rep = text_rep.unsqueeze(1).expand(batch_size, num_labels, dim)\n        combined_rep = torch.cat([text_rep, label_rep], dim=-1)\n\n        # Pass through MLP\n        scores = self.mlp(combined_rep).squeeze(-1)\n\n        return scores\n\n\nclass HopfieldScorer(nn.Module):\n    def __init__(self, hidden_size, mlp_hidden_size=256, beta=4, num_iteration=1, **kwargs):\n        super().__init__()\n\n        self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)\n        self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)\n\n        # Define the MLP\n        self.mlp = nn.Sequential(\n            nn.Linear(hidden_size, mlp_hidden_size),\n            nn.ReLU(),\n            nn.Linear(mlp_hidden_size, mlp_hidden_size // 2),\n            nn.ReLU(),\n            nn.Linear(mlp_hidden_size // 2, 1),\n        )\n\n        self.beta = beta\n        self.num_iteration = num_iteration\n\n    def forward(self, text_rep, label_rep, **kwargs):\n        \"\"\"\n        text_rep: [batch_size, hidden_size]\n        label_rep: [batch_size, num_labels, hidden_size].\n        \"\"\"\n        for _i in range(self.num_iteration):\n            # Expand text_rep to match label_rep's batch shape\n            text_rep_expanded = text_rep.unsqueeze(1)  # [batch_size, 1, dim]\n\n            # Compute Q, K, V\n            query = self.q_proj(label_rep)  # [batch_size, num_labels, dim]\n            key = self.k_proj(text_rep_expanded)  # [batch_size, 1, dim]\n            value = self.v_proj(text_rep_expanded)  # [batch_size, 1, dim]\n\n            attn = torch.bmm(query, key.transpose(1, 2))  # [b, num_labels, 1]\n            attn = attn * self.beta  # optional beta scaling\n            attn = torch.nn.functional.softmax(attn, dim=1)  # softmax over labels\n\n            context = attn * value  # [b, num_labels, dim]\n\n            label_rep = label_rep + context\n\n        scores = self.mlp(label_rep).squeeze(-1)  # [b, num_labels]\n\n        return scores\n\n\nclass CrossAttnScorer(nn.Module):\n    def __init__(self, hidden_size, num_heads=16, attn_dropout=0.1, scorer_mlp_hidden_size=1024, **kwargs):\n        super().__init__()\n        assert hidden_size % num_heads == 0, f\"hidden_size {hidden_size} must be divisible by num_heads {num_heads}\"\n        self.num_heads = num_heads\n        self.head_dim = hidden_size // num_heads\n        self.attn_dropout = attn_dropout\n\n        self.q_norm = nn.LayerNorm(hidden_size)\n        self.kv_norm = nn.LayerNorm(hidden_size)\n\n        self.q = nn.Linear(hidden_size, hidden_size)\n        self.k = nn.Linear(hidden_size, hidden_size)\n        self.v = nn.Linear(hidden_size, hidden_size)\n        self.out = nn.Linear(hidden_size, hidden_size)\n\n        self.norm = nn.LayerNorm(hidden_size)\n\n        self.score_mlp = nn.Sequential(\n            nn.Linear(hidden_size * 2, scorer_mlp_hidden_size),\n            nn.GELU(),\n            nn.Linear(scorer_mlp_hidden_size, scorer_mlp_hidden_size // 2),\n            nn.GELU(),\n            nn.Linear(scorer_mlp_hidden_size // 2, 1),\n        )\n\n    def forward(self, text_rep, label_rep, text_mask=None, **kwargs):\n        batch_size, _, hidden_size = text_rep.shape\n        num_labels = label_rep.shape[1]\n\n        if text_mask is None:\n            text_mask = torch.ones(batch_size, text_rep.shape[1], dtype=torch.bool, device=text_rep.device)\n\n        q = self.q(self.q_norm(label_rep)).view(batch_size, num_labels, self.num_heads, self.head_dim)\n        k = self.k(self.kv_norm(text_rep)).view(batch_size, -1, self.num_heads, self.head_dim)\n        v = self.v(text_rep).view(batch_size, -1, self.num_heads, self.head_dim)\n\n        dropout_p = self.attn_dropout if self.training else 0.0\n        context = attn_padded(q, k, v, key_padding_mask=text_mask, dropout_p=dropout_p)\n        context = self.norm(self.out(context.reshape(batch_size, num_labels, hidden_size)))\n\n        return self.score_mlp(torch.cat([context, label_rep], dim=-1)).squeeze(-1)\n\n\nSCORER2OBJECT = {\n    \"weighted-dot\": ScorerWeightedDot,\n    \"simple\": ScorerDot,\n    \"mlp\": MLPScorer,\n    \"hopfield\": HopfieldScorer,\n    \"cross-attn\": CrossAttnScorer,\n}\n"
  },
  {
    "path": "gliclass/serve/__init__.py",
    "content": "\"\"\"GLiClass serving module.\"\"\"\n\nfrom .client import GLiClassClient\nfrom .config import GLiClassServeConfig\nfrom .memory import GLiClassMemoryEstimator\nfrom .server import GLiClassServer, GLiClassFactory, shutdown, serve_gliclass\n\n__all__ = [\n    \"GLiClassClient\",\n    \"GLiClassFactory\",\n    \"GLiClassMemoryEstimator\",\n    \"GLiClassServeConfig\",\n    \"GLiClassServer\",\n    \"serve_gliclass\",\n    \"shutdown\",\n]\n"
  },
  {
    "path": "gliclass/serve/__main__.py",
    "content": "\"\"\"CLI entry point for GLiClass serving.\"\"\"\n\nimport sys\nimport signal\nimport logging\nimport argparse\n\nimport ray\nfrom ray import serve\n\nfrom .config import GLiClassServeConfig\nfrom .server import serve_gliclass\n\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n)\nlogger = logging.getLogger(__name__)\n\n\ndef main():\n    \"\"\"Main entry point for GLiClass serving.\"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"GLiClass Ray Serve deployment\",\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter,\n    )\n\n    # Config file\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=None,\n        help=\"Path to YAML config file (CLI args override config values)\",\n    )\n\n    # Model configuration\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=None,\n        help=\"Model name or path\",\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=None,\n        help=\"Device to run on (cuda or cpu)\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=None,\n        choices=[\"float32\", \"float16\", \"bfloat16\"],\n        help=\"Data type for model weights\",\n    )\n    parser.add_argument(\n        \"--max-model-len\",\n        type=int,\n        default=None,\n        help=\"Maximum sequence length\",\n    )\n    parser.add_argument(\n        \"--max-batch-size\",\n        type=int,\n        default=None,\n        help=\"Maximum batch size\",\n    )\n    parser.add_argument(\n        \"--max-labels\",\n        type=int,\n        default=None,\n        help=\"Maximum number of labels (-1 for unlimited)\",\n    )\n    parser.add_argument(\n        \"--max-labels-alloc\",\n        type=str,\n        default=None,\n        help='Label memory allocation: \"dynamic\", \"fixed\", or integer (e.g., \"50\")',\n    )\n\n    # Server configuration\n    parser.add_argument(\n        \"--host\",\n        type=str,\n        default=\"0.0.0.0\",\n        help=\"Host to bind to\",\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        default=None,\n        help=\"Port to bind to\",\n    )\n    parser.add_argument(\n        \"--route-prefix\",\n        type=str,\n        default=None,\n        help=\"HTTP route prefix\",\n    )\n    parser.add_argument(\n        \"--num-replicas\",\n        type=int,\n        default=None,\n        help=\"Number of model replicas\",\n    )\n\n    # Performance configuration\n    parser.add_argument(\n        \"--calibrate-on-startup\",\n        action=\"store_true\",\n        default=None,\n        help=\"Run memory calibration on startup\",\n    )\n    parser.add_argument(\n        \"--precompile-on-startup\",\n        action=\"store_true\",\n        default=None,\n        help=\"Precompile model on startup\",\n    )\n    parser.add_argument(\n        \"--use-memory-aware-batching\",\n        action=\"store_true\",\n        default=None,\n        help=\"Use memory-aware dynamic batching\",\n    )\n    parser.add_argument(\n        \"--enable-compilation\",\n        action=\"store_true\",\n        default=None,\n        help=\"Enable torch.compile\",\n    )\n    parser.add_argument(\n        \"--tokenizer-threads\",\n        type=int,\n        default=None,\n        help=\"Number of tokenizer threads\",\n    )\n\n    # Calibration configuration\n    parser.add_argument(\n        \"--calibration-min-batch-size\",\n        type=int,\n        default=None,\n        help=\"Minimum batch size for calibration\",\n    )\n    parser.add_argument(\n        \"--calibration-max-batch-size\",\n        type=int,\n        default=None,\n        help=\"Maximum batch size for calibration\",\n    )\n    parser.add_argument(\n        \"--calibration-min-seq-len\",\n        type=int,\n        default=None,\n        help=\"Minimum sequence length for calibration\",\n    )\n\n    args = parser.parse_args()\n\n    if args.config:\n        logger.info(f\"Loading config from: {args.config}\")\n        config = GLiClassServeConfig.from_yaml(args.config)\n    else:\n        config = GLiClassServeConfig(model=args.model or \"knowledgator/gliclass-edge-v3.0\")\n\n    # Convert max_labels_alloc to int if it's a digit string\n    max_labels_alloc_value = args.max_labels_alloc\n    if max_labels_alloc_value and max_labels_alloc_value.isdigit():\n        max_labels_alloc_value = int(max_labels_alloc_value)\n\n    cli_overrides = {\n        \"model\": args.model,\n        \"device\": args.device,\n        \"dtype\": args.dtype,\n        \"max_model_len\": args.max_model_len,\n        \"max_batch_size\": args.max_batch_size,\n        \"max_labels\": args.max_labels,\n        \"max_labels_alloc\": max_labels_alloc_value,\n        \"http_port\": args.port,\n        \"route_prefix\": args.route_prefix,\n        \"num_replicas\": args.num_replicas,\n        \"calibrate_on_startup\": args.calibrate_on_startup,\n        \"precompile_on_startup\": args.precompile_on_startup,\n        \"use_memory_aware_batching\": args.use_memory_aware_batching,\n        \"enable_compilation\": args.enable_compilation,\n        \"tokenizer_threads\": args.tokenizer_threads,\n        \"calibration_min_batch_size\": args.calibration_min_batch_size,\n        \"calibration_max_batch_size\": args.calibration_max_batch_size,\n        \"calibration_min_seq_len\": args.calibration_min_seq_len,\n    }\n    config.update(**cli_overrides)\n\n    logger.info(\"=\" * 60)\n    logger.info(\"GLiClass Serve Configuration:\")\n    logger.info(f\"  Model: {config.model}\")\n    logger.info(f\"  Device: {config.device}\")\n    logger.info(f\"  Dtype: {config.dtype}\")\n    logger.info(f\"  Max model length: {config.max_model_len}\")\n    logger.info(f\"  Max batch size: {config.max_batch_size}\")\n    logger.info(f\"  Max labels: {config.max_labels}\")\n    logger.info(f\"  Max labels alloc: {config.max_labels_alloc}\")\n    logger.info(f\"  HTTP port: {config.http_port}\")\n    logger.info(f\"  Route prefix: {config.route_prefix}\")\n    logger.info(f\"  Num replicas: {config.num_replicas}\")\n    logger.info(f\"  Calibrate on startup: {config.calibrate_on_startup}\")\n    logger.info(f\"  Precompile on startup: {config.precompile_on_startup}\")\n    logger.info(f\"  Memory-aware batching: {config.use_memory_aware_batching}\")\n    logger.info(\"=\" * 60)\n\n    logger.info(\"Initializing Ray...\")\n    ray.init(ignore_reinit_error=True)\n\n    logger.info(\"Starting Ray Serve...\")\n    serve.start(http_options={\"host\": args.host, \"port\": config.http_port})\n\n    logger.info(f\"Deploying GLiClass with model: {config.model}\")\n    _app = serve_gliclass(config, blocking=False)  # Keep reference to prevent GC\n\n    logger.info(f\"GLiClass server running at http://{args.host}:{config.http_port}{config.route_prefix}\")\n    logger.info(\"Press Ctrl+C to stop the server\")\n\n    def signal_handler(_sig, _frame):\n        logger.info(\"Shutting down server...\")\n        serve.shutdown()\n        ray.shutdown()\n        sys.exit(0)\n\n    signal.signal(signal.SIGINT, signal_handler)\n    signal.signal(signal.SIGTERM, signal_handler)\n\n    import time\n\n    while True:\n        time.sleep(1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "gliclass/serve/client.py",
    "content": "\"\"\"Client for GLiClass serving endpoint.\"\"\"\n\nimport requests\n\n\nclass GLiClassClient:\n    \"\"\"Client for interacting with GLiClass Ray Serve deployment.\"\"\"\n\n    def __init__(self, url: str = \"http://localhost:8000/gliclass\"):\n        \"\"\"Initialize the client.\n\n        Args:\n            url: Base URL of the GLiClass server\n        \"\"\"\n        self.url = url.rstrip(\"/\")\n\n    def __call__(\n        self,\n        texts: str | list[str],\n        labels: list[str] | list[list[str]],\n        threshold: float = 0.5,\n        multi_label: bool = True,\n        examples: list[dict] | None = None,\n        prompt: str | list[str] | None = None,\n    ) -> list[list[dict]]:\n        \"\"\"Classify text(s) - same interface as pipeline.\n\n        Args:\n            texts: Single text or list of texts to classify\n            labels: List of labels (same for all) or list of label lists (per text)\n            threshold: Confidence threshold for predictions\n            multi_label: Whether to enable multi-label classification\n            examples: Optional list of example classifications\n            prompt: Optional task description prompt (string or list)\n\n        Returns:\n            List of results, one per text. Each result is a list of {\"label\": ..., \"score\": ...}\n        \"\"\"\n        payload = {\n            \"texts\": texts,\n            \"labels\": labels,\n            \"threshold\": threshold,\n            \"multi_label\": multi_label,\n        }\n        if examples is not None:\n            payload[\"examples\"] = examples\n        if prompt is not None:\n            payload[\"prompt\"] = prompt\n\n        response = requests.post(self.url, json=payload, timeout=30)\n        response.raise_for_status()\n\n        return response.json()\n\n    def classify(\n        self,\n        text: str,\n        labels: list[str],\n        threshold: float = 0.5,\n        multi_label: bool = True,\n        examples: list[dict] | None = None,\n        prompt: str | None = None,\n    ) -> list[dict]:\n        \"\"\"Classify a single text (convenience method).\n\n        Args:\n            text: Input text to classify\n            labels: List of possible labels\n            threshold: Confidence threshold for predictions\n            multi_label: Whether to enable multi-label classification\n            examples: Optional list of example classifications\n            prompt: Optional task description prompt\n\n        Returns:\n            List of predictions: [{\"label\": ..., \"score\": ...}, ...]\n        \"\"\"\n        results = self(text, labels, threshold, multi_label, examples, prompt)\n        return results[0]\n\n    def health_check(self) -> bool:\n        \"\"\"Check if the server is healthy.\n\n        Returns:\n            True if server is healthy, False otherwise\n        \"\"\"\n        try:\n            response = requests.get(f\"{self.url}/-/healthz\", timeout=5)\n            return response.status_code == 200\n        except requests.RequestException:\n            return False\n\n\nif __name__ == \"__main__\":\n    client = GLiClassClient()\n\n    # Single text\n    result = client.classify(\n        text=\"This is a great product! I love it.\",\n        labels=[\"positive\", \"negative\", \"neutral\"],\n        threshold=0.3,\n    )\n    print(\"Single prediction:\", result)\n\n    # Batch\n    results = client(\n        texts=[\"Great product!\", \"Terrible experience\"],\n        labels=[\"positive\", \"negative\", \"neutral\"],\n        threshold=0.3,\n    )\n    print(\"Batch predictions:\", results)\n"
  },
  {
    "path": "gliclass/serve/config.py",
    "content": "\"\"\"Configuration for GLiClass Ray Serve deployment.\"\"\"\n\nfrom pathlib import Path\nfrom dataclasses import field, asdict, dataclass\n\nimport yaml\n\n\n@dataclass\nclass GLiClassServeConfig:\n    \"\"\"Configuration for GLiClass Ray Serve deployment.\n\n    This config controls model loading, serving parameters, and dynamic batching behavior.\n    \"\"\"\n\n    model: str\n    device: str = \"cuda\"\n    dtype: str = \"bfloat16\"\n\n    quantization: str | None = None\n\n    max_model_len: int = 2048\n    max_labels: int = -1\n    max_labels_alloc: str | int = \"dynamic\"\n\n    default_threshold: float = 0.5\n\n    num_replicas: int = 1\n    num_gpus_per_replica: float = 1.0\n    num_cpus_per_replica: float = 1.0\n\n    max_batch_size: int = 32\n    batch_wait_timeout_ms: float = 20.0\n    request_timeout_s: float = 30.0\n    max_ongoing_requests: int = 256\n    queue_capacity: int = 4096\n\n    route_prefix: str = \"/gliclass\"\n\n    tokenizer_threads: int = 4\n\n    enable_compilation: bool = True\n    calibrate_on_startup: bool = False\n    precompile_on_startup: bool = True\n    use_memory_aware_batching: bool = False\n\n    precompiled_batch_sizes: list[int] = field(default_factory=lambda: [1, 2, 4, 8, 16, 32])\n\n    target_memory_fraction: float = 0.8\n    memory_overhead_factor: float = 1.3\n\n    calibration_min_seq_len: int = 64\n    calibration_min_batch_size: int = 1\n    calibration_max_batch_size: int = 64\n    calibration_probe_batch_size: int = 2\n\n    warmup_iterations: int = 3\n\n    http_port: int = 8000\n\n    ray_address: str | None = None\n\n    def __post_init__(self):\n        if self.max_batch_size not in self.precompiled_batch_sizes:\n            self.precompiled_batch_sizes = sorted(set(self.precompiled_batch_sizes) | {self.max_batch_size})\n        self.precompiled_batch_sizes = sorted(self.precompiled_batch_sizes)\n\n    def to_env_vars(self) -> dict:\n        \"\"\"Convert config to environment variables for model loading.\"\"\"\n        env = {}\n        if self.tokenizer_threads > 0:\n            env[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n        return env\n\n    @classmethod\n    def from_yaml(cls, config_path: str | Path) -> \"GLiClassServeConfig\":\n        \"\"\"Load configuration from YAML file.\n\n        Args:\n            config_path: Path to YAML config file\n\n        Returns:\n            GLiClassServeConfig instance\n        \"\"\"\n        config_path = Path(config_path)\n        with config_path.open(\"r\") as f:\n            config_dict = yaml.safe_load(f)\n        return cls(**config_dict)\n\n    def to_yaml(self, config_path: str | Path) -> None:\n        \"\"\"Save configuration to YAML file.\n\n        Args:\n            config_path: Path to save YAML config\n        \"\"\"\n        config_path = Path(config_path)\n        config_path.parent.mkdir(parents=True, exist_ok=True)\n        with config_path.open(\"w\") as f:\n            yaml.dump(asdict(self), f, default_flow_style=False, sort_keys=False)\n\n    def update(self, **kwargs) -> \"GLiClassServeConfig\":\n        \"\"\"Update config with provided kwargs (for CLI override).\n\n        Args:\n            **kwargs: Fields to update (None values are ignored)\n\n        Returns:\n            Updated config instance\n        \"\"\"\n        for key, value in kwargs.items():\n            if value is not None and hasattr(self, key):\n                setattr(self, key, value)\n        return self\n"
  },
  {
    "path": "gliclass/serve/memory.py",
    "content": "\"\"\"Memory estimation for GLiClass via precomputed calibration table.\n\nStartup calibration runs the model on probe batches at power-of-two sequence\nlengths and records peak GPU memory per sample. At request time ``batch_size_fn``\npicks the largest precompiled batch size that satisfies\n\n    per_sample(seq_len) * N  <=  total_gpu - cuda_context - model_weights\n\nusing a pessimistic (rounded-up) seq_len and a safety factor on per-sample\nmemory.\n\"\"\"\n\nimport logging\nfrom typing import Dict, List, Callable\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\n\ndef _power_of_two_seq_lens(max_seq_len: int, min_seq_len: int = 64) -> List[int]:\n    \"\"\"Return power-of-two sequence lengths from min_seq_len up to max_seq_len.\"\"\"\n    lens: List[int] = []\n    s = max(1, min_seq_len)\n    while s < max_seq_len:\n        lens.append(s)\n        s *= 2\n    lens.append(max_seq_len)\n    return lens\n\n\nclass GLiClassMemoryEstimator:\n    \"\"\"Precomputed memory table for GLiClass inference.\"\"\"\n\n    def __init__(\n        self,\n        safety_factor: float = 1.3,\n        target_memory_fraction: float = 0.9,\n        calibration_probe_batch_size: int = 2,\n    ):\n        self.safety_factor = safety_factor\n        self.target_memory_fraction = target_memory_fraction\n        self.calibration_probe_batch_size = max(2, calibration_probe_batch_size)\n\n        self.total_gpu_memory: int = 0\n        self.cuda_context_bytes: int = 0\n        self.model_memory_bytes: int = 0\n\n        self.per_sample_table: Dict[int, int] = {}\n\n    def measure_cuda_context(self) -> None:\n        \"\"\"Record CUDA context overhead. Must be called before the model loads.\"\"\"\n        if not torch.cuda.is_available():\n            return\n        torch.cuda.synchronize()\n        free, total = torch.cuda.mem_get_info()\n        self.total_gpu_memory = total\n        self.cuda_context_bytes = total - free\n        logger.info(\"CUDA context: %.1f MiB\", self.cuda_context_bytes / (1024**2))\n\n    def measure_model_memory(self) -> None:\n        \"\"\"Record model weight memory. Must be called after the model loads.\"\"\"\n        if not torch.cuda.is_available():\n            return\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        free, total = torch.cuda.mem_get_info()\n        self.total_gpu_memory = total\n        used = total - free\n        self.model_memory_bytes = max(0, used - self.cuda_context_bytes)\n        logger.info(\"Model weights: %.1f MiB\", self.model_memory_bytes / (1024**2))\n\n    def available_memory(self) -> int:\n        \"\"\"Budget for a batch: ``total_gpu - cuda_context - model_weights``.\"\"\"\n        if not torch.cuda.is_available():\n            return 0\n        budget = self.total_gpu_memory - self.cuda_context_bytes - self.model_memory_bytes\n        return max(0, int(budget * self.target_memory_fraction))\n\n    def calibrate(\n        self,\n        predict_method: Callable,\n        max_seq_len: int,\n        min_seq_len: int = 64,\n    ) -> None:\n        \"\"\"Populate ``per_sample_table`` across power-of-two seq lengths.\n\n        Uses a small set of dummy labels to probe classification performance.\n        \"\"\"\n        if not torch.cuda.is_available():\n            return\n\n        seq_lens = _power_of_two_seq_lens(max_seq_len, min_seq_len=min_seq_len)\n        dummy_labels = [\"label1\", \"label2\", \"label3\"]\n        probe_b = self.calibration_probe_batch_size\n\n        logger.info(\"Calibrating memory table: seq_lens=%s, probe_batch=%s\", seq_lens, probe_b)\n\n        for seq_len in seq_lens:\n            dummy_text = \"word \" * max(1, seq_len // 2)\n            peak = self._measure_peak(predict_method, [dummy_text] * probe_b, dummy_labels)\n            per_sample = max(1, peak // probe_b)\n            self.per_sample_table[seq_len] = per_sample\n            logger.info(\"  seq_len=%5d: per_sample=%.1f MiB\", seq_len, per_sample / (1024**2))\n\n    def _measure_peak(\n        self,\n        predict_method: Callable,\n        texts: List[str],\n        labels: List[str],\n    ) -> int:\n        \"\"\"Run a probe batch and return peak allocated bytes above baseline.\"\"\"\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        baseline = torch.cuda.memory_allocated()\n\n        predict_method(texts, labels, threshold=0.5)\n\n        torch.cuda.synchronize()\n        peak = torch.cuda.max_memory_allocated()\n        return max(0, peak - baseline)\n\n    def _lookup_seq_len(self, seq_len: int) -> int:\n        \"\"\"Round ``seq_len`` up to the nearest calibrated entry (pessimistic).\"\"\"\n        if not self.per_sample_table:\n            raise RuntimeError(\"Memory estimator has not been calibrated\")\n        for key in sorted(self.per_sample_table.keys()):\n            if key >= seq_len:\n                return key\n        return max(self.per_sample_table.keys())\n\n    def per_sample_at(self, seq_len: int) -> int:\n        \"\"\"Pessimistic per-sample memory at or above ``seq_len``.\"\"\"\n        probe_seq_len = self._lookup_seq_len(seq_len)\n        return int(self.per_sample_table[probe_seq_len] * self.safety_factor)\n\n    def batch_size_fn(\n        self,\n        seq_len: int,\n        precompiled_sizes: List[int],\n    ) -> int:\n        \"\"\"Largest precompiled batch size satisfying ``per_sample * N <= budget``.\n\n        Budget = ``total_gpu - cuda_context - model_weights`` (times the\n        configured ``target_memory_fraction``).\n        \"\"\"\n        if not precompiled_sizes:\n            return 1\n\n        available = self.available_memory()\n        if available <= 0:\n            return min(precompiled_sizes)\n\n        per_sample = self.per_sample_at(seq_len)\n        for size in sorted(precompiled_sizes, reverse=True):\n            if per_sample * size <= available:\n                return size\n        return min(precompiled_sizes)\n"
  },
  {
    "path": "gliclass/serve/server.py",
    "content": "\"\"\"Ray Serve deployment for GLiClass with dynamic batching.\"\"\"\n\nimport os\nimport logging\nfrom typing import Any\n\nimport torch\nfrom ray import serve\nfrom transformers import AutoTokenizer\n\nfrom gliclass.model import GLiClassModel\nfrom gliclass.pipeline import ZeroShotClassificationPipeline\n\nfrom .config import GLiClassServeConfig\nfrom .memory import GLiClassMemoryEstimator\n\nlogger = logging.getLogger(\"ray.serve\")\n\n\nclass GLiClassServer:\n    \"\"\"GLiClass Ray Serve deployment with dynamic batching.\"\"\"\n\n    def __init__(self, config: GLiClassServeConfig):\n        \"\"\"Initialize GLiClass server deployment.\n\n        Args:\n            config: Server configuration with model and serving parameters\n        \"\"\"\n        self.config = config\n\n        env_vars = config.to_env_vars()\n        for key, value in env_vars.items():\n            os.environ[key] = value\n\n        if config.tokenizer_threads > 0:\n            os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n            torch.set_num_threads(config.tokenizer_threads)\n\n        torch.set_float32_matmul_precision(\"high\")\n\n        dtype_map = {\n            \"float32\": torch.float32,\n            \"float16\": torch.float16,\n            \"fp16\": torch.float16,\n            \"bfloat16\": torch.bfloat16,\n            \"bf16\": torch.bfloat16,\n        }\n        self.torch_dtype = dtype_map.get(config.dtype.lower(), torch.bfloat16)\n        self.device = torch.device(config.device)\n\n        self.memory_estimator = GLiClassMemoryEstimator(\n            safety_factor=config.memory_overhead_factor,\n            target_memory_fraction=config.target_memory_fraction,\n            calibration_probe_batch_size=config.calibration_probe_batch_size,\n        )\n\n        if torch.cuda.is_available():\n            self.memory_estimator.measure_cuda_context()\n\n        logger.info(\"Loading model: %s\", config.model)\n\n        self.model = GLiClassModel.from_pretrained(config.model)\n        self.model.config.max_labels_alloc = config.max_labels_alloc\n        self.model.to(device=self.device, dtype=self.torch_dtype)\n        self.model.eval()\n\n        self.tokenizer = AutoTokenizer.from_pretrained(config.model)\n        pipeline_kwargs = {\n            \"model\": self.model,\n            \"tokenizer\": self.tokenizer,\n            \"max_classes\": config.max_labels if config.max_labels > 0 else 100,\n            \"max_length\": config.max_model_len,\n            \"device\": self.device,\n            \"progress_bar\": False,\n        }\n        self.pipeline = ZeroShotClassificationPipeline(\n            classification_type=\"multi-label\",\n            **pipeline_kwargs,\n        )\n\n        if torch.cuda.is_available():\n            self.memory_estimator.measure_model_memory()\n\n        if config.enable_compilation:\n            self._precompile()\n\n        if torch.cuda.is_available():\n            self._calibrate_memory()\n\n        logger.info(\"GLiClass server initialized successfully\")\n\n    def _precompile(self) -> None:\n        logger.info(\"Precompiling model for batch sizes: %s\", self.config.precompiled_batch_sizes)\n\n        if hasattr(self.model, \"compile\"):\n            self.model.compile()\n\n        dummy_labels = [\"person\", \"organization\", \"location\"]\n\n        for batch_size in self.config.precompiled_batch_sizes:\n            dummy_texts = [f\"Sample text number {i} for precompilation warmup.\" for i in range(batch_size)]\n\n            for _ in range(self.config.warmup_iterations):\n                self._run_batch_internal(\n                    dummy_texts,\n                    dummy_labels,\n                    threshold=0.5,\n                    multi_label=True,\n                )\n\n            logger.info(\"  Batch size %d: compiled\", batch_size)\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n\n        logger.info(\"Precompilation complete.\")\n\n    def _calibrate_memory(self) -> None:\n        logger.info(\"Calibrating memory table...\")\n\n        self.memory_estimator.calibrate(\n            self._run_batch_internal,\n            max_seq_len=self.config.max_model_len,\n            min_seq_len=self.config.calibration_min_seq_len,\n        )\n\n        logger.info(\"Memory calibration complete.\")\n\n    def batch_size_fn(self, seq_len: int | None = None) -> int:\n        \"\"\"Largest precompiled batch size that fits at seq_len.\n\n        Args:\n            seq_len: Sequence length (text + label words). If None, uses max_model_len.\n\n        Returns:\n            Optimal batch size from precompiled sizes\n        \"\"\"\n        if not torch.cuda.is_available():\n            return self.config.precompiled_batch_sizes[-1]\n\n        if seq_len is None:\n            seq_len = self.config.max_model_len\n\n        return self.memory_estimator.batch_size_fn(\n            seq_len=seq_len,\n            precompiled_sizes=self.config.precompiled_batch_sizes,\n        )\n\n    def observed_seq_len(\n        self,\n        texts: list[str],\n        labels: list[str] | list[list[str]] | None = None,\n    ) -> int:\n        \"\"\"Total input word count: longest text + all label words.\n\n        Labels are concatenated into input, so they extend effective seq length\n        for every sample in the batch.\n\n        Args:\n            texts: Input texts\n            labels: Label list\n\n        Returns:\n            Estimated sequence length\n        \"\"\"\n        max_text_words = max((len(t.split()) for t in texts if t.strip()), default=0)\n        prompt_words = 0\n        if labels:\n            if isinstance(labels[0], list):\n                prompt_words += max(sum(len(label.split()) for label in label_set) for label_set in labels)\n            else:\n                prompt_words += sum(len(label.split()) for label in labels)\n        total = max_text_words + prompt_words\n        return min(max(total, self.config.calibration_min_seq_len), self.config.max_model_len)\n\n    def _filter_labels(self, labels: list[str]) -> list[str]:\n        if self.config.max_labels > 0 and len(labels) > self.config.max_labels:\n            logger.warning(\"Truncating labels from %d to %d\", len(labels), self.config.max_labels)\n            return labels[: self.config.max_labels]\n        return labels\n\n    @torch.inference_mode()\n    def _run_batch_internal(\n        self,\n        texts: list[str],\n        labels: list[str] | list[list[str]],\n        threshold: float | list[float] = 0.5,\n        multi_label: bool | list[bool] = True,\n        examples: list[dict[str, Any]] | list[list[dict[str, Any]] | None] | None = None,\n        prompt: str | list[str] | None = None,\n    ) -> list[list[dict[str, Any]]]:\n        \"\"\"Run batch inference using the shared zero-shot pipeline.\n\n        Args:\n            texts: List of input texts\n            labels: Shared label list or one label list per text\n            threshold: Shared threshold or one threshold per text\n            multi_label: Shared mode or one mode flag per text\n            examples: Shared examples or one example set per text\n            prompt: Shared prompt or one prompt per text\n\n        Returns:\n            List of prediction dicts\n        \"\"\"\n        if isinstance(multi_label, list):\n            classification_type = [\"multi-label\" if item else \"single-label\" for item in multi_label]\n        else:\n            classification_type = \"multi-label\" if multi_label else \"single-label\"\n\n        return self.pipeline(\n            texts,\n            labels,\n            threshold=threshold,\n            batch_size=max(len(texts), 1),\n            classification_type=classification_type,\n            examples=examples,\n            prompt=prompt,\n        )\n\n    def predict(\n        self,\n        texts: str | list[str],\n        labels: list[str],\n        threshold: float | None = None,\n        multi_label: bool = True,\n        examples: list[dict[str, Any]] | None = None,\n        prompt: str | list[str] | None = None,\n    ) -> list[list[dict[str, Any]]]:\n        if isinstance(texts, str):\n            texts = [texts]\n\n        if threshold is None:\n            threshold = self.config.default_threshold\n\n        labels = self._filter_labels(labels)\n\n        results = self._run_batch_internal(\n            texts,\n            labels,\n            threshold=threshold,\n            multi_label=multi_label,\n            examples=examples,\n            prompt=prompt,\n        )\n\n        return results\n\n\ndef _build_deployment(config: GLiClassServeConfig):\n    batch_wait_s = max(config.batch_wait_timeout_ms, 0.0) / 1000.0\n    initial_max_batch_size = config.max_batch_size\n\n    @serve.deployment(\n        num_replicas=config.num_replicas,\n        ray_actor_options={\n            \"num_gpus\": config.num_gpus_per_replica,\n            \"num_cpus\": config.num_cpus_per_replica,\n        },\n        max_ongoing_requests=config.max_ongoing_requests,\n    )\n    class GLiClassDeployment:\n        def __init__(self, serve_config: GLiClassServeConfig):\n            self.server = GLiClassServer(serve_config)\n            # Initialize dynamic batch sizing\n            self._infer_batch.set_max_batch_size(self.server.batch_size_fn())\n            logger.info(\n                \"Ray Serve batch size initialized to %d (precompiled: %s)\",\n                self.server.batch_size_fn(),\n                serve_config.precompiled_batch_sizes,\n            )\n\n        @serve.batch(\n            max_batch_size=initial_max_batch_size,\n            batch_wait_timeout_s=batch_wait_s,\n        )\n        async def _infer_batch(\n            self,\n            texts: list[str],\n            labels_list: list[list[str]],\n            thresholds: list[float],\n            multi_label_list: list[bool],\n            examples_list: list[list[dict[str, Any]] | None],\n            prompts_list: list[str | None],\n        ) -> list[list[dict[str, Any]]]:\n            \"\"\"Single forward pass over the Ray-accumulated batch.\n\n            Before dispatch, re-sizes Ray's batcher via set_max_batch_size\n            using batch_size_fn on the observed seq length — so the next\n            accumulation picks the largest precompiled size that fits.\n\n            Supports heterogeneous request parameters by passing per-text\n            thresholds, classification types, labels, examples, and prompts\n            through to the shared pipeline.\n            \"\"\"\n            # Dynamically adjust batch size based on observed sequence length\n            next_max_batch = self.server.batch_size_fn(\n                seq_len=self.server.observed_seq_len(\n                    texts,\n                    labels=labels_list,\n                )\n            )\n            self._infer_batch.set_max_batch_size(next_max_batch)\n\n            # Process entire batch at once\n            results = self.server._run_batch_internal(\n                texts,\n                labels_list,\n                threshold=thresholds,\n                multi_label=multi_label_list,\n                examples=examples_list,\n                prompt=prompts_list,\n            )\n\n            return results\n\n        async def predict(\n            self,\n            text: str,\n            labels: list[str],\n            threshold: float | None = None,\n            multi_label: bool = True,\n            examples: list[dict[str, Any]] | None = None,\n            prompt: str | None = None,\n        ) -> list[dict[str, Any]]:\n            \"\"\"Single prediction endpoint - one text per request.\"\"\"\n            if threshold is None:\n                threshold = self.server.config.default_threshold\n\n            # Call batched method - Ray will accumulate these\n            results = await self._infer_batch(\n                text,\n                labels,\n                threshold,\n                multi_label,\n                examples,\n                prompt,\n            )\n            return results\n\n        async def __call__(self, request) -> list[dict[str, Any]]:\n            \"\"\"HTTP endpoint - accepts single text per request.\"\"\"\n            payload = await request.json()\n            text = payload.get(\"text\") or payload.get(\"texts\")\n            if isinstance(text, list):\n                # If list provided, take first element for compatibility\n                text = text[0] if text else \"\"\n            return await self.predict(\n                text=text,\n                labels=payload[\"labels\"],\n                threshold=payload.get(\"threshold\"),\n                multi_label=payload.get(\"multi_label\", True),\n                examples=payload.get(\"examples\"),\n                prompt=payload.get(\"prompt\"),\n            )\n\n    return GLiClassDeployment.bind(config)\n\n\ndef serve_gliclass(\n    config: GLiClassServeConfig,\n    blocking: bool = False,\n) -> Any:\n    import ray\n\n    if not ray.is_initialized():\n        ray.init(address=config.ray_address, ignore_reinit_error=True)\n\n    serve.start(detached=True, http_options={\"port\": config.http_port})\n\n    app = _build_deployment(config)\n    handle = serve.run(app, name=\"gliclass\", route_prefix=config.route_prefix)\n\n    logger.info(\"GLiClass server running at http://localhost:%d%s\", config.http_port, config.route_prefix)\n\n    if blocking:\n        import time\n        import signal\n\n        shutdown_event = False\n\n        def handle_signal(_signum, _frame):\n            nonlocal shutdown_event\n            shutdown_event = True\n\n        signal.signal(signal.SIGINT, handle_signal)\n        signal.signal(signal.SIGTERM, handle_signal)\n\n        while not shutdown_event:\n            time.sleep(1)\n\n        serve.shutdown()\n\n    return handle\n\n\ndef shutdown() -> None:\n    serve.shutdown()\n\n\nclass GLiClassFactory:\n    \"\"\"Synchronous facade: config → deploy → predict → shutdown in one object.\n\n    Pass list of texts to preserve dynamic batching - Ray Serve accumulates\n    concurrent requests into single forward pass.\n\n    Example:\n        >>> from serve import GLiClassFactory\n        >>> llm = GLiClassFactory(model=\"knowledgator/gliclass-edge-v3.0\")\n        >>> outputs = llm.predict(\n        ...     [\"Great product!\", \"Terrible service\"],\n        ...     labels=[\"positive\", \"negative\", \"neutral\"],\n        ... )\n        >>> llm.shutdown()\n\n        Or as context manager:\n        >>> with GLiClassFactory(model=\"knowledgator/gliclass-edge-v3.0\") as llm:\n        ...     out = llm.predict(\"Great product!\", [\"positive\", \"negative\"])\n    \"\"\"\n\n    def __init__(\n        self,\n        model: str | None = None,\n        *,\n        config: GLiClassServeConfig | None = None,\n        **kwargs,\n    ):\n        \"\"\"Pass either `config` or `model`/kwargs, not both.\"\"\"\n        if config is not None:\n            if model is not None or kwargs:\n                raise ValueError(\"Pass either `config` or `model`/kwargs, not both.\")\n        else:\n            if model is None:\n                raise ValueError(\"Must provide either `model` or `config`.\")\n            config = GLiClassServeConfig(model=model, **kwargs)\n\n        self.config = config\n        self._handle = serve_gliclass(config, blocking=False)\n        self._closed = False\n\n    @property\n    def handle(self):\n        \"\"\"Underlying Ray Serve deployment handle for async/advanced use.\"\"\"\n        return self._handle\n\n    def predict(\n        self,\n        texts: str | list[str],\n        labels: list[str],\n        threshold: float | None = None,\n        multi_label: bool = True,\n        examples: list[dict[str, Any]] | None = None,\n        prompt: str | list[str] | None = None,\n    ) -> dict[str, Any] | list[dict[str, Any]]:\n        \"\"\"Blocking prediction. Returns dict for str input, list for list input.\"\"\"\n        single = isinstance(texts, str)\n        items = [texts] if single else list(texts)\n\n        refs = [\n            self._handle.predict.remote(\n                t,\n                labels,\n                threshold,\n                multi_label,\n                examples,\n                prompt,\n            )\n            for t in items\n        ]\n        results = [ref.result() for ref in refs]\n        return results[0] if single else results\n\n    async def predict_async(\n        self,\n        texts: str | list[str],\n        labels: list[str],\n        threshold: float | None = None,\n        multi_label: bool = True,\n        examples: list[dict[str, Any]] | None = None,\n        prompt: str | list[str] | None = None,\n    ) -> dict[str, Any] | list[dict[str, Any]]:\n        \"\"\"Async prediction. Concurrent calls accumulate into one batch.\"\"\"\n        import asyncio\n\n        single = isinstance(texts, str)\n        items = [texts] if single else list(texts)\n\n        refs = [\n            self._handle.predict.remote(\n                t,\n                labels,\n                threshold,\n                multi_label,\n                examples,\n                prompt,\n            )\n            for t in items\n        ]\n        results = list(await asyncio.gather(*refs))\n        return results[0] if single else results\n\n    def shutdown(self) -> None:\n        \"\"\"Tear down Ray Serve deployment and Ray runtime.\n\n        Idempotent. Shutting down Ray after Serve avoids leaving driver\n        attached to detached Serve instance.\n        \"\"\"\n        if self._closed:\n            return\n        import ray\n\n        serve.shutdown()\n        if ray.is_initialized():\n            ray.shutdown()\n        self._closed = True\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        self.shutdown()\n        return False\n\n    def __del__(self):\n        try:\n            self.shutdown()\n        except Exception:\n            pass\n"
  },
  {
    "path": "gliclass/training.py",
    "content": "import os\nfrom typing import Any, Dict, List, Tuple, Callable\nfrom dataclasses import field, dataclass\n\nimport numpy as np\nimport torch\nimport transformers\nimport torch.nn.functional as F\nfrom tqdm import tqdm\nfrom torch import nn\nfrom transformers import ZeroShotClassificationPipeline as TransformersClassificationPipeline\nfrom torch.utils.data import Dataset, DataLoader\nfrom transformers.trainer import (\n    get_parameter_names,\n    is_sagemaker_mp_enabled,\n)\n\nfrom .utils import default_f1_reward, is_module_available\nfrom .pipeline import ZeroShotClassificationPipeline\n\nif is_module_available(\"apex\"):\n    from apex import amp\n\n    _has_apex = True\nelse:\n    _has_apex = False\n    amp = None\n\nALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm]\n\n\nclass EWC:\n    \"\"\"Elastic Weight Consolidation for preventing catastrophic forgetting in GLiClass models.\"\"\"\n\n    def __init__(\n        self,\n        model: nn.Module,\n        dataset: Dataset,\n        data_collator: Any | None = None,\n        device: str = \"cpu\",\n        ewc_lambda: float = 100.0,\n        batch_size: int = 8,\n        num_samples: int | None = None,\n        fisher_estimation_method: str = \"empirical\",\n        normalize_fisher: bool = True,\n    ):\n        \"\"\"Initialize EWC.\n\n        Args:\n            model: The GLiClass model to apply EWC to\n            dataset: Dataset from previous task to compute Fisher information\n            data_collator: Data collator for batching (required for GLiClass)\n            device: Device to use for computation\n            ewc_lambda: Importance weight for EWC penalty (higher = more regularization)\n            batch_size: Batch size for Fisher computation\n            num_samples: Number of samples to use for Fisher estimation (None = use all)\n            fisher_estimation_method: Method for Fisher estimation ('empirical' or 'diagonal')\n            normalize_fisher: Whether to normalize Fisher information values\n        \"\"\"\n        self.model = model\n        self.device = device\n        self.ewc_lambda = ewc_lambda\n        self.batch_size = batch_size\n        self.num_samples = num_samples\n        self.fisher_estimation_method = fisher_estimation_method\n        self.normalize_fisher = normalize_fisher\n        self.data_collator = data_collator\n\n        # Store old parameters (deep copy to avoid reference issues)\n        self.old_params: Dict[str, torch.Tensor] = {}\n        for name, param in model.named_parameters():\n            if param.requires_grad:\n                self.old_params[name] = param.data.clone().detach()\n\n        # Compute Fisher information matrix\n        self.fisher_info: Dict[str, torch.Tensor] = self._compute_fisher(dataset)\n\n        # Optionally normalize Fisher information\n        if self.normalize_fisher:\n            self._normalize_fisher()\n\n    def _compute_fisher(self, dataset: Dataset) -> Dict[str, torch.Tensor]:\n        \"\"\"Compute diagonal Fisher information matrix.\n\n        The Fisher information measures how sensitive the loss is to changes\n        in each parameter. Parameters with high Fisher information are important\n        for the previous task.\n\n        Args:\n            dataset: Dataset to compute Fisher information from\n\n        Returns:\n            Dictionary mapping parameter names to Fisher information tensors\n        \"\"\"\n        # Initialize Fisher information to zeros\n        fisher: Dict[str, torch.Tensor] = {}\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                fisher[name] = torch.zeros_like(param, device=self.device)\n\n        # Set model to evaluation mode for consistent behavior\n        was_training = self.model.training\n        self.model.eval()\n\n        # Create dataloader\n        if self.num_samples is not None and self.num_samples < len(dataset):\n            # Subsample dataset for efficiency\n            indices = torch.randperm(len(dataset))[: self.num_samples].tolist()\n            subset = torch.utils.data.Subset(dataset, indices)\n            loader = DataLoader(subset, batch_size=self.batch_size, shuffle=False, collate_fn=self.data_collator)\n        else:\n            loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.data_collator)\n\n        num_batches = len(loader)\n\n        print(f\"Computing Fisher information from {len(loader.dataset)} samples...\")\n\n        # Compute Fisher information\n        for batch in tqdm(loader, desc=\"Computing Fisher\"):\n            self.model.zero_grad()\n\n            # Prepare inputs - handle GLiClass specific fields\n            if isinstance(batch, dict):\n                # Remove non-tensor fields that GLiClass might have\n                inputs = {k: v for k, v in batch.items() if k not in [\"labels_text\", \"input_texts\"]}\n\n                # Move tensors to device\n                for k, v in inputs.items():\n                    if isinstance(v, torch.Tensor):\n                        inputs[k] = v.to(self.device)\n            else:\n                inputs = batch\n\n            try:\n                # Forward pass\n                outputs = self.model(**inputs)\n\n                if self.fisher_estimation_method == \"empirical\":\n                    # Use the actual loss for empirical Fisher\n                    if hasattr(outputs, \"loss\") and outputs.loss is not None:\n                        loss = outputs.loss\n                    else:\n                        # Compute loss manually if not provided\n                        logits = outputs.logits\n                        labels = inputs.get(\"labels\")\n                        if labels is not None:\n                            # Handle multi-label classification\n                            if self.model.config.problem_type == \"multi_label_classification\":\n                                loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).float())\n                            else:\n                                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))\n                        else:\n                            continue\n                else:\n                    # Diagonal Fisher: sample from model's predictive distribution\n                    logits = outputs.logits\n\n                    if self.model.config.problem_type == \"multi_label_classification\":\n                        probs = torch.sigmoid(logits)\n                        # Sample binary labels\n                        sampled_labels = torch.bernoulli(probs)\n                        loss = F.binary_cross_entropy_with_logits(logits.view(-1), sampled_labels.view(-1))\n                    else:\n                        probs = F.softmax(logits, dim=-1)\n                        log_probs = F.log_softmax(logits, dim=-1)\n                        sampled_labels = torch.multinomial(probs.view(-1, probs.size(-1)), 1).squeeze(-1)\n                        loss = F.nll_loss(log_probs.view(-1, log_probs.size(-1)), sampled_labels)\n\n                # Backward pass to compute gradients\n                loss.backward()\n\n                # Accumulate squared gradients (Fisher information)\n                for name, param in self.model.named_parameters():\n                    if param.requires_grad and param.grad is not None:\n                        fisher[name] += param.grad.data**2 / num_batches\n\n            except Exception as e:\n                print(f\"Warning: Error computing Fisher for batch: {e}\")\n                continue\n\n        # Restore model training mode\n        if was_training:\n            self.model.train()\n\n        return fisher\n\n    def _normalize_fisher(self):\n        \"\"\"Normalize Fisher information values to prevent numerical issues.\"\"\"\n        # Compute max Fisher value across all parameters\n        max_fisher = 0.0\n        for _, fisher_val in self.fisher_info.items():\n            max_fisher = max(max_fisher, fisher_val.max().item())\n\n        if max_fisher > 0:\n            # Normalize by max value\n            for name in self.fisher_info:\n                self.fisher_info[name] = self.fisher_info[name] / max_fisher\n\n    def ewc_loss(self, batch_size: int | None = None) -> torch.Tensor:\n        \"\"\"Compute EWC penalty loss.\n\n        The EWC loss penalizes changes to parameters that were important\n        for the previous task (as measured by Fisher information).\n\n        Args:\n            batch_size: Batch size for normalization (optional)\n\n        Returns:\n            EWC penalty loss tensor\n        \"\"\"\n        loss = torch.tensor(0.0, device=self.device)\n\n        for name, param in self.model.named_parameters():\n            if param.requires_grad and name in self.fisher_info:\n                # EWC penalty: F_i * (theta_i - theta_i^*)^2\n                param_diff = param - self.old_params[name].to(param.device)\n                fisher = self.fisher_info[name].to(param.device)\n                loss += (fisher * param_diff**2).sum()\n\n        # Optionally normalize by batch size\n        if batch_size is not None:\n            loss = loss / batch_size\n\n        return self.ewc_lambda * loss\n\n    def get_importance_scores(self) -> Dict[str, float]:\n        \"\"\"Get importance scores for each parameter group.\n\n        Returns:\n            Dictionary mapping parameter names to average importance scores\n        \"\"\"\n        scores = {}\n        for name, fisher in self.fisher_info.items():\n            scores[name] = fisher.mean().item()\n        return scores\n\n    def update_lambda(self, new_lambda: float):\n        \"\"\"Update the EWC lambda value.\n\n        Args:\n            new_lambda: New lambda value for EWC penalty\n        \"\"\"\n        self.ewc_lambda = new_lambda\n\n    def consolidate(self, dataset: Dataset, alpha: float = 0.5):\n        \"\"\"Consolidate knowledge by updating Fisher information with new task.\n\n        This allows for online EWC where multiple tasks are consolidated.\n\n        Args:\n            dataset: Dataset from new task\n            alpha: Mixing coefficient (0 = keep old Fisher, 1 = use new Fisher only)\n        \"\"\"\n        # Compute new Fisher information\n        new_fisher = self._compute_fisher(dataset)\n\n        # Mix old and new Fisher information\n        for name in self.fisher_info:\n            if name in new_fisher:\n                self.fisher_info[name] = (1 - alpha) * self.fisher_info[name] + alpha * new_fisher[name]\n\n        # Update old parameters\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                self.old_params[name] = param.data.clone().detach()\n\n        # Re-normalize if needed\n        if self.normalize_fisher:\n            self._normalize_fisher()\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: str | None = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    others_lr: float | None = None\n    others_weight_decay: float | None = 0.0\n\n    use_ewc: bool = field(\n        default=False, metadata={\"help\": \"Whether to use Elastic Weight Consolidation (EWC) for continual learning.\"}\n    )\n    ewc_lambda: float = field(\n        default=100.0, metadata={\"help\": \"Lambda parameter for EWC penalty. Higher values = more regularization.\"}\n    )\n    ewc_fisher_samples: int | None = field(\n        default=None, metadata={\"help\": \"Number of samples to use for Fisher information estimation. None = use all.\"}\n    )\n    ewc_normalize_fisher: bool = field(\n        default=True, metadata={\"help\": \"Whether to normalize Fisher information values.\"}\n    )\n    ewc_gamma: float = field(default=0.95, metadata={\"help\": \"Decay factor for Online EWC.\"})\n\n\nclass Trainer(transformers.Trainer):\n    \"\"\"Extended Trainer with EWC support for continual learning.\"\"\"\n\n    def __init__(self, ewc: EWC | None = None, prev_dataset=None, *args, **kwargs):\n        \"\"\"Initialize Trainer with optional EWC support.\n\n        Args:\n            ewc: Pre-initialized EWC object (optional)\n            prev_dataset: Previous dataset for EWC initialization (optional)\n            *args: Arguments passed to parent Trainer\n            **kwargs: Keyword arguments passed to parent Trainer\n        \"\"\"\n        super().__init__(*args, **kwargs)\n\n        # Ensure use_apex is set for compatibility with different transformers versions\n        if not hasattr(self, \"use_apex\"):\n            self.use_apex = False\n\n        self.ewc = ewc\n        self.prev_dataset = prev_dataset\n        self._ewc_initialized = ewc is not None\n\n    def _maybe_initialize_ewc(self):\n        \"\"\"Initialize EWC if needed and not already initialized.\"\"\"\n        if self._ewc_initialized or not self.args.use_ewc:\n            return\n\n        if self.prev_dataset is None:\n            print(\"Warning: EWC is enabled but no previous dataset provided. Skipping EWC initialization.\")\n            return\n\n        print(f\"Initializing EWC with lambda={self.args.ewc_lambda}...\")\n\n        # Get the data collator\n        data_collator = self.data_collator\n\n        # Determine device\n        device = (\n            self.model.device\n            if hasattr(self.model, \"device\")\n            else (next(self.model.parameters()).device if list(self.model.parameters()) else \"cpu\")\n        )\n\n        # Create EWC instance\n        ewc_kwargs = {\n            \"model\": self.model,\n            \"dataset\": self.prev_dataset,\n            \"data_collator\": data_collator,\n            \"device\": str(device),\n            \"ewc_lambda\": self.args.ewc_lambda,\n            \"num_samples\": self.args.ewc_fisher_samples,\n            \"normalize_fisher\": self.args.ewc_normalize_fisher,\n        }\n\n        self.ewc = EWC(**ewc_kwargs)\n        self._ewc_initialized = True\n        print(\"EWC initialization complete.\")\n\n    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n        \"\"\"Compute loss with optional EWC penalty.\n\n        Args:\n            model: The model\n            inputs: Input batch\n            return_outputs: Whether to return model outputs\n            **kwargs: Additional arguments\n\n        Returns:\n            Loss tensor, or tuple of (loss, outputs) if return_outputs=True\n        \"\"\"\n        # Get base loss from parent\n        if return_outputs:\n            loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)\n        else:\n            loss = super().compute_loss(model, inputs, return_outputs=False, **kwargs)\n            outputs = None\n\n        # Add EWC penalty if enabled\n        if self.ewc is not None and self.args.use_ewc:\n            batch_size = inputs.get(\"input_ids\", inputs.get(\"labels\")).shape[0] if isinstance(inputs, dict) else None\n            ewc_loss = self.ewc.ewc_loss(batch_size=batch_size)\n            loss = loss + ewc_loss\n\n        if return_outputs:\n            return loss, outputs\n        return loss\n\n    def train(self, *args, **kwargs):\n        \"\"\"Train with EWC initialization.\"\"\"\n        # Initialize EWC before training starts\n        self._maybe_initialize_ewc()\n        return super().train(*args, **kwargs)\n\n    def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:\n        \"\"\"\n        Perform a training step on a batch of inputs.\n\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (`nn.Module`):\n                The model to train.\n            inputs (`Dict[str, Union[torch.Tensor, Any]]`):\n                The inputs and targets of the model.\n\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument `labels`. Check your model's documentation for all accepted arguments.\n\n        Return:\n            `torch.Tensor`: The tensor with training loss on this batch.\n        \"\"\"\n        model.train()\n        try:\n            if \"labels_text\" in inputs:\n                inputs.pop(\"labels_text\")\n            if \"input_texts\" in inputs:\n                inputs.pop(\"input_texts\")\n            inputs = self._prepare_inputs(inputs)\n\n            with self.compute_loss_context_manager():\n                loss = self.compute_loss(model, inputs)\n\n            del inputs\n            torch.cuda.empty_cache()\n\n            kwargs = {}\n\n            if self.args.n_gpu > 1:\n                loss = loss.mean()  # mean() to average on multi-gpu parallel training\n\n            if self.use_apex and _has_apex:\n                with amp.scale_loss(loss, self.optimizer) as scaled_loss:\n                    scaled_loss.backward()\n            else:\n                self.accelerator.backward(loss, **kwargs)\n\n            return loss.detach() / self.args.gradient_accumulation_steps\n        except Exception as e:\n            print(f\"Skipping iteration due to error: {e}\")\n            model.zero_grad(set_to_none=True)\n            torch.cuda.empty_cache()\n            return torch.tensor(0.0, requires_grad=True).to(model.device)\n\n    def prediction_step(\n        self,\n        model: torch.nn.Module,\n        inputs: Dict[str, torch.Tensor | Any],\n        prediction_loss_only: bool,\n        ignore_keys: List[str] | None = None,\n    ) -> Tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:\n        \"\"\"\n        Perform an evaluation step on model using inputs.\n        Subclass and override to inject custom behavior.\n\n        Args:\n            model (nn.Module):\n                The model to evaluate.\n            inputs (Dict[str, Union[torch.Tensor, Any]]):\n                The inputs and targets of the model.\n                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the\n                argument labels. Check your model's documentation for all accepted arguments.\n            prediction_loss_only (bool):\n                Whether or not to return the loss only.\n            ignore_keys (List[str], *optional*):\n                A list of keys in the output of your model (if it is a dictionary) that should be ignored when\n                gathering predictions.\n\n        Return:\n            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,\n            logits and labels (each being optional).\n        \"\"\"\n        try:\n            with torch.no_grad():\n                if \"labels_text\" in inputs:\n                    inputs.pop(\"labels_text\")\n                if \"input_texts\" in inputs:\n                    inputs.pop(\"input_texts\")\n                loss = None\n                with self.compute_loss_context_manager():\n                    try:\n                        outputs = model(**inputs)\n                    except Exception as e:\n                        raise RuntimeError(f\"Error during model forward pass: {e!s}\") from e\n\n                if not hasattr(outputs, \"loss\"):\n                    raise AttributeError(\"Model output does not contain 'loss' attribute\")\n                loss = outputs.loss\n\n                if not hasattr(outputs, \"logits\"):\n                    raise AttributeError(\"Model output does not contain 'logits' attribute\")\n                logits = outputs.logits\n\n                if \"labels\" not in inputs:\n                    raise KeyError(\"'labels' not found in input dictionary\")\n                labels = inputs[\"labels\"]\n\n            if prediction_loss_only:\n                return (loss, None, None)\n            return (loss, logits, labels)\n\n        except Exception as e:\n            print(f\"An error occurred during prediction step: {e!s}\")\n            return (None, None, None)\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        if is_sagemaker_mp_enabled():\n            return super().create_optimizer()\n\n        opt_model = self.model\n\n        if self.optimizer is None:\n            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            if self.args.others_lr is not None:\n                encoder_parameters = [name for name, _ in opt_model.named_parameters() if \"encoder\" in name]\n                optimizer_grouped_parameters = [\n                    {\n                        \"params\": [\n                            p\n                            for n, p in opt_model.named_parameters()\n                            if (n in decay_parameters and n not in encoder_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": self.args.others_weight_decay,\n                        \"lr\": self.args.others_lr,\n                    },\n                    {\n                        \"params\": [\n                            p\n                            for n, p in opt_model.named_parameters()\n                            if (n not in decay_parameters and n not in encoder_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": 0.0,\n                        \"lr\": self.args.others_lr,\n                    },\n                    {\n                        \"params\": [\n                            p\n                            for n, p in opt_model.named_parameters()\n                            if (n in decay_parameters and n in encoder_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": self.args.weight_decay,\n                    },\n                    {\n                        \"params\": [\n                            p\n                            for n, p in opt_model.named_parameters()\n                            if (n not in decay_parameters and n in encoder_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": 0.0,\n                    },\n                ]\n            else:\n                optimizer_grouped_parameters = [\n                    {\n                        \"params\": [\n                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": self.args.weight_decay,\n                    },\n                    {\n                        \"params\": [\n                            p\n                            for n, p in opt_model.named_parameters()\n                            if (n not in decay_parameters and p.requires_grad)\n                        ],\n                        \"weight_decay\": 0.0,\n                    },\n                ]\n\n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n\n            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n\n        return self.optimizer\n\n\n@dataclass\nclass RLTrainerConfig(TrainingArguments):\n    cliprange: float = field(\n        default=0.2,\n        metadata={\"help\": \"Clip range.\"},\n    )\n    num_rl_iters: int = field(\n        default=3,\n        metadata={\"help\": \"Number of RL iterations.\"},\n    )\n    gamma: float = field(\n        default=-1,\n        metadata={\"help\": \"Focal loss gamma.\"},\n    )\n    alpha: float = field(\n        default=-1,\n        metadata={\"help\": \"Focal loss alpha.\"},\n    )\n    labels_smoothing: float = field(default=-1, metadata={\"help\": \"Labels smoothing factor.\"})\n    entropy_beta: float = field(default=-1, metadata={\"help\": \"Coeficient of entropy factor.\"})\n    kl_beta: float = field(default=-1, metadata={\"help\": \"Coeficient of KL-divergence factor.\"})\n    get_actions: str = field(\n        default=\"bernoulli\",\n        metadata={\"help\": \"How to get actions of a model, default is `bernoulli`, another option is `threshold`\"},\n    )\n    threshold: float = field(\n        default=0.5,\n        metadata={\"help\": \"Threshold value for predictions.\"},\n    )\n\n\nclass RLTrainer(Trainer):\n    def __init__(\n        self,\n        value_model: torch.nn.Module | None = None,\n        reference_model: ZeroShotClassificationPipeline | TransformersClassificationPipeline | None = None,\n        reward_components: List[Tuple[str, Callable]] | None = None,\n        *args,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n        if value_model is not None:\n            self.value_model = value_model.to(self.model.device)\n        self.reference_model = reference_model\n        if reward_components is None:\n            reward_components = [(\"f1\", default_f1_reward)]\n        self.reward_components = reward_components\n        self._init_metrics()\n\n    def _init_metrics(self):\n        self.metrics = {\n            \"total_loss\": [],\n            \"advantages\": [],\n        }\n        # Initialize metrics for each reward component\n        for name, _ in self.reward_components.items():\n            self.metrics[f\"reward_{name}\"] = []\n\n    def compute_rewards(\n        self, probs: torch.Tensor, actions: torch.Tensor, original_targets: torch.Tensor, valid_mask: torch.Tensor\n    ) -> Dict[str, torch.Tensor]:\n        rewards = {}\n        total_reward = 0.0\n        for name, reward_fn in self.reward_components.items():\n            component = reward_fn(probs, actions, original_targets, valid_mask)\n            rewards[name] = component\n            total_reward += component\n        rewards[\"total_reward\"] = total_reward\n        return rewards\n\n    def get_reference_scores(self, input_texts, labels_text):\n        if input_texts is None or labels_text is None:\n            return None\n        all_scores = []\n        with torch.no_grad():\n            if isinstance(self.reference_model, ZeroShotClassificationPipeline):\n                results = self.reference_model(input_texts, labels_text, threshold=0.0)\n                for id, result in enumerate(results):\n                    label2score = {item[\"label\"]: item[\"score\"] for item in result}\n                    label_scores = [label2score[label] for label in labels_text[id]]\n                    all_scores.append(label_scores)\n            elif isinstance(self.reference_model, TransformersClassificationPipeline):\n                for text, labels in zip(input_texts, labels_text):\n                    result = self.reference_model(text, labels)\n                    label2score = dict(zip(result[\"labels\"], result[\"scores\"]))\n                    label_scores = [label2score[label] for label in labels_text[id]]\n                    all_scores.append(label_scores)\n            else:\n                raise NotImplementedError(\"This classification pipelines is not supported as a reference model.\")\n        max_length = max(len(seq) for seq in all_scores)\n        all_scores = torch.FloatTensor([seq + [0] * (max_length - len(seq)) for seq in all_scores]).to(\n            self.model.device\n        )\n        return all_scores\n\n    def compute_loss(\n        self,\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        log_prob_prev: torch.Tensor | None = None,\n        value_outputs: torch.Tensor | None = None,\n        reference_probs: torch.Tensor | None = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        valid_mask = targets != -100\n        original_targets = targets.clone()\n\n        probs = torch.sigmoid(inputs)\n\n        if self.args.get_actions == \"bernoulli\":\n            actions = torch.bernoulli(probs).detach()\n        else:\n            actions = (probs > self.args.threshold).float().detach()\n\n        with torch.no_grad():\n            metrics = self.compute_rewards(probs, actions, original_targets, valid_mask)\n\n        reward = metrics[\"total_reward\"]\n\n        if value_outputs is not None:\n            state_values = value_outputs.logits[:, 0].unsqueeze(-1)  # Using first token logits as value prediction\n            value_loss = torch.nn.functional.mse_loss(state_values, reward.detach())\n        else:\n            state_values = reward.mean()\n            value_loss = torch.tensor(0.0).to(inputs.device)\n\n        advantages = (reward - state_values).detach()\n        self.metrics[\"advantages\"].append(advantages.mean().item())\n\n        for name, _ in self.reward_components.items():\n            key = f\"reward_{name}\"\n            self.metrics[key].append(metrics[name].mean().item())\n\n        if self.args.label_smoothing_factor > 0:\n            smoothed_actions = actions * (1 - self.args.label_smoothing_factor) + 0.5 * self.args.label_smoothing_factor\n            log_prob_current = smoothed_actions * torch.log(probs + 1e-8) + (1 - smoothed_actions) * torch.log(\n                1 - probs + 1e-8\n            )\n        else:\n            log_prob_current = actions * torch.log(probs + 1e-8) + (1 - actions) * torch.log(1 - probs + 1e-8)\n\n        if log_prob_prev is None:\n            log_prob_prev = log_prob_current.detach()\n\n        log_probs_diff = log_prob_current - log_prob_prev\n        ratio = torch.exp(log_probs_diff)\n\n        cliprange = self.args.cliprange\n        per_label_loss1 = ratio * advantages\n        per_label_loss2 = torch.clamp(ratio, 1 - cliprange, 1 + cliprange) * advantages\n        loss_elements = -torch.min(per_label_loss1, per_label_loss2)\n\n        loss_elements = loss_elements * valid_mask\n        self.metrics[\"total_loss\"].append(loss_elements.mean().item())\n\n        if self.args.gamma > 0:\n            p_t = probs * original_targets + (1 - probs) * (1 - original_targets)\n            loss_elements = loss_elements * (p_t**self.args.gamma)\n\n        if self.args.alpha >= 0:\n            alpha_t = self.args.alpha * original_targets + (1 - self.args.alpha) * (1 - original_targets)\n            loss_elements = alpha_t * loss_elements\n\n        loss = loss_elements.sum() / valid_mask.shape[0] + value_loss\n\n        if reference_probs is not None:\n            ref_per_token_logps = torch.log(reference_probs + 1e-8)\n            per_token_logps = log_prob_current\n            per_label_kl = (\n                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1\n            )\n            per_label_kl = per_label_kl * valid_mask\n            kl_loss = self.args.kl_beta * per_label_kl.mean()\n            loss = loss + kl_loss\n\n        if self.args.entropy_beta:\n            entropy = -(probs * torch.log(probs + 1e-8) + (1 - probs) * torch.log(1 - probs + 1e-8))\n            loss = loss + self.args.entropy_beta * entropy.mean()\n\n        return loss, log_prob_current\n\n    def _inner_training_loop(self, *args, **kwargs):\n        self.create_optimizer()\n        if self.value_model is not None:\n            value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=self.args.learning_rate)\n        args = self.args\n        accelerator = self.accelerator\n        optimizer = self.optimizer\n        model = self.model\n        dataloader = self.get_train_dataloader()\n        device = accelerator.device\n\n        num_local_steps = len(dataloader)\n        num_iters = args.num_train_epochs * num_local_steps\n        pbar = tqdm(total=num_iters, desc=\"Training iterations\")\n        self._init_metrics()\n\n        for epoch in range(args.num_train_epochs):\n            self._init_metrics()\n            model.train()\n            if self.value_model is not None:\n                self.value_model.train()\n\n            for step, inputs in enumerate(dataloader):\n                global_step = step + epoch * num_local_steps\n\n                inputs = self._prepare_inputs(inputs)\n                labels = inputs.pop(\"labels\").to(device)\n                if \"labels_text\" in inputs:\n                    labels_text = inputs.pop(\"labels_text\")\n                else:\n                    labels_text = None\n                if \"input_texts\" in inputs:\n                    input_texts = inputs.pop(\"input_texts\")\n                else:\n                    input_texts = None\n                prev_logps = None\n                for _iter in range(args.num_rl_iters):\n                    try:\n                        outputs = model(**inputs)\n                        logits = outputs.logits\n                        if self.value_model is not None:\n                            value_outputs = self.value_model(**inputs)\n                        else:\n                            value_outputs = None\n                        if self.reference_model is not None:\n                            reference_probs = self.get_reference_scores(input_texts, labels_text)\n                        else:\n                            reference_probs = None\n                        loss, current_logps = self.compute_loss(\n                            logits,\n                            labels,\n                            log_prob_prev=prev_logps,\n                            value_outputs=value_outputs,\n                            reference_probs=reference_probs,\n                        )\n                    except Exception as e:\n                        print(f\"An error occurred during training step: {e!s}\")\n                        del inputs\n                        model.zero_grad(set_to_none=True)\n                        torch.cuda.empty_cache()\n                        break\n\n                    accelerator.backward(loss)\n                    if self.args.max_grad_norm is not None:\n                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)\n                        if self.value_model is not None:\n                            torch.nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.max_grad_norm)\n\n                    optimizer.step()\n                    optimizer.zero_grad()\n                    if self.value_model is not None:\n                        value_optimizer.step()\n                        value_optimizer.zero_grad()\n\n                    prev_logps = current_logps.detach()\n\n                if global_step % args.logging_steps == 0:\n                    self.log_metrics()\n\n                if args.save_steps is not None and global_step % args.save_steps == 0:\n                    self._save_checkpoint(model, step=global_step)\n\n                pbar.set_postfix(epoch=epoch, step=step)\n                pbar.update(1)\n\n            if args.evaluation_strategy == \"epoch\":\n                self.evaluate()\n\n    def log_metrics(self):\n        logged_metrics = {\n            \"loss\": np.mean(self.metrics[\"total_loss\"]),\n            \"advantages\": np.mean(self.metrics[\"advantages\"]),\n        }\n        # Add user reward components\n        for name, _ in self.reward_components.items():\n            key = f\"reward_{name}\"\n            logged_metrics[key] = np.mean(self.metrics[key])\n        self.log(logged_metrics)\n        self._init_metrics()\n\n    def _save_checkpoint(self, model, step=None):\n        checkpoint_dir = f\"checkpoint-{step}\" if step else \"final_model\"\n        output_dir = os.path.join(self.args.output_dir, checkpoint_dir)\n        os.makedirs(output_dir, exist_ok=True)\n        model.save_pretrained(output_dir)\n        if self.tokenizer is not None:\n            self.tokenizer.save_pretrained(output_dir)\n        print(f\"Checkpoint saved to {output_dir}\")\n"
  },
  {
    "path": "gliclass/utils.py",
    "content": "import torch\n\n\ndef is_module_available(module_name):\n    \"\"\"\n    Checks whether the specified Python module is available.\n\n    Args:\n        module_name (str): The name of the module to check.\n\n    Returns:\n        bool: True if the module is available, False otherwise.\n    \"\"\"\n    try:\n        __import__(module_name)\n        return True\n    except ImportError:\n        return False\n\n\nclass MissedPackageException(Exception):\n    \"\"\"Raised when the requested decoder model is not supported.\"\"\"\n\n    pass\n\n\ndef retrieval_augmented_text(text: str, examples: list) -> str:\n    \"\"\"\n    Constructs a new text by appending relevant retrieved examples to the input text.\n\n    Args:\n        text (str): The main input text.\n        examples (list): A list of examples in the format\n                         {\"text\": str, \"true_labels\": List[str], \"all_labels\": List[str]}.\n\n    Returns:\n        str: The modified text with relevant examples appended.\n    \"\"\"\n    if not examples:\n        return text\n\n    retrieved_examples = []\n    all_labels = {label for example in examples for label in example.get(\"true_labels\", [])}\n    relevant_examples = [ex for ex in examples if set(ex.get(\"true_labels\", [])) & all_labels]\n\n    for example in relevant_examples:\n        example_text = example[\"text\"]\n        true_labels = example.get(\"true_labels\", [])\n        all_labels = example.get(\"all_labels\", [])\n\n        false_labels = list(set(all_labels) - set(true_labels))\n\n        true_labels_str = \" \".join([f\"<<TRUE_LABEL>> {label}\" for label in true_labels])\n        false_labels_str = \" \".join([f\"<<FALSE_LABEL>> {label}\" for label in false_labels])\n\n        retrieved_example_str = f\"<<EXAMPLE>> {example_text} {true_labels_str} {false_labels_str} <</EXAMPLE>>\"\n        retrieved_examples.append(retrieved_example_str)\n\n    augmented_text = f\"{text} {' '.join(retrieved_examples)}\" if retrieved_examples else text\n\n    return augmented_text\n\n\ndef default_f1_reward(\n    probs: torch.Tensor, actions: torch.Tensor, original_targets: torch.Tensor, valid_mask: torch.Tensor\n) -> torch.Tensor:\n    \"\"\"\n    A variant that extracts list-of-indices sets and then calculates\n    the F1 score in a classical manner. Returns shape (N, 1).\n\n    Args:\n        probs:              (N, T) Tensor of probabilities (not used here but left for interface consistency).\n        actions:            (N, T) Tensor of predicted labels in {0, 1}.\n        original_targets:   (N, T) Tensor of ground-truth labels in {0, 1}.\n        valid_mask:         (N, T) Tensor indicating which positions are valid (1) vs. invalid (0).\n\n    Returns:\n        f1_scores: (N, 1) Tensor containing the F1 score for each row.\n    \"\"\"\n    N = actions.shape[0]\n    f1_scores = []\n\n    for i in range(N):\n        # Filter valid positions\n        valid_preds_i = actions[i] * valid_mask[i]\n        valid_targets_i = original_targets[i] * valid_mask[i]\n\n        # Get the set of indices where we predicted 1\n        predicted_set = set((valid_preds_i == 1).nonzero(as_tuple=True)[0].tolist())\n        # Get the set of indices where the ground truth is 1\n        target_set = set((valid_targets_i == 1).nonzero(as_tuple=True)[0].tolist())\n\n        # Compute intersection\n        intersection = predicted_set.intersection(target_set)\n\n        # Precision\n        if len(predicted_set) > 0:\n            precision = len(intersection) / len(predicted_set)\n        else:\n            precision = 0.0\n\n        # Recall\n        if len(target_set) > 0:\n            recall = len(intersection) / len(target_set)\n        else:\n            recall = 0.0\n\n        # F1 score\n        if (precision + recall) > 0:\n            f1 = 2 * precision * recall / (precision + recall)\n        else:\n            f1 = 0.0\n\n        f1_scores.append(f1)\n\n    # Convert list to tensor shape (N, 1)\n    f1_scores = torch.tensor(f1_scores, dtype=torch.float).unsqueeze(-1)\n    return f1_scores.detach().to(probs.device)\n"
  },
  {
    "path": "notebooks/finetuning.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"os.environ[\\\"TOKENIZERS_PARALLELISM\\\"] = \\\"true\\\"\\n\",\n    \"\\n\",\n    \"from datasets import load_dataset, Dataset, DatasetDict\\n\",\n    \"\\n\",\n    \"from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score\\n\",\n    \"import numpy as np\\n\",\n    \"import random\\n\",\n    \"\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"from gliclass import GLiClassModel, ZeroShotClassificationPipeline\\n\",\n    \"from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding\\n\",\n    \"from gliclass.training import TrainingArguments, Trainer\\n\",\n    \"\\n\",\n    \"device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_gliclass_predictions(pipeline, test_texts, classes, batch_size=8):\\n\",\n    \"    results = pipeline(test_texts, classes, batch_size=batch_size)#, labels_chunk_size=1)\\n\",\n    \"    predicts = [result[0]['label'] for result in results]\\n\",\n    \"    return predicts\\n\",\n    \"\\n\",\n    \"def evaluate(predicts, true_labels):\\n\",\n    \"    micro = f1_score(true_labels, predicts, average=\\\"micro\\\")\\n\",\n    \"    macro = f1_score(true_labels, predicts, average=\\\"macro\\\")\\n\",\n    \"    weighted = f1_score(true_labels, predicts, average=\\\"weighted\\\")\\n\",\n    \"    return {\\\"micro\\\": micro, \\\"macro\\\": macro, \\\"weighted\\\": weighted}\\n\",\n    \"\\n\",\n    \"def get_train_dataset(dataset, N, label_column='label'):\\n\",\n    \"    ids = []\\n\",\n    \"    label2count = {}\\n\",\n    \"    train_dataset = dataset.shuffle(seed=41)\\n\",\n    \"    for id, example in enumerate(train_dataset):\\n\",\n    \"        if example[label_column] not in label2count:\\n\",\n    \"            label2count[example[label_column]]=1\\n\",\n    \"        elif label2count[example[label_column]]>=N:\\n\",\n    \"            continue\\n\",\n    \"        else:\\n\",\n    \"            label2count[example[label_column]]+=1\\n\",\n    \"        ids.append(id)\\n\",\n    \"    return train_dataset.select(ids)\\n\",\n    \"\\n\",\n    \"def prepare_dataset(dataset, classes = None, text_column = 'text', label_column = \\\"label\\\", split=None):\\n\",\n    \"    if 'test' in dataset:\\n\",\n    \"        test_dataset = dataset['test']\\n\",\n    \"    elif isinstance(dataset, Dataset):\\n\",\n    \"        test_dataset = dataset\\n\",\n    \"    else:\\n\",\n    \"        test_dataset = dataset['train']\\n\",\n    \"    \\n\",\n    \"    if classes is None:\\n\",\n    \"        classes = test_dataset.features[label_column].names\\n\",\n    \"        if split is not None:\\n\",\n    \"            classes = [' '.join(class_.split(split)) for class_ in classes]\\n\",\n    \"\\n\",\n    \"    texts = test_dataset[text_column]\\n\",\n    \"\\n\",\n    \"    true_labels = test_dataset[label_column]\\n\",\n    \"\\n\",\n    \"    print(classes)\\n\",\n    \"    if type(test_dataset[label_column][0]) == int:\\n\",\n    \"        true_labels = [classes[label] for label in true_labels]\\n\",\n    \"\\n\",\n    \"    return texts, classes, true_labels\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def prepare_dataset_for_training(train_dataset, classes, text_column='text', label_column='label'):\\n\",\n    \"    id2class = {id: class_ for id, class_ in enumerate(classes)}\\n\",\n    \"    dataset = []\\n\",\n    \"    for example in train_dataset:\\n\",\n    \"        label = example[label_column]\\n\",\n    \"        if type(label)==int:\\n\",\n    \"            label = id2class[label]\\n\",\n    \"        item = {'text': example[text_column], 'all_labels': classes, 'true_labels': [label]}\\n\",\n    \"        dataset.append(item)\\n\",\n    \"    random.shuffle(dataset)\\n\",\n    \"    return dataset\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"emotions = load_dataset('dair-ai/emotion')\\n\",\n    \"\\n\",\n    \"train_data = get_train_dataset(emotions['train'], N=64)\\n\",\n    \"\\n\",\n    \"test_texts, classes, true_labels = prepare_dataset(emotions)\\n\",\n    \"\\n\",\n    \"train_data = prepare_dataset_for_training(train_data, classes)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ag_news = load_dataset('ag_news')\\n\",\n    \"\\n\",\n    \"train_data = get_train_dataset(ag_news['train'], N=64)\\n\",\n    \"\\n\",\n    \"test_texts, classes, true_labels = prepare_dataset(ag_news)\\n\",\n    \"\\n\",\n    \"train_data = prepare_dataset_for_training(train_data, classes)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"sst5 = load_dataset('SetFit/sst5')\\n\",\n    \"\\n\",\n    \"train_data = get_train_dataset(sst5['train'], N=64)\\n\",\n    \"\\n\",\n    \"classes = ['very negative', 'negative', 'neutral', 'positive', 'very positive']\\n\",\n    \"\\n\",\n    \"test_texts, classes, true_labels = prepare_dataset(sst5, classes=classes)\\n\",\n    \"\\n\",\n    \"train_data = prepare_dataset_for_training(train_data, classes)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"banking = load_dataset('PolyAI/banking77')\\n\",\n    \"\\n\",\n    \"train_data = get_train_dataset(banking['train'], N=32)\\n\",\n    \"\\n\",\n    \"test_texts, classes, true_labels = prepare_dataset(banking)\\n\",\n    \"\\n\",\n    \"train_data = prepare_dataset_for_training(train_data, classes)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"massive = load_dataset(\\\"AmazonScience/massive\\\", \\\"en-US\\\")\\n\",\n    \"\\n\",\n    \"train_data = get_train_dataset(massive['train'], N=32, label_column='intent')\\n\",\n    \"\\n\",\n    \"test_texts, classes, true_labels = prepare_dataset(massive, text_column='utt', label_column='intent')\\n\",\n    \"\\n\",\n    \"train_data = prepare_dataset_for_training(train_data, classes,  text_column='utt', label_column='intent')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_name = 'knowledgator/gliclass-base-v1.0'\\n\",\n    \"\\n\",\n    \"model = GLiClassModel.from_pretrained(model_name).to(device)\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"max_length = 1024\\n\",\n    \"problem_type = \\\"multi_label_classification\\\"\\n\",\n    \"architecture_type = model.config.architecture_type\\n\",\n    \"prompt_first = model.config.prompt_first\\n\",\n    \"\\n\",\n    \"train_dataset = GLiClassDataset(train_data, tokenizer, max_length, problem_type, architecture_type, prompt_first)\\n\",\n    \"test_dataset = GLiClassDataset(train_data[:int(len(train_data)*0.1)], tokenizer, max_length, problem_type, architecture_type, prompt_first)\\n\",\n    \"\\n\",\n    \"data_collator = DataCollatorWithPadding(device=device)\\n\",\n    \"\\n\",\n    \"training_args = TrainingArguments(\\n\",\n    \"    output_dir='models/test',\\n\",\n    \"    learning_rate=1e-5,\\n\",\n    \"    weight_decay=0.01,\\n\",\n    \"    others_lr=1e-5,\\n\",\n    \"    others_weight_decay=0.01,\\n\",\n    \"    lr_scheduler_type='linear',\\n\",\n    \"    warmup_ratio=0.0,\\n\",\n    \"    per_device_train_batch_size=8,\\n\",\n    \"    per_device_eval_batch_size=8,\\n\",\n    \"    num_train_epochs=8,\\n\",\n    \"    evaluation_strategy=\\\"epoch\\\",\\n\",\n    \"    save_steps = 1000,\\n\",\n    \"    save_total_limit=10,\\n\",\n    \"    dataloader_num_workers=8,\\n\",\n    \"    logging_steps=10,\\n\",\n    \"    use_cpu = False,\\n\",\n    \"    report_to=\\\"none\\\",\\n\",\n    \"    fp16=False,\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"trainer = Trainer(\\n\",\n    \"    model=model,\\n\",\n    \"    args=training_args,\\n\",\n    \"    train_dataset=train_dataset,\\n\",\n    \"    eval_dataset=test_dataset,\\n\",\n    \"    tokenizer=tokenizer,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \")\\n\",\n    \"trainer.train()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='single-label', device='cuda:0')\\n\",\n    \"\\n\",\n    \"predicts = get_gliclass_predictions(pipeline, test_texts, classes, batch_size=8)\\n\",\n    \"\\n\",\n    \"results = evaluate(predicts, true_labels)\\n\",\n    \"print(results)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.10\"\n  },\n  \"orig_nbformat\": 4\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.setuptools.packages.find]\ninclude = [\"gliclass\", \"gliclass.*\"]\n\n[tool.setuptools.dynamic]\nversion = {attr = \"gliclass.__version__\"}\n\n[project]\nname = \"gliclass\"\ndescription = \"Generalist and Lightweight Model for Text Classification\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = {text = \"Apache-2.0\"}\nkeywords = [\n    \"text-classification\",\n    \"zero-shot-classification\",\n    \"natural-language-processing\",\n    \"nlp\",\n    \"transformers\",\n    \"machine-learning\"\n]\nauthors = [\n    {name = \"knowledgator.com\"}\n]\n\ndependencies = [\n    \"torch>=2.0.0\",\n    \"transformers>=5.0.0\",\n    \"scikit-learn>=1.0.0\",\n    \"numpy>=2.0.0\",\n]\n\ndynamic = [\"version\"]\n\n[project.optional-dependencies]\nserve = [\n    \"ray[serve]>=2.0.0\",\n    \"requests>=2.31.0\",\n    \"pyyaml>=6.0.0\",\n]\n\n[dependency-groups]\n# version pins are in uv.lock\ndev = [\n    \"pytest\",\n    \"pytest-asyncio\",\n    \"ruff\"\n]\n\n[tool.pytest.ini_options]\npythonpath = [\".\"]\ntestpaths = [\"tests\"]\n\n[tool.ruff]\nline-length = 120\nindent-width = 4\noutput-format = \"grouped\"\ntarget-version = \"py310\"\n\nexclude = [\n    \".git\",\n    \".ruff_cache\",\n    \".venv\",\n    \"__pycache__\",\n    \"build\",\n    \"dist\",\n    \"*.egg-info\",\n]\n\n[tool.ruff.format]\ndocstring-code-format = true\nquote-style = \"double\"\nindent-style = \"space\"\nskip-magic-trailing-comma = false\nline-ending = \"auto\"\n\n[tool.ruff.lint]\nselect = [\n    \"F\",      # Pyflakes\n    \"E\",      # pycodestyle errors\n    \"W\",      # pycodestyle warnings\n    \"I\",      # isort\n    \"D\",      # pydocstyle\n    \"UP\",     # pyupgrade\n    \"B\",      # bugbear\n    \"SIM\",    # simplify\n    \"ARG\",    # unused arguments\n    \"T20\",    # print statements\n    \"C4\",     # comprehensions\n    \"EM\",     # errmsg\n    \"PL\",     # Pylint\n    \"RUF\"     # Ruff-specific rules\n]\n\nignore = [\n    \"D100\",   # Missing docstring in public module\n    \"D101\",   # Missing docstring in public class\n    \"D102\",   # Missing docstring in public method\n    \"D103\",   # Missing docstring in public function\n    \"D104\",   # Missing docstring in public package\n    \"D105\",   # Missing docstring in magic method\n    \"D107\",   # Missing docstring in `__init__`\n    \"D200\",   # One-line docstring should fit on one line\n    \"D205\",   # Blank line required between summary and description\n    \"D212\",   # Multi-line docstring summary should start at the first line\n    \"D400\",   # First line should end with a period\n    \"D401\",   # First line should be in imperative mood\n    \"D417\",   # Missing argument descriptions\n    \"RUF012\", # Mutable class attributes should be annotated\n    \"PLR0913\",# Too many arguments\n    \"PLR0912\",# Too many branches\n    \"PLR0915\",# Too many statements\n    \"PLR2004\",# Magic value used in comparison\n    \"PLW2901\",# Loop variable overwritten\n    \"B006\",   # Mutable defaults\n    \"S101\",   # Use of `assert` detected\n    \"SIM105\", # Use contextlib.suppress instead of try-except-pass\n    \"SIM108\", # Use ternary operator\n    \"UP035\",  # Deprecated typing imports\n    \"UP006\",  # Deprecated typing imports\n    \"EM101\",  # Exception message formatting\n    \"EM102\",  # Exception message formatting\n    \"ARG002\", # Unused method arguments\n    \"B905\",   # zip() without explicit strict= parameter\n    \"E402\",   # Module level import not at top (for conditional imports)\n    \"PLC0206\",# Extracting value from dictionary without .items()\n    \"PLC0415\",# Import should be at top-level (for dynamic imports)\n    \"T201\",   # Print statements (used for debugging)\n]\n\nunfixable = [\n    \"F401\",   # Unused imports\n    \"T201\",   # Print statements\n    \"T203\",   # pprint statements\n    \"F841\",   # Unused variables\n]\n\ndummy-variable-rgx = \"^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$\"\n\n[tool.ruff.lint.per-file-ignores]\n\"__init__.py\" = [\n    \"F401\",  # Unused imports in __init__.py are often intentional\n    \"F403\",  # Star imports in __init__.py\n]\n\"tests/**/*.py\" = [\n    \"D\",      # No docstring requirements in tests\n    \"ARG\",    # Unused arguments are common in test fixtures\n    \"PLR2004\",# Magic values are fine in tests\n    \"S101\",   # Assert is expected in tests\n]\n\n[tool.ruff.lint.isort]\nlength-sort = true\nlength-sort-straight = true\ncombine-as-imports = true\nknown-first-party = [\"gliclass\"]\n\n[tool.ruff.lint.pydocstyle]\nconvention = \"google\"\n\n[tool.ruff.lint.pylint]\nmax-args = 20\nmax-branches = 15\nmax-returns = 8\nmax-statements = 60\n\n[project.urls]\nHomepage = \"https://github.com/knowledgator/gliclass\""
  },
  {
    "path": "serve_configs/serve_config.yaml",
    "content": "# Model configuration\nmodel: knowledgator/gliclass-edge-v3.0\ndevice: cuda\ndtype: float16\n\n# Limits\nmax_model_len: 2048\nmax_labels: -1  # -1 for unlimited\nmax_labels_alloc: dynamic  # Memory allocation strategy: \"dynamic\", \"fixed\", or integer (e.g., 50)\n\n# Thresholds\ndefault_threshold: 0.5\n\n# Ray Serve deployment\nnum_replicas: 1\nnum_gpus_per_replica: 1.0\nnum_cpus_per_replica: 1.0\n\n# Batching configuration\nmax_batch_size: 128\nbatch_wait_timeout_ms: 50.0\nrequest_timeout_s: 30.0\nmax_ongoing_requests: 512\nqueue_capacity: 8192\n\n# HTTP configuration\nroute_prefix: /gliclass\nhttp_port: 8000\n\n# Threading\ntokenizer_threads: 4\n\n# Optimization features\nenable_compilation: true\ncalibrate_on_startup: false\nprecompile_on_startup: false\nuse_memory_aware_batching: false\n\n# Precompiled batch sizes (for compilation warmup)\nprecompiled_batch_sizes:\n  - 1\n  - 2\n  - 4\n  - 8\n  - 16\n  - 32\n  - 64\n  - 128\n\n# Memory management\ntarget_memory_fraction: 0.8\nmemory_overhead_factor: 1.3\n\n# Calibration settings\ncalibration_min_seq_len: 64\ncalibration_min_batch_size: 1\ncalibration_max_batch_size: 64\ncalibration_probe_batch_size: 2\n\n# Warmup\nwarmup_iterations: 3\n\n# Ray cluster (optional)\nray_address: null  # null for local, or ray://<head_node>:10001 for cluster\n"
  },
  {
    "path": "test_gliclass.py",
    "content": "from gliclass import GLiClassModel, ZeroShotClassificationPipeline\nfrom transformers import AutoTokenizer\nfrom datasets import load_dataset, Dataset\nfrom datasets import ClassLabel\nfrom sklearn.metrics import f1_score\nimport numpy as np\nfrom transformers import AutoTokenizer\nimport torch\nimport argparse\n\ndevice = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')\n\n\nclass TestModel:\n\n    def __init__(self, model, token):\n        self.model_name = model\n        self.model = None\n        self.tokeinzer = None\n        self.token=token\n        self.datasets = [\"SetFit/CR\", \"SetFit/sst2\", \"SetFit/sst5\", 'stanfordnlp/imdb',\n                         \"SetFit/20_newsgroups\", \"SetFit/enron_spam\", \"AmazonScience/massive\",\n                         'PolyAI/banking77', 'takala/financial_phrasebank','ag_news', 'dair-ai/emotion',\n                         \"MoritzLaurer/cap_sotu\", 'cornell-movie-review-data/rotten_tomatoes']\n        self.pipeline = None\n\n        self.macro_scores = []\n    def load_model(self):\n        self.model = GLiClassModel.from_pretrained(self.model_name, token=self.token).to(dtype=torch.float16)\n        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.token, add_prefix_space=True)\n        self.pipeline = ZeroShotClassificationPipeline(self.model, self.tokenizer, classification_type='single-label',\n                                                       device='cuda:0')\n\n    def prepare_dataset(self, dataset, classes=None, text_column='text', label_column=\"label_text\", split=None):\n\n        if 'test' in dataset:\n            test_dataset = dataset['test']\n        elif isinstance(dataset, Dataset):\n            test_dataset = dataset\n        else:\n            test_dataset = dataset['train']\n        if classes is None:\n            classes = test_dataset[label_column]\n            classes = list(set(classes))\n            if split is not None:\n                classes = [' '.join(class_.split(split)) for class_ in classes]\n        texts = test_dataset[text_column]\n        true_labels = test_dataset[label_column]\n        print(true_labels[:5])\n        print(classes)\n        if type(test_dataset[label_column][0]) == int:\n            true_labels = [classes[label] for label in true_labels]\n        return texts, classes, true_labels\n\n    def prepare_nomapping(self, dataset, classes=None, text_column='text', label_column='label_text', split=None):\n        if 'test' in dataset:\n            test_dataset = dataset['test']\n        elif isinstance(dataset, Dataset):\n            test_dataset = dataset\n        else:\n            test_dataset = dataset['train']\n        if classes is None:\n            if isinstance(test_dataset.features[label_column], ClassLabel):\n                classes = test_dataset.features[label_column].names\n            else:\n                classes = test_dataset[label_column]\n                classes = list(set(classes))\n                if split is not None:\n                    classes = [' '.join(class_.split(split)) for class_ in classes]\n        texts = test_dataset[text_column]\n        true_labels = test_dataset[label_column]\n        # if isinstance(test_dataset.features[label_column], ClassLabel):\n        #     true_labels = [test_dataset.features[label_column].int2str(label) for label in true_labels]\n        if type(true_labels[0]) == int:\n            true_labels = [classes[label] for label in true_labels]\n\n        return texts, classes, true_labels\n\n    def get_gliclass_predictions(self, test_texts, classes, batch_size=8):\n        results = self.pipeline(test_texts, classes, batch_size=batch_size)\n        predicts = [result[0]['label'] for result in results]\n        return predicts\n\n    def evaluate(self, predicts, true_labels):\n        micro = f1_score(true_labels, predicts, average=\"micro\")\n        macro = f1_score(true_labels, predicts, average=\"macro\")\n        weighted = f1_score(true_labels, predicts, average=\"weighted\")\n        return {\"micro\": micro, \"macro\": macro, \"weighted\": weighted}\n\n    def process(self):\n        self.load_model()\n        for dataset in self.datasets:\n            classes = None\n            print(dataset)\n            if dataset == 'SetFit/sst5':\n                classes = ['very negative', 'negative', 'neutral', 'positive', 'very positive']\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, classes=classes)\n            elif dataset == 'PolyAI/banking77':\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')\n            elif dataset == 'takala/financial_phrasebank':\n                ds = load_dataset('takala/financial_phrasebank', 'sentences_allagree', trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='sentence',\n                                                                          label_column=\"label\")\n            elif dataset == \"AmazonScience/massive\":\n                ds = load_dataset(dataset,\"en-US\")\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='utt',\n                                                                          label_column=\"intent\")\n            elif dataset == 'stanfordnlp/imdb':\n                ds = load_dataset(dataset, trust_remote_code=True)\n                classes = ['negative', 'positive']\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, classes=classes, text_column='text', label_column='label')\n                print(true_labels[0], classes)\n            elif dataset == 'ag_news':\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')\n            elif dataset == 'dair-ai/emotion':\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')\n            elif dataset == 'MoritzLaurer/cap_sotu':\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='labels')\n            elif dataset == 'cornell-movie-review-data/rotten_tomatoes':\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')\n            elif dataset == 'massive':\n                ds = load_dataset(\"AmazonScience/massive\", \"en-US\", trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='utt', label_column='intent')\n            else:\n                ds = load_dataset(dataset, trust_remote_code=True)\n                test_texts, classes, true_labels = self.prepare_nomapping(ds)\n            predicts = self.get_gliclass_predictions(test_texts, classes, batch_size=8)\n            results = self.evaluate(predicts, true_labels)\n            self.macro_scores.append(results['macro'])\n            print(results)\n        print('Average Score:', np.mean(self.macro_scores))\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run TestModel with arguments\")\n    parser.add_argument(\"--model\", type=str, required=True, help=\"Model name to use\")\n    parser.add_argument(\"--api_key\", type=str, required=False, default = None, help=\"API key for authentication\")\n\n    args = parser.parse_args()\n\n    gliclasstest = TestModel(args.model, args.api_key)\n    gliclasstest.process()"
  },
  {
    "path": "tests/test_data_processing.py",
    "content": "\"\"\"Tests for gliclass.data_processing module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.data_processing import pad_2d_tensor\n\n\nclass TestPad2DTensor:\n    \"\"\"Test suite for pad_2d_tensor function.\"\"\"\n\n    @pytest.fixture\n    def sample_tensors(self):\n        \"\"\"Fixture providing sample 2D tensors of varying sizes.\"\"\"\n        return [\n            torch.tensor([[1, 2], [3, 4]]),  # 2x2\n            torch.tensor([[5, 6, 7]]),  # 1x3\n            torch.tensor([[8], [9], [10]]),  # 3x1\n        ]\n\n    def test_pads_to_maximum_dimensions(self, sample_tensors):\n        \"\"\"Should pad all tensors to match the maximum rows and columns.\"\"\"\n        result = pad_2d_tensor(sample_tensors)\n\n        # batch_size=3, max_rows=3, max_cols=3\n        assert result.shape == (3, 3, 3)\n\n    def test_preserves_original_values(self, sample_tensors):\n        \"\"\"Should preserve all original tensor values in padded output.\"\"\"\n        result = pad_2d_tensor(sample_tensors)\n\n        # First tensor: check original values\n        assert result[0, 0, 0] == 1\n        assert result[0, 0, 1] == 2\n        assert result[0, 1, 0] == 3\n        assert result[0, 1, 1] == 4\n\n        # Second tensor\n        assert result[1, 0, 0] == 5\n        assert result[1, 0, 1] == 6\n        assert result[1, 0, 2] == 7\n\n        # Third tensor\n        assert result[2, 0, 0] == 8\n        assert result[2, 1, 0] == 9\n        assert result[2, 2, 0] == 10\n\n    def test_pads_with_zeros(self, sample_tensors):\n        \"\"\"Should fill padding positions with zeros.\"\"\"\n        result = pad_2d_tensor(sample_tensors)\n\n        # Check padded positions\n        assert result[0, 2, 0] == 0  # Row padding\n        assert result[0, 0, 2] == 0  # Column padding\n        assert result[1, 1, 0] == 0  # Row padding in second tensor\n        assert result[2, 0, 1] == 0  # Column padding in third tensor\n\n    def test_single_tensor(self):\n        \"\"\"Should handle a single tensor correctly.\"\"\"\n        single_tensor = [torch.tensor([[1, 2], [3, 4]])]\n\n        result = pad_2d_tensor(single_tensor)\n\n        assert result.shape == (1, 2, 2)\n        assert torch.allclose(result[0], single_tensor[0].long())\n\n    def test_uniform_size_tensors(self):\n        \"\"\"Should handle tensors that are already the same size.\"\"\"\n        uniform_tensors = [\n            torch.tensor([[1, 2], [3, 4]]),\n            torch.tensor([[5, 6], [7, 8]]),\n        ]\n\n        result = pad_2d_tensor(uniform_tensors)\n\n        assert result.shape == (2, 2, 2)\n        # No padding needed, should match originals exactly\n        assert torch.allclose(result[0], uniform_tensors[0].long())\n        assert torch.allclose(result[1], uniform_tensors[1].long())\n\n    def test_empty_tensor_handling(self):\n        \"\"\"Should handle tensors with zero dimensions.\"\"\"\n        tensors_with_empty = [\n            torch.tensor([[1, 2]]),\n            torch.tensor([[]]),  # Empty second dimension\n        ]\n\n        result = pad_2d_tensor(tensors_with_empty)\n\n        # Should not crash and should return valid tensor\n        assert isinstance(result, torch.Tensor)\n        assert result.shape[0] == 2  # batch size\n\n    def test_preserves_dtype(self):\n        \"\"\"Should preserve data type of input tensors.\"\"\"\n        float_tensors = [\n            torch.tensor([[1.5, 2.5]], dtype=torch.float32),\n            torch.tensor([[3.5]], dtype=torch.float32),\n        ]\n\n        result = pad_2d_tensor(float_tensors)\n\n        assert result.dtype == torch.float32\n\n    def test_varying_row_counts(self):\n        \"\"\"Should handle tensors with different numbers of rows.\"\"\"\n        tensors = [\n            torch.tensor([[1]]),  # 1 row\n            torch.tensor([[2], [3], [4], [5]]),  # 4 rows\n            torch.tensor([[6], [7]]),  # 2 rows\n        ]\n\n        result = pad_2d_tensor(tensors)\n\n        # Max rows = 4\n        assert result.shape == (3, 4, 1)\n\n    def test_varying_column_counts(self):\n        \"\"\"Should handle tensors with different numbers of columns.\"\"\"\n        tensors = [\n            torch.tensor([[1, 2, 3, 4, 5]]),  # 5 cols\n            torch.tensor([[6, 7]]),  # 2 cols\n            torch.tensor([[8]]),  # 1 col\n        ]\n\n        result = pad_2d_tensor(tensors)\n\n        # Max cols = 5\n        assert result.shape == (3, 1, 5)\n\n    def test_batch_consistency(self):\n        \"\"\"Should maintain batch order and size.\"\"\"\n        tensors = [\n            torch.tensor([[1]]),\n            torch.tensor([[2]]),\n            torch.tensor([[3]]),\n        ]\n\n        result = pad_2d_tensor(tensors)\n\n        assert result.shape[0] == 3\n        assert result[0, 0, 0] == 1\n        assert result[1, 0, 0] == 2\n        assert result[2, 0, 0] == 3\n"
  },
  {
    "path": "tests/test_loss_functions.py",
    "content": "\"\"\"Tests for gliclass.loss_functions module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.loss_functions import sequence_contrastive_loss, focal_loss_with_logits\n\n\nclass TestSequenceContrastiveLoss:\n    \"\"\"Test suite for sequence_contrastive_loss function.\"\"\"\n\n    @pytest.fixture\n    def sample_embeddings(self):\n        \"\"\"Sample embeddings for testing.\"\"\"\n        batch_size = 2\n        seq_len = 4\n        embed_dim = 8\n        return torch.randn(batch_size, seq_len, embed_dim)\n\n    @pytest.fixture\n    def sample_mask(self):\n        \"\"\"Sample mask indicating valid positions.\"\"\"\n        return torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.float32)\n\n    def test_returns_scalar_loss(self, sample_embeddings, sample_mask):\n        \"\"\"Should return a scalar loss value.\"\"\"\n        loss = sequence_contrastive_loss(sample_embeddings, sample_mask)\n\n        assert isinstance(loss, torch.Tensor)\n        assert loss.dim() == 0  # Scalar\n\n    def test_loss_is_positive(self, sample_embeddings, sample_mask):\n        \"\"\"Should return positive loss value.\"\"\"\n        loss = sequence_contrastive_loss(sample_embeddings, sample_mask)\n\n        assert loss >= 0\n\n    def test_identical_sequences_low_loss(self):\n        \"\"\"Should give low loss for identical sequences.\"\"\"\n        embeddings = torch.ones(2, 4, 8)  # Identical embeddings\n        mask = torch.ones(2, 4, dtype=torch.float32)\n\n        loss = sequence_contrastive_loss(embeddings, mask)\n\n        # Identical sequences should have low contrastive loss\n        assert loss < 10.0  # Reasonable upper bound\n\n    def test_handles_masked_positions(self):\n        \"\"\"Should ignore masked-out positions.\"\"\"\n        embeddings = torch.randn(2, 4, 8)\n        mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]], dtype=torch.float32)\n\n        loss = sequence_contrastive_loss(embeddings, mask)\n\n        assert not torch.isnan(loss)\n        assert not torch.isinf(loss)\n\n    def test_gradient_flows_through_loss(self, sample_embeddings, sample_mask):\n        \"\"\"Should allow gradient to flow through.\"\"\"\n        sample_embeddings.requires_grad = True\n\n        loss = sequence_contrastive_loss(sample_embeddings, sample_mask)\n        loss.backward()\n\n        assert sample_embeddings.grad is not None\n\n\nclass TestFocalLossWithLogits:\n    \"\"\"Test suite for focal_loss_with_logits function.\"\"\"\n\n    @pytest.fixture\n    def sample_logits(self):\n        \"\"\"Sample logits for testing.\"\"\"\n        return torch.randn(2, 4)\n\n    @pytest.fixture\n    def sample_targets(self):\n        \"\"\"Sample binary targets.\"\"\"\n        return torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.float32)\n\n    def test_returns_tensor_with_reduction_none(self, sample_logits, sample_targets):\n        \"\"\"Should return per-element loss with reduction='none' (default).\"\"\"\n        loss = focal_loss_with_logits(sample_logits, sample_targets)\n\n        assert isinstance(loss, torch.Tensor)\n        assert loss.shape == sample_logits.shape\n\n    def test_returns_scalar_with_reduction_mean(self, sample_logits, sample_targets):\n        \"\"\"Should return scalar with reduction='mean'.\"\"\"\n        loss = focal_loss_with_logits(sample_logits, sample_targets, reduction=\"mean\")\n\n        assert loss.dim() == 0  # Scalar\n\n    def test_loss_is_positive(self, sample_logits, sample_targets):\n        \"\"\"Should return non-negative loss.\"\"\"\n        loss = focal_loss_with_logits(sample_logits, sample_targets, reduction=\"mean\")\n\n        assert loss >= 0\n\n    def test_perfect_predictions_low_loss(self):\n        \"\"\"Should give low loss for perfect predictions.\"\"\"\n        logits = torch.tensor([[10.0, -10.0], [-10.0, 10.0]])\n        targets = torch.tensor([[1.0, 0.0], [0.0, 1.0]])\n\n        loss = focal_loss_with_logits(logits, targets, reduction=\"mean\")\n\n        # Perfect predictions should have very low loss\n        assert loss < 0.1\n\n    def test_wrong_predictions_high_loss(self):\n        \"\"\"Should give higher loss for wrong predictions.\"\"\"\n        logits = torch.tensor([[10.0, -10.0], [-10.0, 10.0]])\n        targets = torch.tensor([[0.0, 1.0], [1.0, 0.0]])  # Opposite\n\n        loss = focal_loss_with_logits(logits, targets, reduction=\"mean\")\n\n        # Wrong predictions should have higher loss\n        assert loss > 1.0\n\n    def test_alpha_parameter_effect(self, sample_logits, sample_targets):\n        \"\"\"Should respect alpha weighting parameter.\"\"\"\n        loss_alpha_1 = focal_loss_with_logits(sample_logits, sample_targets, alpha=1.0, reduction=\"mean\")\n        loss_alpha_05 = focal_loss_with_logits(sample_logits, sample_targets, alpha=0.5, reduction=\"mean\")\n\n        # Different alpha should give different losses\n        assert not torch.allclose(loss_alpha_1, loss_alpha_05)\n\n    def test_gamma_parameter_effect(self, sample_logits, sample_targets):\n        \"\"\"Should respect gamma focusing parameter.\"\"\"\n        loss_gamma_0 = focal_loss_with_logits(sample_logits, sample_targets, gamma=0.0, reduction=\"mean\")\n        loss_gamma_2 = focal_loss_with_logits(sample_logits, sample_targets, gamma=2.0, reduction=\"mean\")\n\n        # Different gamma should give different losses\n        # gamma=0 is equivalent to BCE loss\n        assert not torch.allclose(loss_gamma_0, loss_gamma_2)\n\n    def test_reduction_sum(self, sample_logits, sample_targets):\n        \"\"\"Should reduce loss by sum.\"\"\"\n        loss = focal_loss_with_logits(sample_logits, sample_targets, reduction=\"sum\")\n\n        assert loss.dim() == 0  # Scalar\n\n    def test_reduction_none(self, sample_logits, sample_targets):\n        \"\"\"Should return per-element loss when reduction='none'.\"\"\"\n        loss = focal_loss_with_logits(sample_logits, sample_targets, reduction=\"none\")\n\n        assert loss.shape == sample_logits.shape\n\n    def test_handles_extreme_logits(self):\n        \"\"\"Should handle very large positive and negative logits.\"\"\"\n        logits = torch.tensor([[100.0, -100.0], [-100.0, 100.0]])\n        targets = torch.tensor([[1.0, 0.0], [0.0, 1.0]])\n\n        loss = focal_loss_with_logits(logits, targets, reduction=\"mean\")\n\n        assert not torch.isnan(loss)\n        assert not torch.isinf(loss)\n\n    def test_gradient_flows_through_loss(self, sample_logits, sample_targets):\n        \"\"\"Should allow gradient to flow through.\"\"\"\n        sample_logits.requires_grad = True\n\n        loss = focal_loss_with_logits(sample_logits, sample_targets, reduction=\"mean\")\n        loss.backward()\n\n        assert sample_logits.grad is not None\n\n    def test_all_zeros_targets(self):\n        \"\"\"Should handle all-zero targets.\"\"\"\n        logits = torch.randn(2, 4)\n        targets = torch.zeros(2, 4)\n\n        loss = focal_loss_with_logits(logits, targets, reduction=\"mean\")\n\n        assert not torch.isnan(loss)\n        assert loss >= 0\n\n    def test_all_ones_targets(self):\n        \"\"\"Should handle all-one targets.\"\"\"\n        logits = torch.randn(2, 4)\n        targets = torch.ones(2, 4)\n\n        loss = focal_loss_with_logits(logits, targets, reduction=\"mean\")\n\n        assert not torch.isnan(loss)\n        assert loss >= 0\n"
  },
  {
    "path": "tests/test_poolings.py",
    "content": "\"\"\"Tests for gliclass.poolings module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.poolings import (\n    GlobalMaxPooling1D,\n    GlobalAvgPooling1D,\n    GlobalSumPooling1D,\n    GlobalRMSPooling1D,\n    GlobalAbsMaxPooling1D,\n    GlobalAbsAvgPooling1D,\n    FirstTokenPooling1D,\n    LastTokenPooling1D,\n    PassPooling1D,\n)\n\n\nclass TestGlobalMaxPooling1D:\n    \"\"\"Test suite for GlobalMaxPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return GlobalMaxPooling1D()\n\n    @pytest.fixture\n    def sample_input(self):\n        \"\"\"Sample input tensor (batch_size, seq_len, hidden_dim).\"\"\"\n        return torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]])\n\n    def test_returns_max_across_sequence(self, pooling_layer, sample_input):\n        \"\"\"Should return maximum values across sequence dimension.\"\"\"\n        output = pooling_layer(sample_input)\n\n        expected = torch.tensor([[5.0, 6.0], [11.0, 12.0]])\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer, sample_input):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        output = pooling_layer(sample_input)\n\n        assert output.shape == (2, 2)  # (batch_size, hidden_dim)\n\n\nclass TestGlobalAvgPooling1D:\n    \"\"\"Test suite for GlobalAvgPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return GlobalAvgPooling1D()\n\n    def test_returns_average_across_sequence(self, pooling_layer):\n        \"\"\"Should return average values across sequence dimension.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[2.0, 3.0]])  # (1+3)/2, (2+4)/2\n        assert torch.allclose(output, expected)\n\n    def test_handles_attention_mask(self, pooling_layer):\n        \"\"\"Should average only over non-masked positions.\"\"\"\n        inputs = torch.tensor([[[2.0, 4.0], [4.0, 6.0], [99.0, 99.0]]])\n        mask = torch.tensor([[1, 1, 0]]).unsqueeze(-1)  # (batch, seq, 1)\n\n        output = pooling_layer(inputs, mask)\n\n        expected = torch.tensor([[3.0, 5.0]])  # (2+4)/2, (4+6)/2\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        inputs = torch.randn(2, 5, 10)  # (batch, seq, hidden)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == (2, 10)\n\n\nclass TestGlobalSumPooling1D:\n    \"\"\"Test suite for GlobalSumPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return GlobalSumPooling1D()\n\n    def test_returns_sum_across_sequence(self, pooling_layer):\n        \"\"\"Should return sum of values across sequence dimension.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[4.0, 6.0]])  # 1+3, 2+4\n        assert torch.allclose(output, expected)\n\n    def test_handles_attention_mask(self, pooling_layer):\n        \"\"\"Should sum only over non-masked positions.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [99.0, 99.0]]])\n        mask = torch.tensor([[1, 1, 0]]).unsqueeze(-1)  # (batch, seq, 1)\n\n        output = pooling_layer(inputs, mask)\n\n        expected = torch.tensor([[4.0, 6.0]])  # 1+3, 2+4 (masked position becomes 0)\n        assert torch.allclose(output, expected)\n\n\nclass TestFirstTokenPooling1D:\n    \"\"\"Test suite for FirstTokenPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return FirstTokenPooling1D()\n\n    def test_returns_first_token(self, pooling_layer):\n        \"\"\"Should return the first token representation.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[1.0, 2.0]])\n        assert torch.allclose(output, expected)\n\n    def test_works_with_batch(self, pooling_layer):\n        \"\"\"Should work with batched inputs.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[1.0, 2.0], [5.0, 6.0]])\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        inputs = torch.randn(4, 10, 16)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == (4, 16)\n\n\nclass TestLastTokenPooling1D:\n    \"\"\"Test suite for LastTokenPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return LastTokenPooling1D()\n\n    def test_returns_last_token(self, pooling_layer):\n        \"\"\"Should return the last token representation.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[5.0, 6.0]])\n        assert torch.allclose(output, expected)\n\n    def test_works_with_batch(self, pooling_layer):\n        \"\"\"Should work with batched inputs.\"\"\"\n        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[3.0, 4.0], [7.0, 8.0]])\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        inputs = torch.randn(4, 10, 16)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == (4, 16)\n\n\nclass TestGlobalRMSPooling1D:\n    \"\"\"Test suite for GlobalRMSPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return GlobalRMSPooling1D()\n\n    def test_returns_rms_across_sequence(self, pooling_layer):\n        \"\"\"Should return RMS values across sequence dimension.\"\"\"\n        inputs = torch.tensor([[[3.0, 4.0], [0.0, 0.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.sqrt(torch.tensor([[4.5, 8.0]]))\n        assert torch.allclose(output, expected)\n\n    def test_handles_attention_mask(self, pooling_layer):\n        \"\"\"Should compute RMS only over non-masked positions.\"\"\"\n        inputs = torch.tensor([[[3.0, 4.0], [3.0, 4.0], [99.0, 99.0]]])\n        mask = torch.tensor([[1, 1, 0]]).unsqueeze(-1)\n\n        output = pooling_layer(inputs, mask)\n\n        expected = torch.tensor([[3.0, 4.0]])\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        inputs = torch.randn(2, 5, 10)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == (2, 10)\n\n\nclass TestGlobalAbsMaxPooling1D:\n    \"\"\"Test suite for GlobalAbsMaxPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return GlobalAbsMaxPooling1D()\n\n    def test_returns_abs_max_across_sequence(self, pooling_layer):\n        \"\"\"Should return maximum absolute values across sequence dimension.\"\"\"\n        inputs = torch.tensor([[[-5.0, 2.0], [3.0, -4.0], [1.0, 1.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[5.0, 4.0]])\n        assert torch.allclose(output, expected)\n\n    def test_handles_attention_mask(self, pooling_layer):\n        \"\"\"Should find abs max only over non-masked positions.\"\"\"\n        inputs = torch.tensor([[[-2.0, 3.0], [4.0, -1.0], [99.0, 99.0]]])\n        mask = torch.tensor([[1, 1, 0]]).unsqueeze(-1)\n\n        output = pooling_layer(inputs, mask)\n\n        expected = torch.tensor([[4.0, 3.0]])\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        inputs = torch.randn(2, 5, 10)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == (2, 10)\n\n\nclass TestGlobalAbsAvgPooling1D:\n    \"\"\"Test suite for GlobalAbsAvgPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return GlobalAbsAvgPooling1D()\n\n    def test_returns_abs_avg_across_sequence(self, pooling_layer):\n        \"\"\"Should return average of absolute values across sequence dimension.\"\"\"\n        inputs = torch.tensor([[[-2.0, 4.0], [2.0, -4.0]]])\n\n        output = pooling_layer(inputs)\n\n        expected = torch.tensor([[2.0, 4.0]])\n        assert torch.allclose(output, expected)\n\n    def test_handles_attention_mask(self, pooling_layer):\n        \"\"\"Should average abs values only over non-masked positions.\"\"\"\n        inputs = torch.tensor([[[-2.0, 4.0], [4.0, -2.0], [99.0, 99.0]]])\n        mask = torch.tensor([[1, 1, 0]]).unsqueeze(-1)\n\n        output = pooling_layer(inputs, mask)\n\n        expected = torch.tensor([[3.0, 3.0]])\n        assert torch.allclose(output, expected)\n\n    def test_output_shape(self, pooling_layer):\n        \"\"\"Should reduce sequence dimension.\"\"\"\n        inputs = torch.randn(2, 5, 10)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == (2, 10)\n\n\nclass TestPassPooling1D:\n    \"\"\"Test suite for PassPooling1D.\"\"\"\n\n    @pytest.fixture\n    def pooling_layer(self):\n        \"\"\"Create pooling layer for testing.\"\"\"\n        return PassPooling1D()\n\n    def test_returns_input_unchanged(self, pooling_layer):\n        \"\"\"Should return input tensor without modification.\"\"\"\n        inputs = torch.randn(2, 5, 10)\n\n        output = pooling_layer(inputs)\n\n        assert torch.allclose(output, inputs)\n\n    def test_ignores_attention_mask(self, pooling_layer):\n        \"\"\"Should ignore attention mask and return full input.\"\"\"\n        inputs = torch.randn(2, 5, 10)\n        mask = torch.tensor([[1, 1, 0, 0, 0], [1, 1, 1, 0, 0]])\n\n        output = pooling_layer(inputs, mask)\n\n        assert torch.allclose(output, inputs)\n\n    def test_maintains_shape(self, pooling_layer):\n        \"\"\"Should maintain input shape exactly.\"\"\"\n        inputs = torch.randn(3, 7, 12)\n\n        output = pooling_layer(inputs)\n\n        assert output.shape == inputs.shape\n"
  },
  {
    "path": "tests/test_scorers.py",
    "content": "\"\"\"Tests for gliclass.scorers module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.scorers import (\n    ScorerWeightedDot,\n    ScorerDot,\n    MLPScorer,\n    HopfieldScorer,\n    CrossAttnScorer,\n)\n\n\nclass TestScorerWeightedDot:\n    @pytest.fixture\n    def scorer(self):\n        return ScorerWeightedDot(hidden_size=128)\n\n    def test_forward_pass(self, scorer):\n        text_rep = torch.randn(4, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n\n    def test_gradient_flow(self, scorer):\n        text_rep = torch.randn(4, 128, requires_grad=True)\n        label_rep = torch.randn(4, 10, 128, requires_grad=True)\n\n        scores = scorer(text_rep, label_rep)\n        loss = scores.sum()\n        loss.backward()\n\n        assert text_rep.grad is not None\n        assert label_rep.grad is not None\n\n\nclass TestScorerDot:\n    @pytest.fixture\n    def scorer(self):\n        return ScorerDot()\n\n    def test_forward_pass(self, scorer):\n        text_rep = torch.randn(4, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n\n    def test_gradient_flow(self, scorer):\n        text_rep = torch.randn(4, 128, requires_grad=True)\n        label_rep = torch.randn(4, 10, 128, requires_grad=True)\n\n        scores = scorer(text_rep, label_rep)\n        loss = scores.sum()\n        loss.backward()\n\n        assert text_rep.grad is not None\n        assert label_rep.grad is not None\n\n\nclass TestMLPScorer:\n    @pytest.fixture\n    def scorer(self):\n        return MLPScorer(hidden_size=128)\n\n    def test_forward_pass(self, scorer):\n        text_rep = torch.randn(4, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n\n    def test_different_batch_sizes(self, scorer):\n        for batch_size in [1, 2, 8]:\n            text_rep = torch.randn(batch_size, 128)\n            label_rep = torch.randn(batch_size, 10, 128)\n\n            scores = scorer(text_rep, label_rep)\n\n            assert scores.shape == (batch_size, 10)\n\n    def test_gradient_flow(self, scorer):\n        text_rep = torch.randn(4, 128, requires_grad=True)\n        label_rep = torch.randn(4, 10, 128, requires_grad=True)\n\n        scores = scorer(text_rep, label_rep)\n        loss = scores.sum()\n        loss.backward()\n\n        assert text_rep.grad is not None\n        assert label_rep.grad is not None\n\n\nclass TestHopfieldScorer:\n    @pytest.fixture\n    def scorer(self):\n        return HopfieldScorer(hidden_size=128)\n\n    def test_forward_pass(self, scorer):\n        text_rep = torch.randn(4, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n\n    def test_multiple_iterations(self):\n        scorer = HopfieldScorer(hidden_size=128, num_iteration=3)\n        text_rep = torch.randn(4, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n\n    def test_gradient_flow(self, scorer):\n        text_rep = torch.randn(4, 128, requires_grad=True)\n        label_rep = torch.randn(4, 10, 128, requires_grad=True)\n\n        scores = scorer(text_rep, label_rep)\n        loss = scores.sum()\n        loss.backward()\n\n        assert text_rep.grad is not None\n        assert label_rep.grad is not None\n\n\nclass TestCrossAttnScorer:\n    @pytest.fixture\n    def scorer(self):\n        return CrossAttnScorer(hidden_size=128, num_heads=8)\n\n    def test_forward_pass_with_text_mask(self, scorer):\n        text_rep = torch.randn(4, 20, 128)\n        label_rep = torch.randn(4, 10, 128)\n        text_mask = torch.ones(4, 20, dtype=torch.bool)\n        text_mask[:, 15:] = 0\n\n        scores = scorer(text_rep, label_rep, text_mask=text_mask)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n\n    def test_forward_pass_without_text_mask(self, scorer):\n        text_rep = torch.randn(4, 20, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n\n    def test_different_seq_lengths(self, scorer):\n        for seq_len in [10, 20, 50]:\n            text_rep = torch.randn(4, seq_len, 128)\n            label_rep = torch.randn(4, 10, 128)\n\n            scores = scorer(text_rep, label_rep)\n\n            assert scores.shape == (4, 10)\n\n    def test_gradient_flow(self, scorer):\n        text_rep = torch.randn(4, 20, 128, requires_grad=True)\n        label_rep = torch.randn(4, 10, 128, requires_grad=True)\n\n        scores = scorer(text_rep, label_rep)\n        loss = scores.sum()\n        loss.backward()\n\n        assert text_rep.grad is not None\n        assert label_rep.grad is not None\n\n    def test_eval_mode(self, scorer):\n        scorer.eval()\n        text_rep = torch.randn(4, 20, 128)\n        label_rep = torch.randn(4, 10, 128)\n\n        with torch.no_grad():\n            scores = scorer(text_rep, label_rep)\n\n        assert scores.shape == (4, 10)\n        assert not torch.isnan(scores).any()\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "\"\"\"Tests for gliclass.utils module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.utils import is_module_available, retrieval_augmented_text, default_f1_reward\n\n\nclass TestIsModuleAvailable:\n    \"\"\"Test suite for is_module_available function.\"\"\"\n\n    def test_detects_installed_module(self):\n        \"\"\"Should return True for installed modules.\"\"\"\n        assert is_module_available(\"torch\") is True\n        assert is_module_available(\"pytest\") is True\n\n    def test_detects_missing_module(self):\n        \"\"\"Should return False for non-existent modules.\"\"\"\n        assert is_module_available(\"nonexistent_module_12345\") is False\n\n    def test_handles_submodules(self):\n        \"\"\"Should work with submodule paths.\"\"\"\n        assert is_module_available(\"torch.nn\") is True\n\n\nclass TestRetrievalAugmentedText:\n    \"\"\"Test suite for retrieval_augmented_text function.\"\"\"\n\n    def test_with_structured_examples(self):\n        \"\"\"Should concatenate input text with structured examples.\"\"\"\n        text = \"This is a test.\"\n        examples = [\n            {\"text\": \"Example 1\", \"true_labels\": [\"label1\"], \"all_labels\": [\"label1\", \"label2\"]},\n            {\"text\": \"Example 2\", \"true_labels\": [\"label2\"], \"all_labels\": [\"label1\", \"label2\"]},\n        ]\n\n        result = retrieval_augmented_text(text, examples)\n\n        assert isinstance(result, str)\n        assert text in result\n\n    def test_empty_examples_returns_original_text(self):\n        \"\"\"Should return original text when no examples provided.\"\"\"\n        text = \"This is a test.\"\n        examples = []\n\n        result = retrieval_augmented_text(text, examples)\n\n        assert result == text\n\n    def test_includes_true_label_markers(self):\n        \"\"\"Should include TRUE_LABEL markers for positive labels.\"\"\"\n        text = \"Query text\"\n        examples = [{\"text\": \"Example\", \"true_labels\": [\"tech\"], \"all_labels\": [\"tech\", \"sports\"]}]\n\n        result = retrieval_augmented_text(text, examples)\n\n        assert \"<<TRUE_LABEL>>\" in result\n        assert \"tech\" in result\n\n\nclass TestDefaultF1Reward:\n    \"\"\"Test suite for default_f1_reward function.\"\"\"\n\n    @pytest.fixture\n    def sample_inputs(self):\n        \"\"\"Sample inputs matching the function signature.\"\"\"\n        batch_size = 2\n        num_labels = 4\n        return {\n            \"probs\": torch.rand(batch_size, num_labels),\n            \"actions\": torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.long),\n            \"original_targets\": torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.long),\n            \"valid_mask\": torch.ones(batch_size, num_labels),\n        }\n\n    def test_returns_tensor(self, sample_inputs):\n        \"\"\"Should return a torch tensor.\"\"\"\n        reward = default_f1_reward(**sample_inputs)\n\n        assert isinstance(reward, torch.Tensor)\n\n    def test_output_shape(self, sample_inputs):\n        \"\"\"Should return (N, 1) shaped tensor.\"\"\"\n        reward = default_f1_reward(**sample_inputs)\n\n        assert reward.shape == (2, 1)\n\n    def test_perfect_predictions(self):\n        \"\"\"Should give F1=1.0 for perfect predictions.\"\"\"\n        probs = torch.rand(1, 4)\n        actions = torch.tensor([[1, 0, 1, 0]], dtype=torch.long)\n        targets = torch.tensor([[1, 0, 1, 0]], dtype=torch.long)\n        valid_mask = torch.ones(1, 4)\n\n        reward = default_f1_reward(probs, actions, targets, valid_mask)\n\n        assert torch.allclose(reward, torch.tensor([[1.0]]))\n\n    def test_zero_f1_for_wrong_predictions(self):\n        \"\"\"Should give F1=0.0 when predictions and targets don't overlap.\"\"\"\n        probs = torch.rand(1, 4)\n        actions = torch.tensor([[1, 1, 0, 0]], dtype=torch.long)\n        targets = torch.tensor([[0, 0, 1, 1]], dtype=torch.long)\n        valid_mask = torch.ones(1, 4)\n\n        reward = default_f1_reward(probs, actions, targets, valid_mask)\n\n        assert torch.allclose(reward, torch.tensor([[0.0]]))\n\n    def test_handles_valid_mask(self):\n        \"\"\"Should respect valid_mask to ignore certain positions.\"\"\"\n        probs = torch.rand(1, 4)\n        actions = torch.tensor([[1, 1, 1, 1]], dtype=torch.long)\n        targets = torch.tensor([[1, 1, 0, 0]], dtype=torch.long)\n        valid_mask = torch.tensor([[1, 1, 0, 0]])  # Mask out last two\n\n        reward = default_f1_reward(probs, actions, targets, valid_mask)\n\n        # Should get F1=1.0 since masked positions are ignored\n        assert torch.allclose(reward, torch.tensor([[1.0]]))\n"
  },
  {
    "path": "train.py",
    "content": "import os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\nimport numpy as np\nimport argparse\nimport json\n\nfrom sklearn.metrics import precision_recall_fscore_support, accuracy_score\nimport transformers\nfrom transformers import AutoTokenizer, AutoConfig\nfrom torch.utils.data import WeightedRandomSampler\nfrom packaging import version\n\nimport random\nimport torch\n\nfrom gliclass import GLiClassModelConfig, GLiClassModel\nfrom gliclass.training import TrainingArguments, Trainer\nfrom gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset, AugmentationConfig\n\nclass CustomTrainer(Trainer):\n    \"\"\"Trainer with weighted random sampling support.\"\"\"\n    \n    def __init__(self, *args, use_weighted_sampling=False, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.use_weighted_sampling = use_weighted_sampling\n    \n    def _get_train_sampler(self, train_dataset) -> torch.utils.data.Sampler:\n        if not self.use_weighted_sampling:\n            return super()._get_train_sampler()\n        \n        weights = train_dataset.get_diversity()\n        return WeightedRandomSampler(\n            weights=weights,\n            num_samples=len(train_dataset),\n            replacement=True\n        )\n    \ndef compute_metrics(p, problem_type='multi_label_classification'):\n    \"\"\"Compute evaluation metrics.\n    \n    Args:\n        p: Predictions tuple (predictions, labels)\n        problem_type: Type of classification problem\n        \n    Returns:\n        Dictionary of metrics\n    \"\"\"\n    predictions, labels = p\n    labels = labels.reshape(-1)\n    \n    if problem_type == 'single_label_classification':\n        preds = np.argmax(predictions, axis=1)\n        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')\n        accuracy = accuracy_score(labels, preds)\n        return {\n            'accuracy': accuracy,\n            'precision': precision,\n            'recall': recall,\n            'f1': f1,\n        }\n\n    elif problem_type == 'multi_label_classification':\n        predictions = predictions.reshape(-1)\n        preds = (predictions > 0.5).astype(int)\n        labels = np.where(labels > 0.5, 1, 0)\n        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')\n        accuracy = accuracy_score(labels, preds)\n        return {\n            'accuracy': accuracy,\n            'precision': precision,\n            'recall': recall,\n            'f1': f1,\n        }\n    else:\n        raise NotImplementedError(f\"{problem_type} is not implemented.\")\n\n\ndef load_dataset(data_path: str) -> list:\n    \"\"\"Load dataset from JSON file.\n    \n    Args:\n        data_path: Path to JSON data file\n        \n    Returns:\n        List of data samples\n    \"\"\"\n    with open(data_path, 'r') as f:\n        data = json.load(f)\n    return data\n\n\ndef main(args):\n    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n    # Load or create model\n    if args.model_name is not None:\n        model = GLiClassModel.from_pretrained(\n            args.model_name, \n            focal_loss_alpha=args.focal_loss_alpha,\n            focal_loss_gamma=args.focal_loss_gamma,\n            focal_loss_reduction=args.focal_loss_reduction\n        )\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)\n        encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)\n\n        label_model_config = None\n        if args.label_model_name is not None:\n            label_model_config = AutoConfig.from_pretrained(args.label_model_name)\n\n        glicalss_config = GLiClassModelConfig(\n            encoder_config=encoder_config,\n            encoder_model=args.encoder_model_name,\n            label_model_name=args.label_model_name,\n            label_model_config=label_model_config,\n            class_token_index=len(tokenizer),\n            text_token_index=len(tokenizer)+1,\n            example_token_index=len(tokenizer)+2,\n            pooling_strategy=args.pooler_type,\n            class_token_pooling=args.class_token_pooling,\n            scorer_type=args.scorer_type,\n            use_lstm=args.use_lstm,\n            focal_loss_alpha=args.focal_loss_alpha,\n            focal_loss_gamma=args.focal_loss_gamma,\n            focal_loss_reduction=args.focal_loss_reduction,\n            contrastive_loss_coef=args.contrastive_loss_coef,\n            normalize_features=args.normalize_features,\n            extract_text_features=args.extract_text_features,\n            architecture_type=args.architecture_type,\n            prompt_first=args.prompt_first,\n            squeeze_layers=args.squeeze_layers,\n            layer_wise=args.layer_wise,\n            encoder_layer_id=args.encoder_layer_id,\n            shuffle_labels=args.shuffle_labels,\n            dropout=args.dropout,\n            use_segment_embeddings=args.use_segment_embeddings,\n        )\n\n        model = GLiClassModel(glicalss_config, from_pretrained=True).to(dtype=torch.float32)\n\n        if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:\n            new_words = [\"<<LABEL>>\", \"<<SEP>>\", \"<<EXAMPLE>>\"]\n            tokenizer.add_tokens(new_words, special_tokens=True)\n            model.resize_token_embeddings(len(tokenizer))\n\n    model.to(device)\n\n    # Get labels tokenizer if needed\n    if model.config.label_model_name is not None:\n        labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)\n    else:\n        labels_tokenizer = None\n\n    model.config.problem_type = args.problem_type\n\n    # Load current training data\n    data = load_dataset(args.data_path)\n    print(f'Dataset size: {len(data)}')\n    \n    random.shuffle(data)    \n    print('Dataset is shuffled...')\n\n    train_data = data[:int(len(data) * 0.9)]\n    test_data = data[int(len(data) * 0.9):]\n    print('Dataset is splitted...')\n\n    # Create augmentation config with all parameters\n    augment_config = AugmentationConfig(\n        enabled=args.enable_augmentation,\n        random_label_removal_prob=args.random_label_removal_prob,\n        random_label_addition_prob=args.random_label_addition_prob,\n        random_text_addition_prob=args.random_text_addition_prob,\n        random_add_description_prob=args.random_add_description_prob,\n        random_add_synonyms_prob=args.random_add_synonyms_prob,\n        random_add_examples_prob=args.random_add_examples_prob,\n        max_num_examples=args.max_num_examples\n    )\n    \n    if args.labels_desc_path is not None:\n        labels_descriptions = load_dataset(args.labels_desc_path)\n        label_to_description = {item.get(\"label\"): item for item in labels_descriptions}\n    else:\n        label_to_description = {}\n\n    train_dataset = GLiClassDataset(train_data, tokenizer, augment_config, \n                                    label_to_description, args.max_length, \n                                    args.problem_type, args.architecture_type, \n                                    args.prompt_first, labels_tokenizer=labels_tokenizer)\n    \n    # Disable augmentation for test dataset\n    test_augment_config = AugmentationConfig(enabled=False)\n    test_dataset = GLiClassDataset(test_data, tokenizer, test_augment_config, \n                                        label_to_description,\n                                        args.max_length, args.problem_type, \n                                        args.architecture_type, args.prompt_first,\n                                        labels_tokenizer = labels_tokenizer)\n\n    # Load previous dataset for EWC if provided\n    prev_dataset = None\n    if args.use_ewc and args.prev_data_path is not None:\n        print(f'Loading previous dataset for EWC from: {args.prev_data_path}')\n        prev_data = load_dataset(args.prev_data_path)\n        print(f'Previous dataset size: {len(prev_data)}')\n        \n        # Use a subset if specified\n        if args.ewc_fisher_samples is not None and args.ewc_fisher_samples < len(prev_data):\n            random.shuffle(prev_data)\n            prev_data = prev_data[:args.ewc_fisher_samples]\n            print(f'Using {len(prev_data)} samples for Fisher estimation')\n        \n        prev_dataset = GLiClassDataset(prev_data, tokenizer, test_augment_config, \n                                        label_to_description,\n                                        args.max_length, args.problem_type, \n                                        args.architecture_type, args.prompt_first,\n                                        labels_tokenizer = labels_tokenizer)\n\n    data_collator = DataCollatorWithPadding(device=device)\n\n    # Create training arguments with EWC parameters\n    training_args = TrainingArguments(\n        output_dir=args.save_path,\n        learning_rate=args.encoder_lr,\n        weight_decay=args.encoder_weight_decay,\n        others_lr=args.others_lr,\n        others_weight_decay=args.others_weight_decay,\n        lr_scheduler_type=args.lr_scheduler_type,\n        warmup_ratio=args.warmup_ratio,\n        per_device_train_batch_size=args.batch_size,\n        per_device_eval_batch_size=args.batch_size,\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        num_train_epochs=args.num_epochs,\n        save_steps=args.save_steps,\n        save_total_limit=args.save_total_limit,\n        dataloader_num_workers=args.num_workers,\n        logging_steps=100,\n        use_cpu=False,\n        report_to=\"none\",\n        fp16=args.fp16,\n        # EWC parameters\n        use_ewc=args.use_ewc,\n        ewc_lambda=args.ewc_lambda,\n        ewc_fisher_samples=args.ewc_fisher_samples,\n        ewc_normalize_fisher=args.ewc_normalize_fisher,\n        ewc_gamma=args.ewc_gamma,\n    )\n\n    # Create compute_metrics function with problem_type closure\n    def compute_metrics_fn(p):\n        return compute_metrics(p, args.problem_type)\n\n    # Create trainer with EWC support\n    # Handle version differences between transformers v4 and v5\n    trainer_kwargs = {\n        \"model\": model,\n        \"args\": training_args,\n        \"train_dataset\": train_dataset,\n        \"eval_dataset\": test_dataset,\n        \"data_collator\": data_collator,\n        \"compute_metrics\": compute_metrics_fn,\n        \"prev_dataset\": prev_dataset,  # Pass previous dataset for EWC\n    }\n\n    if version.parse(transformers.__version__) < version.parse(\"5.0.0\"):\n        trainer_kwargs[\"tokenizer\"] = tokenizer\n    else:\n        trainer_kwargs[\"processing_class\"] = tokenizer\n\n    trainer = CustomTrainer(**trainer_kwargs)\n    \n    # Print EWC status\n    if args.use_ewc:\n        if args.prev_data_path is not None:\n            print(f'\\nEWC enabled with lambda={args.ewc_lambda}')\n        else:\n            print('\\nWarning: EWC is enabled but no previous data path provided. EWC will not be used.')\n    \n    trainer.train()\n    \n    # Save final model\n    final_output_dir = os.path.join(args.save_path, 'final_model')\n    model.save_pretrained(final_output_dir)\n    tokenizer.save_pretrained(final_output_dir)\n    print(f'Final model saved to {final_output_dir}')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Train GLiClass model with optional EWC for continual learning')\n    \n    # Model arguments\n    parser.add_argument('--model_name', type=str, default=None,\n                        help='Pretrained model name or path')\n    parser.add_argument('--encoder_model_name', type=str, default='microsoft/deberta-v3-small',\n                        help='Encoder model name')\n    parser.add_argument('--label_model_name', type=str, default=\"BAAI/bge-small-en-v1.5\",\n                        help='Label model name')\n    \n    # Path arguments\n    parser.add_argument('--save_path', type=str, default='models/',\n                        help='Path to save trained model')\n    parser.add_argument('--data_path', type=str, default='data/zero-cats.json',\n                        help='Path to training data JSON file')\n    parser.add_argument('--prev_data_path', type=str, default=None,\n                        help='Path to previous task data for EWC (required if use_ewc=True)')\n    parser.add_argument('--labels_desc_path', type=str, default = None)\n\n    # Model architecture arguments\n    parser.add_argument('--problem_type', type=str, default='multi_label_classification',\n                        choices=['single_label_classification', 'multi_label_classification'])\n    parser.add_argument('--pooler_type', type=str, default='avg')\n    parser.add_argument('--scorer_type', type=str, default='simple')\n    parser.add_argument('--architecture_type', type=str, default='uni-encoder')\n    parser.add_argument('--class_token_pooling', type=str, default='first')\n    parser.add_argument('--normalize_features', type=bool, default=False)\n    parser.add_argument('--extract_text_features', type=bool, default=False)\n    parser.add_argument('--prompt_first', type=bool, default=True)\n    parser.add_argument('--use_lstm', type=bool, default=False)\n    parser.add_argument('--squeeze_layers', type=bool, default=False)\n    parser.add_argument('--layer_wise', type=bool, default=False)\n    parser.add_argument('--encoder_layer_id', type=int, default=-1)\n    parser.add_argument('--dropout', type=float, default=0.3)\n    parser.add_argument('--shuffle_labels', type=bool, default=True)\n    parser.add_argument('--use_segment_embeddings', type=bool, default=False)\n\n    # Training arguments\n    parser.add_argument('--num_epochs', type=int, default=3)\n    parser.add_argument('--batch_size', type=int, default=8)\n    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)\n    parser.add_argument('--encoder_lr', type=float, default=1e-5)\n    parser.add_argument('--others_lr', type=float, default=3e-5)\n    parser.add_argument('--encoder_weight_decay', type=float, default=0.01)\n    parser.add_argument('--others_weight_decay', type=float, default=0.01)\n    parser.add_argument('--warmup_ratio', type=float, default=0.05)\n    parser.add_argument('--lr_scheduler_type', type=str, default='linear')\n    parser.add_argument('--max_length', type=int, default=1024)\n    parser.add_argument('--save_steps', type=int, default=1000)\n    parser.add_argument('--save_total_limit', type=int, default=3)\n    parser.add_argument('--num_workers', type=int, default=12)\n    parser.add_argument('--fp16', type=bool, default=False)\n    \n    # Augmentation parameters\n    parser.add_argument('--enable_augmentation', type=bool, default=True)\n    parser.add_argument('--random_label_removal_prob', type=float, default=0.05)\n    parser.add_argument('--random_label_addition_prob', type=float, default=0.05)\n    parser.add_argument('--random_text_addition_prob', type=float, default=0.05)\n    parser.add_argument('--random_add_description_prob', type=float, default=0.05)\n    parser.add_argument('--random_add_synonyms_prob', type=float, default=0.05)\n    parser.add_argument('--random_add_examples_prob', type=float, default=0.1)\n    parser.add_argument('--max_num_examples', type=int, default=5)\n\n\n    # Loss arguments\n    parser.add_argument('--focal_loss_alpha', type=float, default=-1)\n    parser.add_argument('--focal_loss_gamma', type=float, default=-1)\n    parser.add_argument('--focal_loss_reduction', type=str, default='none',\n                        choices=['none', 'mean', 'sum'])\n    parser.add_argument('--contrastive_loss_coef', type=float, default=0.)\n    \n    # EWC arguments\n    parser.add_argument('--use_ewc', action='store_true',\n                        help='Enable Elastic Weight Consolidation for continual learning')\n    parser.add_argument('--ewc_lambda', type=float, default=100.0,\n                        help='Lambda parameter for EWC penalty (higher = more regularization)')\n    parser.add_argument('--ewc_fisher_samples', type=int, default=None,\n                        help='Number of samples to use for Fisher information estimation (None = use all)')\n    parser.add_argument('--ewc_normalize_fisher', type=bool, default=True,\n                        help='Whether to normalize Fisher information values')\n    parser.add_argument('--ewc_gamma', type=float, default=0.95,\n                        help='Decay factor for Online EWC (0 < gamma < 1)')\n    \n    args = parser.parse_args()\n\n    # Validate EWC arguments\n    if args.use_ewc and args.prev_data_path is None:\n        print(\"Warning: --use_ewc is set but --prev_data_path is not provided.\")\n        print(\"EWC requires previous task data to compute Fisher information.\")\n        print(\"Training will proceed without EWC.\")\n    \n    main(args)"
  },
  {
    "path": "train_rl.py",
    "content": "import os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\nimport numpy as np\nimport argparse\nimport json\n\nfrom sklearn.metrics import precision_recall_fscore_support, accuracy_score\nimport transformers\nfrom transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification\nfrom packaging import version\n\nimport random\nimport torch\n\nfrom gliclass import GLiClassModelConfig, GLiClassModel, ZeroShotClassificationPipeline\nfrom gliclass.training import TrainingArguments, Trainer, RLTrainerConfig, RLTrainer\nfrom gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset\nfrom gliclass.utils import default_f1_reward\n\ndef accuracy_reward(probs, actions, targets, valid_mask):\n    probs = probs * valid_mask\n    predicts = torch.argmax(probs, dim=-1)\n    true_labels = torch.argmax(targets, dim=-1)\n    correct = (predicts == true_labels).float().unsqueeze(1)\n    return correct\n\ndef recall_reward(\n    probs: torch.Tensor,\n    actions: torch.Tensor,\n    original_targets: torch.Tensor,\n    valid_mask: torch.Tensor\n) -> torch.Tensor:\n    valid_preds = actions * valid_mask\n    valid_targets = original_targets * valid_mask\n\n    TP = torch.sum((valid_preds * valid_targets), dim=-1)\n    FN = torch.sum(((1 - valid_preds) * valid_targets), dim=-1)\n\n    eps = 1e-8\n    recall = TP / (TP + FN + eps)\n    return recall.detach().unsqueeze(1)\n\ndef compute_metrics(p):\n    predictions, labels = p\n    labels = labels.reshape(-1)\n    if args.problem_type == 'single_label_classification':\n        preds = np.argmax(predictions, axis=1)\n        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')\n        accuracy = accuracy_score(labels, preds)\n        return {\n            'accuracy': accuracy,\n            'precision': precision,\n            'recall': recall,\n            'f1': f1,\n        }\n\n    elif args.problem_type == 'multi_label_classification':\n        predictions = predictions.reshape(-1)\n        preds = (predictions > 0.5).astype(int)\n        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')\n        accuracy = accuracy_score(labels, preds)\n        return {\n            'accuracy': accuracy,\n            'precision': precision,\n            'recall': recall,\n            'f1': f1,\n        }\n    else:\n        raise NotImplementedError(f\"{args.problem_type} is not implemented.\")\n\ndef main(args):\n    device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')\n\n    if args.model_name is not None:\n        model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,\n                                                                focal_loss_gamma=args.focal_loss_gamma,\n                                                                focal_loss_reduction=args.focal_loss_reduction)\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n    else:\n        tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)\n        encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)\n\n        if args.label_model_name is not None:\n            label_model_config = AutoConfig.from_pretrained(args.label_model_name)\n\n        glicalss_config = GLiClassModelConfig(\n            encoder_config=encoder_config,\n            encoder_model=args.encoder_model_name,\n            label_model_name=args.label_model_name,\n            label_model_config=label_model_config,\n            class_token_index=len(tokenizer),\n            text_token_index=len(tokenizer)+1,\n            pooling_strategy=args.pooler_type,\n            scorer_type=args.scorer_type,\n            use_lstm=args.use_lstm,\n            focal_loss_alpha=args.focal_loss_alpha,\n            focal_loss_gamma=args.focal_loss_gamma,\n            focal_loss_reduction=args.focal_loss_reduction,\n            labels_smoothing=args.labels_smoothing,\n            entropy_beta=args.entropy_beta,\n            kl_beta=args.kl_beta,\n            contrastive_loss_coef=args.contrastive_loss_coef,\n            normalize_features=args.normalize_features,\n            extract_text_features=args.extract_text_features,\n            architecture_type=args.architecture_type,\n            prompt_first=args.prompt_first,\n            squeeze_layers=args.squeeze_layers\n        )\n\n        glicalss_config.problem_type = args.problem_type\n\n        model = GLiClassModel(glicalss_config, from_pretrained=True)\n\n        if args.architecture_type in  {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:\n            new_words = [\"<<LABEL>>\", \"<<SEP>>\"]\n            tokenizer.add_tokens(new_words, special_tokens=True)\n            model.resize_token_embeddings(len(tokenizer))\n\n    if args.set_value_model:\n        value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)\n        value_model.resize_token_embeddings(len(tokenizer))\n    else:\n        value_model = None\n\n    if args.reference_model is not None:\n        refrence_model = GLiClassModel.from_pretrained(args.reference_model)\n        reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)\n        reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer, \n                                                                classification_type='multi-label', \n                                                                progress_bar=False, device=device)\n    else:\n        reference_pipe = None\n\n    if args.label_model_name is not None:\n        labels_tokenizer = AutoTokenizer.from_pretrained(args.label_model_name)\n    else:\n        labels_tokenizer = None\n\n    model.to(device)\n        \n    with open(args.data_path, 'r') as f:\n        data = json.load(f)[:]\n    init_ld = len(data)*1\n\n    print('Dataset size:', len(data))\n    random.shuffle(data)    \n    print('Dataset is shuffled...')\n\n    train_data = data[:int(len(data)*0.9)]\n    test_data = data[int(len(data)*0.9):]\n\n    print('Dataset is splitted...')\n\n    train_dataset = GLiClassDataset(train_data, tokenizer, args.max_length, \n                                    args.problem_type, args.architecture_type, \n                                    args.prompt_first, labels_tokenizer=labels_tokenizer)\n    test_dataset = GLiClassDataset(test_data, tokenizer, args.max_length, args.problem_type, \n                                        args.architecture_type, args.prompt_first,\n                                        labels_tokenizer = labels_tokenizer)\n\n    data_collator = DataCollatorWithPadding(device=device)\n\n    compute_metrics_func = compute_metrics if args.use_compute_metrics else None\n\n    training_args = RLTrainerConfig(\n        output_dir=args.save_path,\n        learning_rate=args.encoder_lr,\n        weight_decay=args.encoder_weight_decay,\n        others_lr=args.others_lr,\n        others_weight_decay=args.others_weight_decay,\n        lr_scheduler_type=args.lr_scheduler_type,\n        warmup_ratio=args.warmup_ratio,\n        per_device_train_batch_size=args.batch_size,\n        per_device_eval_batch_size=args.batch_size,\n        num_train_epochs=args.num_epochs,\n        evaluation_strategy=\"epoch\",\n        save_steps = args.save_steps,\n        save_total_limit=args.save_total_limit,\n        dataloader_num_workers = args.num_workers,\n        logging_steps=100,\n        use_cpu = False,\n        report_to=\"none\",\n        fp16=args.fp16,\n        cliprange=args.clip_range,\n        num_rl_iters=args.num_rl_iters\n        )\n\n    # Handle version differences between transformers v4 and v5\n    trainer_kwargs = {\n        \"model\": model,\n        \"value_model\": value_model,\n        \"reference_model\": reference_pipe,\n        \"args\": training_args,\n        \"train_dataset\": train_dataset,\n        \"eval_dataset\": test_dataset,\n        \"data_collator\": data_collator,\n        \"compute_metrics\": compute_metrics_func,\n        \"reward_components\": {\n            'micro_f1': default_f1_reward,\n        },\n    }\n\n    if version.parse(transformers.__version__) < version.parse(\"5.0.0\"):\n        trainer_kwargs[\"tokenizer\"] = tokenizer\n    else:\n        trainer_kwargs[\"processing_class\"] = tokenizer\n\n    trainer = RLTrainer(**trainer_kwargs)\n    trainer.train()\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model_name', type=str, default= \"knowledgator/gliclass-modern-base-v2.0-init\")\n    parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')\n    parser.add_argument('--label_model_name', type=str, default = \"BAAI/bge-small-en-v1.5\")\n    parser.add_argument('--reference_model', type=str, default = None)\n    parser.add_argument('--set_value_model', type=bool, default = True)\n    parser.add_argument('--save_path', type=str, default = 'models/')\n    parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')\n    parser.add_argument('--problem_type', type=str, default='multi_label_classification')\n    parser.add_argument('--pooler_type', type=str, default='avg')\n    parser.add_argument('--scorer_type', type=str, default='simple')\n    parser.add_argument('--architecture_type', type=str, default='uni-encoder')\n    parser.add_argument('--normalize_features', type=bool, default=False)\n    parser.add_argument('--extract_text_features', type=bool, default=False)\n    parser.add_argument('--prompt_first', type=bool, default=True)\n    parser.add_argument('--use_lstm', type=bool, default=False)\n    parser.add_argument('--squeeze_layers', type=bool, default=False)\n    parser.add_argument('--num_epochs', type=int, default=1)\n    parser.add_argument('--batch_size', type=int, default=32)\n    parser.add_argument('--encoder_lr', type=float, default=2e-6)\n    parser.add_argument('--others_lr', type=float, default=3e-6)\n    parser.add_argument('--encoder_weight_decay', type=float, default=0.01)\n    parser.add_argument('--others_weight_decay', type=float, default=0.01)\n    parser.add_argument('--warmup_ratio', type=float, default=0.05)\n    parser.add_argument('--lr_scheduler_type', type=str, default='linear')\n    parser.add_argument('--focal_loss_alpha', type=float, default=-1)\n    parser.add_argument('--focal_loss_gamma', type=float, default=-1)\n    parser.add_argument('--focal_loss_reduction', type=str, default='none', choices=['none', 'mean', 'sum'])\n    parser.add_argument('--labels_smoothing', type=float, default=-1)\n    parser.add_argument('--entropy_beta', type=float, default=-1)\n    parser.add_argument('--kl_beta', type=float, default=0.1)\n    parser.add_argument('--clip_range', type=float, default=0.2)\n    parser.add_argument('--num_rl_iters', type=int, default=2)\n    parser.add_argument('--contrastive_loss_coef', type=float, default=0.)\n    parser.add_argument('--max_length', type=int, default=2048)\n    parser.add_argument('--save_steps', type=int, default=300)\n    parser.add_argument('--save_total_limit', type=int, default=3)\n    parser.add_argument('--num_workers', type=int, default=12)\n    parser.add_argument('--fp16', type=bool, default=False)\n    parser.add_argument('--use_compute_metrics', type=bool, default=False)\n    args = parser.parse_args()\n\n    main(args)\n"
  }
]