Full Code of Knowledgator/GLiClass for AI

main 6ccf83fe5130 cached
35 files
364.2 KB
83.2k tokens
409 symbols
1 requests
Download .txt
Showing preview only (379K chars total). Download the full file or copy to clipboard to get everything.
Repository: Knowledgator/GLiClass
Branch: main
Commit: 6ccf83fe5130
Files: 35
Total size: 364.2 KB

Directory structure:
gitextract_2n3b6llp/

├── .github/
│   └── workflows/
│       ├── release.yaml
│       └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── demo.py
├── gliclass/
│   ├── __init__.py
│   ├── config.py
│   ├── data_processing.py
│   ├── layers.py
│   ├── loss_functions.py
│   ├── model.py
│   ├── ops.py
│   ├── pipeline.py
│   ├── poolings.py
│   ├── scorers.py
│   ├── serve/
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── memory.py
│   │   └── server.py
│   ├── training.py
│   └── utils.py
├── notebooks/
│   └── finetuning.ipynb
├── pyproject.toml
├── serve_configs/
│   └── serve_config.yaml
├── test_gliclass.py
├── tests/
│   ├── test_data_processing.py
│   ├── test_loss_functions.py
│   ├── test_poolings.py
│   ├── test_scorers.py
│   └── test_utils.py
├── train.py
└── train_rl.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/release.yaml
================================================
name: Release GLiClass to PyPI

on:
  push:
    tags:
      - 'v*'  # Trigger on version tags (e.g., v1.0.0, v2.1.3)

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  build:
    name: Build distribution 📦
    runs-on: ubuntu-latest

    steps:
    - uses: actions/checkout@v6
      with:
        persist-credentials: false
    - name: Set up Python
      uses: actions/setup-python@v6
      with:
        python-version: "3.x"
    - name: Install pypa/build
      run: >-
        python3 -m
        pip install
        build
        --user
    - name: Build a binary wheel and a source tarball
      run: python3 -m build
    - name: Store the distribution packages
      uses: actions/upload-artifact@v5
      with:
        name: python-package-distributions
        path: dist/

  publish-to-pypi:
    name: >-
      Publish Python 🐍 distribution 📦 to PyPI
    if: startsWith(github.ref, 'refs/tags/')  # only publish to PyPI on tag pushes
    needs:
    - build
    runs-on: ubuntu-latest
    environment:
      name: pypi
      url: https://pypi.org/project/gliclass/  # Replace <package-name> with your PyPI project name
    permissions:
      id-token: write  # IMPORTANT: mandatory for trusted publishing

    steps:
    - name: Checkout code
      uses: actions/checkout@v6
      with:
        fetch-depth: 0  # Fetch all history to check branches
    - name: Verify tag is on main branch
      run: |
        if ! git branch -r --contains ${{ github.ref_name }} | grep -q 'origin/main'; then
          echo "Error: Tag ${{ github.ref_name }} is not on the main branch"
          exit 1
        fi
        echo "✓ Tag ${{ github.ref_name }} is on main branch"
    - name: Download all the dists
      uses: actions/download-artifact@v6
      with:
        name: python-package-distributions
        path: dist/
    - name: Publish distribution 📦 to PyPI
      uses: pypa/gh-action-pypi-publish@release/v1

================================================
FILE: .github/workflows/tests.yml
================================================
name: Tests

on:
  push:
    branches:
      - main
  pull_request:
    branches:
      - main
  workflow_dispatch:

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

jobs:
  test:
    name: pytest (Python ${{ matrix.python-version }})
    runs-on: ubuntu-latest
    strategy:
      fail-fast: false
      matrix:
        python-version: ["3.10", "3.11", "3.12"]

    steps:
      - name: Check out repository
        uses: actions/checkout@v4

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python-version }}

      - name: Cache pip
        uses: actions/cache@v4
        with:
          path: ~/.cache/pip
          key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('pyproject.toml') }}
          restore-keys: |
            ${{ runner.os }}-py${{ matrix.python-version }}-pip-

      - name: Install dependencies
        run: |
          python -m pip install --upgrade pip
          pip install -e .
          pip install pytest pytest-asyncio

      - name: Run pytest
        run: pytest -v --tb=short

  lint:
    name: ruff
    runs-on: ubuntu-latest
    steps:
      - name: Check out repository
        uses: actions/checkout@v4

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: "3.12"

      - name: Install ruff
        run: pip install ruff

      - name: ruff check
        run: ruff check gliclass

      - name: ruff format --check
        run: ruff format --check gliclass


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

#custom
models/
wandb/
gradio_cached_examples/
test.ipynb
demo1.py
.gradio/
uv.lock

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# ⭐ GLiClass: Generalist and Lightweight Model for Sequence Classification

**GLiClass** is an efficient, zero-shot sequence classification model inspired by the [GLiNER](https://github.com/urchade/GLiNER/tree/main) framework. It achieves comparable performance to traditional cross-encoder models while being significantly more computationally efficient, offering classification results approximately **10 times faster** by performing classification in a single forward pass.

<p align="center">
    <a href="https://medium.com/@knowledgrator/pushing-zero-shot-classification-to-the-limit-696a2403032f">📄 Blog</a>
    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>
    <a href="https://discord.gg/dkyeAgs9DG">📢 Discord</a>
    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>
    <a href="https://huggingface.co/spaces/knowledgator/GLiClass_SandBox">📺 Demo</a>
    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>
    <a href="https://huggingface.co/models?sort=trending&search=gliclass">🤗 Available models</a>
    <span>&nbsp;&nbsp;•&nbsp;&nbsp;</span>
    <a href="https://colab.research.google.com/github/Knowledgator/GLiClass/blob/main/finetuning.ipynb">
        <img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />
    </a>
</p>

### 🚀 Quick Start

Install GLiClass easily using pip:

```bash
pip install gliclass
```

#### Install from Source

Clone and install directly from GitHub:

```bash
git clone https://github.com/Knowledgator/GLiClass
cd GLiClass

python -m venv venv
source venv/bin/activate  # Windows: venv\Scripts\activate

pip install -r requirements.txt
pip install .
```

Verify your installation:

```python
import gliclass
print(gliclass.__version__)
```

### 🧑‍💻 Usage Example

```python
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer

model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1.0")

pipeline = ZeroShotClassificationPipeline(
    model, tokenizer, classification_type='multi-label', device='cuda:0'
)

text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]

for result in results:
    print(f"{result['label']} => {result['score']:.3f}")
```

### 🔥 New Features

#### Hierarchical Labels

GLiClass now supports hierarchical label structures using dot notation:

```python
hierarchical_labels = {
    "sentiment": ["positive", "negative", "neutral"],
    "topic": ["product", "service", "shipping"]
}

text = "The product quality is amazing but delivery was slow"
results = pipeline(text, hierarchical_labels, threshold=0.5)[0]

for result in results:
    print(f"{result['label']} => {result['score']:.3f}")
# Output:
# sentiment.positive => 0.892
# topic.product => 0.921
# topic.shipping => 0.763
```

Get hierarchical output matching your input structure:

```python
results = pipeline(text, hierarchical_labels, return_hierarchical=True)[0]
print(results)
# Output:
# {
#     "sentiment": {"positive": 0.892, "negative": 0.051, "neutral": 0.124},
#     "topic": {"product": 0.921, "service": 0.153, "shipping": 0.763}
# }
```

#### Few-Shot Examples

Improve classification accuracy with in-context examples using the `<<EXAMPLE>>` token:

```python
examples = [
    {
        "text": "Love this item, great quality!",
        "labels": ["positive", "product"]
    },
    {
        "text": "Customer support was unhelpful",
        "labels": ["negative", "service"]
    }
]

text = "Fast delivery and the item works perfectly!"
labels = ["positive", "negative", "product", "service", "shipping"]

results = pipeline(text, labels, examples=examples, threshold=0.5)[0]

for result in results:
    print(f"{result['label']} => {result['score']:.3f}")
```

#### Task Description Prompts

Add custom prompts to guide the classification task:

```python
text = "The battery life on this phone is incredible"
labels = ["positive", "negative", "neutral"]

results = pipeline(
    text,
    labels,
    prompt="Classify the sentiment of this product review:",
    threshold=0.5
)[0]
```

Use per-text prompts for batch processing:

```python
texts = ["Review about electronics", "Review about clothing"]
prompts = [
    "Analyze this electronics review:",
    "Analyze this clothing review:"
]

results = pipeline(texts, labels, prompt=prompts)
```

#### Long Document Classification

Process long documents with automatic text chunking:

```python
from gliclass import ZeroShotClassificationWithChunkingPipeline

chunking_pipeline = ZeroShotClassificationWithChunkingPipeline(
    model,
    tokenizer,
    text_chunk_size=8192,
    text_chunk_overlap=256,
    labels_chunk_size=8
)

long_document = "..." # Very long text
labels = ["category1", "category2", "category3"]

results = chunking_pipeline(long_document, labels, threshold=0.5)
```

### 🌟 Retrieval-Augmented Classification (RAC)

With new models trained with retrieval-agumented classification, such as [this model](https://huggingface.co/knowledgator/gliclass-base-v2.0-rac-init) you can specify examples to improve classification accuracy:

```python
example = {
    "text": "A new machine learning platform automates complex data workflows but faces integration issues.",
    "all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
    "true_labels": ["AI", "integration", "automation"]
}

text = "The new AI-powered tool streamlines data analysis but has limited integration capabilities."
labels = ["AI", "automation", "data_analysis", "usability", "integration"]

results = pipeline(text, labels, threshold=0.1, rac_examples=[example])[0]

for predict in results:
    print(f"{predict['label']} => {predict['score']:.3f}")
```

### 🚀 Production Serving

Deploy GLiClass with Ray Serve for production workloads with dynamic batching and memory-aware processing.

#### Installation

```bash
pip install gliclass[serve]
```

#### Quick Start

```bash
# Default model
python -m gliclass.serve

# Specify model and port
python -m gliclass.serve --model knowledgator/gliclass-edge-v3.0 --port 8000

# With config file
python -m gliclass.serve --config serve_configs/serve_config.yaml
```

#### Python Client

```python
from gliclass.serve import GLiClassClient

client = GLiClassClient(url="http://localhost:8000/gliclass")

result = client.classify(
    text="This is a great product!",
    labels=["positive", "negative", "neutral"],
    threshold=0.3,
)
print(result)  # [{"label": "positive", "score": 0.95}, ...]
```

#### HTTP API

The HTTP endpoint processes one text per request.

```bash
curl -X POST http://localhost:8000/gliclass \
  -H "Content-Type: application/json" \
  -d '{
    "texts": "This is a great product!",
    "labels": ["positive", "negative", "neutral"],
    "threshold": 0.3
  }'

# Response: [{"label": "positive", "score": 0.95}, ...]
```

**Note:** For batch processing multiple texts, use the `ZeroShotClassificationPipeline` directly instead of the serving API.

See `serve_configs/serve_config.yaml` for full configuration options.

### 🎯 Key Use Cases

- **Sentiment Analysis:** Rapidly classify texts as positive, negative, or neutral.
- **Document Classification:** Efficiently organize and categorize large document collections.
- **Search Results Re-ranking:** Improve relevance and precision by reranking search outputs.
- **News Categorization:** Automatically tag and organize news articles into predefined categories.
- **Fact Checking:** Quickly validate and categorize statements based on factual accuracy.

### 🛠️ How to Train

Prepare your training data as follows:

```json
[
  {"text": "Sample text.", "all_labels": ["sports", "science", "business"], "true_labels": ["sports"]},
  ...
]
```

Optionally, specify confidence scores explicitly:

```json
[
  {"text": "Sample text.", "all_labels": ["sports", "science"], "true_labels": {"sports": 0.9}},
  ...
]
```

Please, refer to the `train.py` script to set up your training from scratch or fine-tune existing models.

### ⚙️ Advanced Configuration

#### Architecture Types

GLiClass supports multiple architecture types:

- **uni-encoder**: Single encoder for both text and labels (default, most efficient)
- **bi-encoder**: Separate encoders for text and labels
- **bi-encoder-fused**: Bi-encoder with label embeddings fused into text encoding
- **encoder-decoder**: Encoder-decoder architecture for sequence-to-sequence tasks

```python
from gliclass import GLiClassBiEncoder

# Load a bi-encoder model
model = GLiClassBiEncoder.from_pretrained("knowledgator/gliclass-biencoder-v1.0")
```

#### Pooling Strategies

Configure how token embeddings are pooled:

- `first`: First token (CLS token)
- `avg`: Average pooling
- `max`: Max pooling
- `last`: Last token
- `sum`: Sum pooling
- `rms`: Root mean square pooling
- `abs_max`: Max of absolute values
- `abs_avg`: Average of absolute values

```python
from gliclass import GLiClassModelConfig

config = GLiClassModelConfig(
    pooling_strategy='avg',
    class_token_pooling='average'  # or 'first'
)
```

#### Scoring Mechanisms

Choose different scoring mechanisms for classification:

- `simple`: Dot product (fastest)
- `weighted-dot`: Weighted dot product with learned projections
- `mlp`: Multi-layer perceptron scorer
- `hopfield`: Hopfield network-based scorer

```python
config = GLiClassModelConfig(
    scorer_type='mlp'
)
```

---

### Flash Attention Backends

GLiClass supports optional flash attention backends for faster inference.

#### Install

```bash
pip install flashdeberta   # DeBERTa v2
pip install turbot5        # T5 / mT5
```

---

#### FlashDeBERTa (DeBERTa v2)

Enable via environment variable:

```bash
export USE_FLASHDEBERTA=1
```

If `flashdeberta` is installed, DeBERTa v2 models will use `FlashDebertaV2Model`.
Otherwise, GLiClass falls back to `DebertaV2Model`.

---

#### TurboT5 (T5 / mT5)

Enable via environment variable:

```bash
export TURBOT5_ATTN_TYPE=triton-basic
```

If `turbot5` is installed, T5 / mT5 models will use `FlashT5EncoderModel`.
Otherwise, GLiClass falls back to `T5EncoderModel`.

Notes:
* Flash backends are **optional**
* Enabled automatically when available
* No code changes required

Want it even tighter (single block), or is this the sweet spot?


## 📚 Citations

If you find GLiClass useful in your research or project, please cite our papers:


```bibtex
@misc{stepanov2025gliclassgeneralistlightweightmodel,
      title={GLiClass: Generalist Lightweight Model for Sequence Classification Tasks}, 
      author={Ihor Stepanov and Mykhailo Shtopko and Dmytro Vodianytskyi and Oleksandr Lukashov and Alexander Yavorskyi and Mykyta Yaroshenko},
      year={2025},
      eprint={2508.07662},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2508.07662}, 
}
```


================================================
FILE: demo.py
================================================
"""
GLiClass Enhanced Demo with Advanced Features

Features:
- Task description prompts
- Hierarchical label inputs (JSON format)
- Few-shot examples
- Hierarchical output structure
- Label descriptions
"""

import json
from typing import Dict, List, Any, Union, Optional
import gradio as gr
import torch
from transformers import AutoTokenizer

from gliclass import GLiClassModel, ZeroShotClassificationPipeline

# Initialize model and pipeline
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_path = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

pipeline = ZeroShotClassificationPipeline(
    model, tokenizer, 
    classification_type='multi-label', 
    device=device
)

# ============== Example Texts ==============

TEXT_PRODUCT_REVIEW = """
I recently purchased the Sony WH-1000XM4 Wireless Noise-Canceling Headphones from Amazon and I must say, I'm thoroughly impressed. The package arrived in New York within 2 days, thanks to Amazon Prime's expedited shipping.

The headphones themselves are remarkable. The noise-canceling feature works like a charm in the bustling city environment, and the 30-hour battery life means I don't have to charge them every day. Connecting them to my Samsung Galaxy S21 was a breeze, and the sound quality is second to none.

I also appreciated the customer service from Amazon when I had a question about the warranty. They responded within an hour and provided all the information I needed.

However, the headphones did not come with a hard case, which was listed in the product description. I contacted Amazon, and they offered a 10% discount on my next purchase as an apology.

Overall, I'd give these headphones a 4.5/5 rating and highly recommend them to anyone looking for top-notch quality in both product and service.
"""

TEXT_TECH_COMPANIES = """
Apple Inc. is an American multinational technology company headquartered in Cupertino, California. Apple is the world's largest technology company by revenue, with US$394.3 billion in 2022 revenue. As of March 2023, Apple is the world's biggest company by market capitalization.

Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975 to develop and sell BASIC interpreters for the Altair 8800. During his career at Microsoft, Gates held the positions of chairman, chief executive officer, president and chief software architect.

Apple was founded as Apple Computer Company on April 1, 1976, by Steve Wozniak, Steve Jobs (1955–2011) and Ronald Wayne to develop and sell Wozniak's Apple I personal computer.
"""

TEXT_SCIENTIFIC = """
Several studies have reported its pharmacological activities, including anti-inflammatory, antimicrobial, and antitumoral effects. 
The effect of E-anethole was studied in the osteosarcoma MG-63 cell line, and the antiproliferative activity was evaluated by an MTT assay. 
It showed a GI50 value of 60.25 μM with apoptosis induction through the mitochondrial-mediated pathway. Additionally, it induced cell cycle arrest at the G0/G1 phase, up-regulated the expression of p53, caspase-3, and caspase-9, and down-regulated Bcl-xL expression.
"""

TEXT_RESTAURANT_REVIEW = """
We visited La Maison last Friday for our anniversary dinner. The ambiance was absolutely stunning - dim lighting, soft jazz music, and elegant table settings. Our waiter, Marcus, was incredibly attentive without being intrusive.

For appetizers, we had the truffle bruschetta and the soup of the day. Both were divine! The main courses - filet mignon for me and lobster risotto for my wife - were cooked to perfection. 

The only downside was the wait time for our desserts, which took about 25 minutes. However, the chocolate soufflé was worth the wait!

Price was on the higher side ($180 for two), but the quality justified the cost. Will definitely return!
"""

TEXT_NEWS_POLITICS = """
The Senate passed a landmark bipartisan infrastructure bill late Thursday night, allocating $1.2 trillion for roads, bridges, broadband internet, and clean energy initiatives. The vote was 69-30, with 19 Republican senators joining all Democrats in support.

President Biden called the passage "a historic investment in America's future" and urged the House to act quickly. However, progressive Democrats have signaled they won't vote for the infrastructure bill unless it's paired with a larger social spending package.

Senate Minority Leader criticized portions of the bill related to climate spending, calling them "unnecessary green new deal provisions," while environmental groups praised the clean energy investments as "a step in the right direction, but not nearly enough."
"""

TEXT_SPORTS = """
In a thrilling overtime finish, the Lakers defeated the Celtics 118-112 in Game 7 of the NBA Finals. LeBron James delivered a historic performance with 42 points, 16 rebounds, and 10 assists, securing his fifth championship ring and fourth Finals MVP award.

The game was tied at 102 with 30 seconds remaining in regulation when Marcus Smart hit a contested three-pointer. However, James answered with a driving layup at the buzzer to force overtime.

In the extra period, the Lakers outscored Boston 16-10, with Anthony Davis contributing two crucial blocks in the final minute. "This is what you dream about as a kid," James said in the post-game interview. "Playing against the Celtics, Game 7, everything on the line."
"""

TEXT_MOVIE_REVIEW = """
Christopher Nolan's "Oppenheimer" is a masterwork of biographical cinema that demands to be seen on the largest screen possible. Cillian Murphy delivers a career-defining performance as J. Robert Oppenheimer, capturing both the brilliance and moral anguish of the father of the atomic bomb.

The film's nonlinear structure, weaving between the Manhattan Project, the 1954 security hearing, and the 1959 Lewis Strauss confirmation hearing, could have been confusing. Instead, Nolan crafts a compelling narrative that builds to a devastating emotional climax.

At three hours, some viewers may find the pacing challenging, particularly in the courtroom sequences. However, the technical achievements - Ludwig Göransson's haunting score, Hoyte van Hoytema's IMAX cinematography - make this an unmissable theatrical experience. Rating: 9/10
"""

TEXT_TECH_STARTUP = """
San Francisco-based AI startup Anthropic announced today it has raised $450 million in Series C funding, valuing the company at $5 billion. The round was led by Spark Capital, with participation from Google and existing investors.

Founded in 2021 by former OpenAI researchers Dario and Daniela Amodei, Anthropic has positioned itself as a leader in AI safety research. The company's Claude assistant has gained significant market share in the enterprise segment.

"This funding will accelerate our research into interpretable and steerable AI systems," said CEO Dario Amodei. "We believe safety and capability go hand in hand." The company plans to double its research team and expand internationally, with offices planned in London and Tokyo.
"""

TEXT_HEALTH_WELLNESS = """
A new study published in the Journal of the American Medical Association suggests that intermittent fasting may offer significant benefits beyond weight loss. Researchers followed 500 participants over two years and found improvements in cardiovascular health markers, insulin sensitivity, and cognitive function.

Participants who followed a 16:8 fasting protocol (eating within an 8-hour window) showed a 15% reduction in LDL cholesterol and a 20% improvement in fasting glucose levels compared to the control group.

However, experts caution that intermittent fasting isn't suitable for everyone. "Pregnant women, people with a history of eating disorders, and those with certain medical conditions should consult their doctor first," said Dr. Sarah Chen, the study's lead author. "It's not a magic solution, but for many people, it can be a sustainable approach to improving metabolic health."
"""

TEXT_TRAVEL = """
Hidden among the limestone karsts of Ha Long Bay, Cat Ba Island offers travelers an authentic Vietnamese experience away from the tourist crowds. We spent five days exploring this gem and discovered why it's becoming a favorite among backpackers and adventure seekers.

The island's national park features challenging hikes through tropical rainforest, with the trek to the peak of Ngu Lam offering panoramic views of the bay. We also kayaked through hidden lagoons and explored caves that few tourists ever see.

Accommodation ranges from basic hostels ($8/night) to comfortable eco-resorts ($60/night). The seafood is incredibly fresh - we had the best grilled squid of our lives at a family-run restaurant in Cat Ba Town for just $5. Pro tip: rent a motorbike to explore the quieter beaches on the island's east side.
"""

TEXT_COOKING_RECIPE = """
This Thai green curry comes together in just 30 minutes and tastes better than takeout. The secret is making your own curry paste - it takes an extra 10 minutes but the flavor difference is remarkable.

For the paste, blend together: 10 green chilies, 4 garlic cloves, 2 shallots, 1 stalk lemongrass, 1 inch galangal, handful of cilantro stems, 1 tsp cumin, 1 tsp coriander, zest of 1 lime, and 2 tbsp fish sauce. 

Heat coconut oil in a wok, fry the paste for 2 minutes until fragrant. Add chicken (or tofu), cook until browned. Pour in coconut milk, add bamboo shoots, Thai eggplant, and basil. Simmer for 15 minutes. Season with palm sugar and more fish sauce to taste.

Serve over jasmine rice with extra chilies on the side. This recipe serves 4 and can be made ahead - the flavors actually improve overnight.
"""

TEXT_FINANCIAL_ADVICE = """
With inflation running at 4.2% and the Fed signaling more rate hikes, many investors are wondering how to position their portfolios. Here's what our analysis suggests for Q4 2024.

Fixed income is finally attractive again. With 10-year Treasury yields above 4.5%, bonds offer meaningful real returns for the first time in years. We recommend increasing allocation to investment-grade corporate bonds and TIPS for inflation protection.

For equities, we're cautiously optimistic on value stocks, particularly in the energy and financial sectors. Tech valuations remain stretched despite recent pullbacks. International developed markets, especially Japan and Europe, offer better risk-reward at current levels.

Remember: past performance doesn't guarantee future results. This is general information, not personalized advice. Consult a financial advisor before making investment decisions.
"""

TEXT_ENVIRONMENTAL = """
The Great Barrier Reef experienced its sixth mass bleaching event in a decade this summer, with aerial surveys showing 91% of reefs affected. Scientists warn that without dramatic action on climate change, the world's largest coral ecosystem may not survive beyond 2050.

"We're witnessing the collapse of one of Earth's most biodiverse ecosystems in real time," said Dr. Terry Hughes of James Cook University. Water temperatures reached 2°C above the February average, causing corals to expel the symbiotic algae that give them color and nutrients.

Some researchers are experimenting with heat-resistant coral varieties and cloud-brightening technology to shade reefs. However, most scientists agree these are stopgap measures. "The only real solution is rapid decarbonization," Hughes said. "Everything else is just buying time."
"""

TEXT_EDUCATION = """
The debate over standardized testing in American schools has intensified following a new report showing significant post-pandemic learning gaps. The National Assessment of Educational Progress found that fourth-grade math scores dropped to levels not seen since 2005.

Proponents of testing argue that standardized assessments are essential for identifying struggling students and holding schools accountable. "Without data, we're flying blind," said Education Secretary Miguel Cardona. "Tests help us direct resources where they're needed most."

Critics counter that high-stakes testing narrows the curriculum and increases student stress without improving outcomes. "We're testing kids more than ever, but educational outcomes aren't improving," said education researcher Dr. Pasi Sahlberg. "Countries like Finland, which use minimal standardized testing, consistently outperform the US."
"""

TEXT_FASHION = """
Milan Fashion Week wrapped up yesterday with several surprising trends that will likely dominate fall/winter 2025. After years of quiet luxury and minimalism, designers are embracing bold maximalism - think dramatic volumes, clashing prints, and unapologetic color.

Prada's collection featured oversized coats with exaggerated shoulders paired with flowing silk pants, while Gucci returned to its pattern-mixing roots under new creative direction. Versace went full baroque with gold-embroidered gowns that would feel at home in a Renaissance painting.

Sustainability remained a talking point, with Stella McCartney showcasing a collection made entirely from recycled ocean plastic. However, critics noted that the industry still has far to go. "One sustainable collection doesn't offset the environmental impact of fast fashion," noted fashion journalist Vanessa Friedman. "The industry needs systemic change, not just good PR."
"""

TEXT_LEGAL_CASE = """
The Supreme Court agreed Monday to hear a case that could reshape the boundaries of free speech on social media platforms. The case, NetChoice v. Paxton, challenges Texas and Florida laws that prohibit large social media companies from removing certain political content.

Tech companies argue that the First Amendment protects their right to moderate content as they see fit, similar to how newspapers decide what to publish. "Forcing platforms to host speech they find objectionable is compelled speech, which the Constitution forbids," said NetChoice counsel Paul Clement.

Texas and Florida counter that social media platforms function as common carriers or public utilities and should be subject to similar non-discrimination requirements. "These companies have become the modern public square," said Texas Attorney General Ken Paxton. "They shouldn't be able to silence voices based on political viewpoint."
"""

TEXT_GAMING = """
After three years in development hell, "Hollow Eclipse" has finally launched - and it's everything fans hoped for. This action RPG from indie studio Moonlight Games delivers a haunting 40-hour adventure that rivals titles from studios with ten times the budget.

The combat system strikes a perfect balance between accessibility and depth. Basic attacks and dodges are simple to execute, but mastering the "shadow merge" mechanic - which lets you temporarily possess enemies - adds layers of strategy. Boss fights are challenging without feeling unfair, though the final boss may take even experienced players dozens of attempts.

Where the game truly shines is its atmosphere. The decaying gothic city of Velmoor is rendered in stunning hand-drawn art, and the ambient soundtrack creates constant unease. The story tackles themes of grief and memory with surprising emotional maturity. Minor technical issues (occasional frame drops, one softlock) can't diminish this achievement. Score: 9.5/10
"""

TEXT_REAL_ESTATE = """
The housing market is sending mixed signals as we enter 2025. Existing home sales fell for the third consecutive month, down 4.1% in November, yet prices continue to climb in most metropolitan areas. The median home price hit $416,000, up 3.8% year-over-year.

Low inventory remains the central issue. Many homeowners are reluctant to sell because they've locked in sub-3% mortgage rates and don't want to trade up to today's 7% rates. This "lock-in effect" has created a severe shortage of listings, particularly in the starter home category.

"We're seeing bidding wars even in this high-rate environment because there's simply nothing to buy," said economist Lawrence Yun. First-time buyers are particularly squeezed, with affordability at its worst level since 1984. Some markets, including Austin and Phoenix, are showing price corrections, but coastal cities remain stubbornly expensive.
"""

TEXT_MENTAL_HEALTH = """
Workplace burnout has reached epidemic proportions, with a new Gallup survey finding that 76% of employees experience burnout at least sometimes. But recognizing burnout isn't always straightforward - it often manifests differently than simple exhaustion.

The three hallmarks of burnout are: emotional exhaustion (feeling drained and unable to cope), depersonalization (becoming cynical and detached from work), and reduced personal accomplishment (feeling ineffective regardless of actual performance).

Recovery requires more than a vacation. "You can't just rest your way out of burnout," says psychologist Dr. Christina Maslach, who pioneered burnout research. "You need to address the root causes - usually workload, lack of control, insufficient recognition, or values conflicts." Strategies include setting firm boundaries, delegating tasks, and having honest conversations with managers about sustainable workloads. In severe cases, professional support from a therapist can help.
"""

TEXT_ASTRONOMY = """
NASA's James Webb Space Telescope has detected what may be signs of biological activity in the atmosphere of K2-18b, an exoplanet 120 light-years away. The discovery has electrified the scientific community, though researchers caution against jumping to conclusions.

The telescope's spectrometers identified dimethyl sulfide (DMS), a molecule produced almost exclusively by living organisms on Earth. Webb also confirmed the presence of methane and carbon dioxide, consistent with a water-rich atmosphere.

"This is tantalizing, but not definitive proof of life," said lead researcher Dr. Nikku Madhusudhan. "DMS could potentially be produced by unknown geological processes. We need more observations." K2-18b is a "Hycean" world - a planet with a hydrogen-rich atmosphere and potentially a liquid water ocean beneath. If confirmed, this would be humanity's first detection of a potential biosignature beyond our solar system.
"""


def parse_labels_input(labels_input: str) -> Union[List[str], Dict[str, Any]]:
    """
    Parse labels input - supports both comma-separated and JSON hierarchical format.
    
    Examples:
    - "positive, negative, neutral" -> ["positive", "negative", "neutral"]
    - '{"sentiment": ["positive", "negative"], "topic": ["food", "service"]}' -> dict
    """
    labels_input = labels_input.strip()
    
    # Try parsing as JSON first (for hierarchical labels)
    if labels_input.startswith('{'):
        try:
            return json.loads(labels_input)
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format for hierarchical labels: {e}")
    
    # Otherwise, treat as comma-separated flat labels
    labels = [label.strip() for label in labels_input.split(',') if label.strip()]
    return labels


def parse_examples_input(examples_input: str) -> Optional[List[Dict[str, Any]]]:
    """
    Parse few-shot examples input (JSON format).
    
    Expected format:
    [
        {"text": "Example text 1", "labels": ["label1", "label2"]},
        {"text": "Example text 2", "labels": ["label3"]}
    ]
    """
    if not examples_input or not examples_input.strip():
        return None
    
    try:
        examples = json.loads(examples_input.strip())
        if not isinstance(examples, list):
            raise ValueError("Examples must be a JSON array")
        
        for i, ex in enumerate(examples):
            if not isinstance(ex, dict):
                raise ValueError(f"Example {i+1} must be a JSON object")
            if 'text' not in ex:
                raise ValueError(f"Example {i+1} missing 'text' field")
            if 'labels' not in ex and 'true_labels' not in ex:
                raise ValueError(f"Example {i+1} missing 'labels' field")
        
        return examples
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON format for examples: {e}")


def format_output(
    results: Union[List[Dict], Dict], 
    hierarchical: bool = False,
    output_format: str = "visual"
) -> Union[Dict[str, float], str]:
    """Format classification output for Gradio display."""
    
    if output_format == "json":
        return format_as_json(results, hierarchical)
    
    if hierarchical and isinstance(results, dict):
        # Format hierarchical output as readable string
        return format_hierarchical_dict(results)
    
    if isinstance(results, list):
        return {result['label']: float(result['score']) for result in results}
    
    return results


def format_as_json(results: Union[List[Dict], Dict], hierarchical: bool = False) -> str:
    """Format results as pretty-printed JSON string."""
    if hierarchical and isinstance(results, dict):
        # Already in hierarchical dict format
        return json.dumps(results, indent=2, ensure_ascii=False)
    
    if isinstance(results, list):
        # Convert list of predictions to structured format
        output = {
            "predictions": [
                {"label": r["label"], "score": round(r["score"], 4)}
                for r in results
            ],
            "scores": {r["label"]: round(r["score"], 4) for r in results}
        }
        return json.dumps(output, indent=2, ensure_ascii=False)
    
    return json.dumps(results, indent=2, ensure_ascii=False)


def format_hierarchical_dict(d: Dict, indent: int = 0) -> str:
    """Format hierarchical dict for display with visual score bars."""
    lines = []
    prefix = "  " * indent
    
    for key, value in d.items():
        if isinstance(value, dict):
            lines.append(f"{prefix}**{key}**:")
            lines.append(format_hierarchical_dict(value, indent + 1))
        else:
            score_bar = "█" * int(value * 20) + "░" * (20 - int(value * 20))
            lines.append(f"{prefix}{key}: {score_bar} {value:.3f}")
    
    return "\n".join(lines)


def classification(
    text: str,
    labels_input: str,
    threshold: float,
    multi_label: bool,
    prompt: str,
    examples_input: str,
    hierarchical_output: bool,
    output_format: str = "visual"
) -> Union[Dict[str, float], str]:
    """
    Perform classification with all advanced features.
    """
    try:
        # Parse labels (flat or hierarchical)
        labels = parse_labels_input(labels_input)
        
        # Parse few-shot examples
        examples = parse_examples_input(examples_input) if examples_input else None
        
        # Set classification type
        pipeline.pipe.classification_type = 'multi-label' if multi_label else 'single-label'
        
        # Prepare prompt
        task_prompt = prompt.strip() if prompt and prompt.strip() else None
        
        # Run classification
        results = pipeline(
            text, 
            labels, 
            threshold=threshold,
            examples=examples,
            prompt=task_prompt,
            return_hierarchical=hierarchical_output
        )[0]  # Single text, get first result
        
        # Format output based on selected format
        if output_format == "json":
            return format_as_json(results, hierarchical_output)
        elif hierarchical_output:
            return format_hierarchical_dict(results)
        else:
            return {result['label']: float(result['score']) for result in results}
            
    except Exception as e:
        return f"Error: {str(e)}"


# ============== Example Configurations ==============

EXAMPLES = [
    # Example 1: Basic flat labels with prompt
    [
        TEXT_PRODUCT_REVIEW,
        "product review, electronics, positive feedback, negative feedback, customer service, shipping",
        0.5,
        True,
        "Classify this customer review by topic and sentiment:",
        "",
        False,
        "visual"
    ],
    # Example 2: Hierarchical labels for restaurant review
    [
        TEXT_RESTAURANT_REVIEW,
        '''{
    "sentiment": ["positive", "negative", "mixed"],
    "aspects": ["food quality", "service", "ambiance", "price", "wait time"],
    "recommendation": ["would recommend", "would not recommend"]
}''',
        0.4,
        True,
        "Analyze this restaurant review:",
        "",
        True,
        "visual"
    ],
    # Example 3: News article with few-shot examples
    [
        TEXT_NEWS_POLITICS,
        "politics, business, technology, sports, entertainment, science, health",
        0.5,
        True,
        "Classify this news article by category:",
        '''[
    {"text": "The Federal Reserve raised interest rates by 0.25% today, citing persistent inflation concerns.", "labels": ["politics", "business"]},
    {"text": "Scientists discover high new high-temperature superconductor material that works at room temperature.", "labels": ["science", "technology"]}
]''',
        False,
        "visual"
    ],
    # Example 4: Scientific classification with hierarchical output
    [
        TEXT_SCIENTIFIC,
        '''{
    "domain": ["biology", "chemistry", "medicine", "physics"],
    "research_type": ["experimental", "theoretical", "review"],
    "application": ["therapeutic", "diagnostic", "basic research"]
}''',
        0.3,
        True,
        "Classify this scientific abstract:",
        "",
        True,
        "visual"
    ],
    # Example 5: Sports article - single label
    [
        TEXT_SPORTS,
        "basketball, football, soccer, tennis, baseball, hockey, golf",
        0.5,
        False,
        "What sport is this article about?",
        "",
        False,
        "visual"
    ],
    # Example 6: Movie review with detailed sentiment (JSON output)
    [
        TEXT_MOVIE_REVIEW,
        '''{
    "overall_sentiment": ["positive", "negative", "mixed"],
    "aspects_praised": ["acting", "direction", "cinematography", "music", "story", "pacing"],
    "aspects_criticized": ["acting", "direction", "cinematography", "music", "story", "pacing"],
    "recommendation": ["must watch", "worth watching", "skip it"]
}''',
        0.35,
        True,
        "Analyze this movie review in detail:",
        "",
        True,
        "json"
    ],
    # Example 7: Tech startup news
    [
        TEXT_TECH_STARTUP,
        "funding announcement, product launch, acquisition, IPO, partnership, hiring, layoffs, legal",
        0.4,
        True,
        "What type of tech news is this?",
        "",
        False,
        "visual"
    ],
    # Example 8: Health article with hierarchical categories
    [
        TEXT_HEALTH_WELLNESS,
        '''{
    "topic": ["nutrition", "exercise", "mental health", "sleep", "medical research"],
    "content_type": ["research findings", "practical advice", "expert opinion", "warning"],
    "audience": ["general public", "healthcare professionals", "patients"]
}''',
        0.4,
        True,
        "Categorize this health article:",
        "",
        True,
        "visual"
    ],
    # Example 9: Travel content (JSON output)
    [
        TEXT_TRAVEL,
        "destination guide, hotel review, restaurant review, adventure travel, budget travel, luxury travel, travel tips",
        0.4,
        True,
        "What type of travel content is this?",
        "",
        False,
        "json"
    ],
    # Example 10: Recipe classification
    [
        TEXT_COOKING_RECIPE,
        '''{
    "cuisine": ["Thai", "Italian", "Mexican", "Indian", "Chinese", "Japanese", "French", "American"],
    "difficulty": ["easy", "medium", "hard"],
    "meal_type": ["breakfast", "lunch", "dinner", "dessert", "snack"],
    "dietary": ["vegetarian friendly", "vegan friendly", "gluten free", "dairy free", "contains meat"]
}''',
        0.35,
        True,
        "Classify this recipe:",
        "",
        True,
        "visual"
    ],
    # Example 11: Financial content with examples
    [
        TEXT_FINANCIAL_ADVICE,
        "investment advice, market analysis, personal finance, retirement planning, tax advice, economic news",
        0.4,
        True,
        "Categorize this financial content:",
        '''[
    {"text": "Here are 5 ways to maximize your 401k contributions before year end.", "labels": ["personal finance", "retirement planning", "tax advice"]},
    {"text": "The S&P 500 rose 2% today following strong jobs report.", "labels": ["market analysis", "economic news"]}
]''',
        False,
        "visual"
    ],
    # Example 12: Environmental news (JSON output)
    [
        TEXT_ENVIRONMENTAL,
        '''{
    "topic": ["climate change", "biodiversity", "pollution", "conservation", "renewable energy"],
    "tone": ["alarming", "hopeful", "neutral", "urgent"],
    "focus": ["problem description", "solutions", "policy", "research findings"]
}''',
        0.35,
        True,
        "Analyze this environmental article:",
        "",
        True,
        "json"
    ],
    # Example 13: Education debate
    [
        TEXT_EDUCATION,
        "education policy, standardized testing, curriculum, teacher issues, student welfare, technology in education, higher education",
        0.4,
        True,
        "What education topics does this article cover?",
        "",
        False,
        "visual"
    ],
    # Example 14: Fashion news with hierarchy
    [
        TEXT_FASHION,
        '''{
    "content_type": ["trend report", "designer profile", "collection review", "industry news", "sustainability"],
    "season": ["spring/summer", "fall/winter"],
    "market_segment": ["luxury", "fast fashion", "sustainable fashion", "streetwear"]
}''',
        0.4,
        True,
        "Classify this fashion article:",
        "",
        True,
        "visual"
    ],
    # Example 15: Legal case (JSON output)
    [
        TEXT_LEGAL_CASE,
        "constitutional law, criminal law, civil rights, corporate law, intellectual property, free speech, privacy",
        0.4,
        True,
        "What areas of law does this case involve?",
        "",
        False,
        "json"
    ],
    # Example 16: Gaming review with detailed analysis
    [
        TEXT_GAMING,
        '''{
    "genre": ["action", "RPG", "adventure", "puzzle", "strategy", "simulation", "sports"],
    "platform_feel": ["indie", "AAA", "mid-tier"],
    "strengths": ["gameplay", "story", "graphics", "music", "replayability"],
    "weaknesses": ["bugs", "difficulty", "length", "graphics", "story"],
    "recommendation": ["must play", "worth playing", "wait for sale", "skip"]
}''',
        0.35,
        True,
        "Analyze this game review:",
        "",
        True,
        "visual"
    ],
    # Example 17: Real estate market analysis
    [
        TEXT_REAL_ESTATE,
        "market analysis, buying advice, selling advice, investment, rental market, mortgage rates, housing policy",
        0.4,
        True,
        "What real estate topics are covered?",
        "",
        False,
        "visual"
    ],
    # Example 18: Mental health with few-shot (JSON output)
    [
        TEXT_MENTAL_HEALTH,
        '''{
    "topic": ["burnout", "anxiety", "depression", "stress management", "work-life balance"],
    "content_type": ["educational", "self-help advice", "research summary", "personal story"],
    "actionability": ["provides concrete steps", "general awareness", "seeks professional help"]
}''',
        0.35,
        True,
        "Categorize this mental health content:",
        '''[
    {"text": "Feeling overwhelmed? Try the 5-4-3-2-1 grounding technique: notice 5 things you see, 4 you hear...", "labels": ["topic.anxiety", "topic.stress management", "content_type.self-help advice", "actionability.provides concrete steps"]},
    {"text": "A new study links social media use exceeding 3 hours daily with increased rates of depression in teens.", "labels": ["topic.depression", "content_type.research summary", "actionability.general awareness"]}
]''',
        True,
        "json"
    ],
    # Example 19: Astronomy discovery
    [
        TEXT_ASTRONOMY,
        "exoplanets, astrobiology, cosmology, solar system, space exploration, telescopes, astrophysics",
        0.4,
        True,
        "What astronomy topics are discussed?",
        "",
        False,
        "visual"
    ],
    # Example 20: Tech companies - single label
    [
        TEXT_TECH_COMPANIES,
        "company profile, product announcement, financial report, industry analysis, biography, opinion piece",
        0.5,
        False,
        "What is the primary type of this article?",
        "",
        False,
        "visual"
    ],
]


# ============== Gradio Interface ==============

with gr.Blocks(
    title="GLiClass Advanced Demo",
    theme=gr.themes.Soft(
        primary_hue="blue",
        secondary_hue="slate",
    )
) as demo:
    
    gr.Markdown("""
    # 🏷️ GLiClass Advanced Zero-Shot Classification
    
    Enhanced demo featuring **prompts**, **hierarchical labels**, **few-shot examples**, and **structured outputs**.
    """)
    
    with gr.Accordion("📖 How to Use This Demo", open=False):
        gr.Markdown("""
        ## Features Overview
        
        ### 1. Task Description Prompts
        Add a natural language description of the classification task to guide the model.
        
        **Example:** `"Classify this customer review by sentiment and topic:"`
        
        ---
        
        ### 2. Hierarchical Labels (JSON Format)
        Structure your labels in categories for organized classification:
        
        ```json
        {
            "sentiment": ["positive", "negative", "neutral"],
            "topic": ["product", "service", "shipping"],
            "urgency": ["high", "medium", "low"]
        }
        ```
        
        Or use simple comma-separated labels: `positive, negative, neutral`
        
        ---
        
        ### 3. Few-Shot Examples
        Provide examples to guide the model's understanding:
        
        ```json
        [
            {"text": "Great product, love it!", "labels": ["positive", "product"]},
            {"text": "Shipping was delayed by 2 weeks", "labels": ["negative", "shipping"]}
        ]
        ```
        
        ---
        
        ### 4. Hierarchical Output
        When enabled with hierarchical labels, returns structured scores matching your input format.
        """)
    
    with gr.Accordion("💻 Code Example", open=False):
        gr.Code(
            '''from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer

model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1")

pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')

# Basic usage
text = "The product quality is amazing but delivery was slow"
labels = ["positive", "negative", "product", "shipping"]
results = pipeline(text, labels, threshold=0.5)[0]

# With hierarchical labels
hierarchical_labels = {
    "sentiment": ["positive", "negative", "neutral"],
    "topic": ["product", "service", "shipping"]
}

results = pipeline(
    text, 
    hierarchical_labels,
    prompt="Classify this review:",
    return_hierarchical=True
)[0]

# With few-shot examples
examples = [
    {"text": "Love this item!", "labels": ["sentiment.positive", "topic.product"]},
    {"text": "Terrible customer support", "labels": ["sentiment.negative", "topic.service"]}
]

results = pipeline(
    text,
    hierarchical_labels, 
    examples=examples,
    prompt="Classify customer feedback:"
)[0]
''',
            language="python",
        )
    
    with gr.Row():
        with gr.Column(scale=2):
            input_text = gr.Textbox(
                value=EXAMPLES[0][0],
                label="📝 Text Input",
                placeholder="Enter the text you want to classify...",
                lines=8
            )
            
            prompt_input = gr.Textbox(
                value=EXAMPLES[0][4],
                label="💡 Task Description Prompt (Optional)",
                placeholder="E.g., 'Classify this customer review by sentiment and topic:'",
                lines=1
            )
        
        with gr.Column(scale=1):
            labels_input = gr.Textbox(
                value=EXAMPLES[0][1],
                label="🏷️ Labels (comma-separated or JSON)",
                placeholder='positive, negative\n\nOR\n\n{"category": ["label1", "label2"]}',
                lines=6
            )
            
            with gr.Row():
                threshold = gr.Slider(
                    0, 1,
                    value=0.5,
                    step=0.01,
                    label="Threshold",
                    info="Confidence threshold for predictions"
                )
            
            with gr.Row():
                multi_label = gr.Checkbox(
                    value=True,
                    label="Multi-label",
                    info="Allow multiple labels per text"
                )
                hierarchical_output = gr.Checkbox(
                    value=False,
                    label="Hierarchical Output",
                    info="Return structured output matching label hierarchy"
                )
            
            with gr.Row():
                output_format = gr.Radio(
                    choices=["visual", "json"],
                    value="visual",
                    label="Output Format",
                    info="Visual: charts/bars | JSON: raw data"
                )
    
    with gr.Accordion("🎯 Few-Shot Examples (Optional)", open=False):
        examples_input = gr.Textbox(
            value="",
            label="Examples (JSON format)",
            placeholder='''[
    {"text": "Example text 1", "labels": ["label1", "label2"]},
    {"text": "Example text 2", "labels": ["label3"]}
]''',
            lines=5
        )
        gr.Markdown("""
        *Provide labeled examples to guide the model. Each example needs a `text` field and a `labels` array.*
        """)
    
    submit_btn = gr.Button("🚀 Classify", variant="primary", size="lg")
    
    output = gr.Label(label="📊 Classification Results")
    output_text = gr.Textbox(
        label="📊 Hierarchical Results", 
        visible=False, 
        lines=10
    )
    output_json = gr.Code(
        label="📊 JSON Output",
        language="json",
        visible=False,
        lines=15
    )
    
    # Dynamic output visibility based on format and hierarchical toggle
    def update_output_visibility(hierarchical: bool, fmt: str):
        if fmt == "json":
            return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
        elif hierarchical:
            return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
        else:
            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
    
    hierarchical_output.change(
        fn=update_output_visibility,
        inputs=[hierarchical_output, output_format],
        outputs=[output, output_text, output_json]
    )
    
    output_format.change(
        fn=update_output_visibility,
        inputs=[hierarchical_output, output_format],
        outputs=[output, output_text, output_json]
    )
    
    # Classification function wrapper for different outputs
    def classify_wrapper(text, labels, threshold, multi_label, prompt, examples, hierarchical, fmt):
        result = classification(text, labels, threshold, multi_label, prompt, examples, hierarchical, fmt)
        
        if fmt == "json":
            return None, None, result
        elif hierarchical or isinstance(result, str):
            return None, result, None
        else:
            return result, None, None
    
    # Event handlers
    submit_btn.click(
        fn=classify_wrapper,
        inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],
        outputs=[output, output_text, output_json]
    )
    
    input_text.submit(
        fn=classify_wrapper,
        inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],
        outputs=[output, output_text, output_json]
    )
    
    gr.Markdown("### 📚 Example Configurations")
    
    gr.Examples(
        examples=EXAMPLES,
        inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],
        outputs=[output, output_text, output_json],
        fn=classify_wrapper,
        cache_examples=False,
        examples_per_page=5
    )
    
    gr.Markdown("""
    ---
    
    ### 🔧 Tips for Best Results
    
    | Feature | Best Practice |
    |---------|---------------|
    | **Prompts** | Be specific about the task, e.g., "Classify by sentiment:" vs "Analyze:" |
    | **Labels** | Use descriptive labels; "customer service issue" > "service" |
    | **Hierarchical** | Group related labels under categories for organized results |
    | **Examples** | 2-3 diverse examples often improve accuracy significantly |
    | **Threshold** | Start at 0.5, lower for more predictions, raise for higher precision |
    """)


if __name__ == "__main__":
    demo.queue()
    demo.launch(debug=True, share=True)

================================================
FILE: gliclass/__init__.py
================================================
from .model import GLiClassModel, GLiClassBiEncoder, GLiClassUniEncoder, GLiClassEncoderDecoderCLS
from .config import GLiClassModelConfig
from .pipeline import (
    ZeroShotClassificationPipeline,
    BiEncoderZeroShotClassificationPipeline,
    ZeroShotClassificationWithChunkingPipeline,
)

__version__ = "0.1.19"

# Serve module (optional import)
try:
    from . import serve
except ImportError:
    serve = None


================================================
FILE: gliclass/config.py
================================================
from transformers import AutoConfig
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
from transformers.configuration_utils import PretrainedConfig

from .utils import is_module_available

IS_TURBOT5 = is_module_available("turbot5")

if IS_TURBOT5:
    from turbot5.model.config import T5Config
else:
    from transformers import T5Config


logger = logging.get_logger(__name__)


class GLiClassModelConfig(PretrainedConfig):
    model_type = "GLiClass"
    is_composition = True

    def __init__(
        self,
        encoder_config=None,
        encoder_model=None,
        label_model_config=None,
        label_model_name=None,
        class_token_index=-1,
        text_token_index=-1,
        example_token_index=-1,
        ignore_index=-100,
        hidden_size=None,
        projector_hidden_act="gelu",
        vocab_size=None,
        problem_type="single_label_classification",
        max_num_classes=25,
        use_lstm=False,
        initializer_range=0.03,
        scorer_type="simple",
        scorer_num_heads=16,
        scorer_mlp_hidden_size=1024,
        scorer_attn_dropout=0.1,
        pooling_strategy="first",
        class_token_pooling="first",
        focal_loss_alpha=0.5,
        focal_loss_gamma=2,
        focal_loss_reduction=None,
        logit_scale_init_value=2.6592,
        normalize_features=False,
        extract_text_features=False,
        max_labels_alloc: str = "dynamic",
        contrastive_loss_coef=0,
        architecture_type="uni-encoder",
        prompt_first=False,
        squeeze_layers=False,
        layer_wise=False,
        encoder_layer_id=-1,
        embed_class_token=True,
        dropout=0.1,
        use_segment_embeddings=False,
        **kwargs,
    ):
        if isinstance(encoder_config, dict):
            encoder_config["model_type"] = encoder_config.get("model_type", "deberta-v2")
            if encoder_config["model_type"] == "t5":
                encoder_config = T5Config(**encoder_config)
            elif encoder_config["model_type"] in CONFIG_MAPPING:
                encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
            else:
                _name = encoder_model or kwargs.get("encoder_model_name")
                if _name:
                    encoder_config = AutoConfig.from_pretrained(_name, trust_remote_code=True)
                else:
                    encoder_config = PretrainedConfig(**encoder_config)
        elif encoder_config is None:
            encoder_config = CONFIG_MAPPING["deberta-v2"]()

        self.encoder_config = encoder_config
        self.encoder_model_name = encoder_model

        if label_model_name is not None:
            if isinstance(label_model_config, dict):
                label_model_config["model_type"] = label_model_config.get("model_type", "deberta-v2")
                label_model_config = CONFIG_MAPPING[label_model_config["model_type"]](**label_model_config)
            elif label_model_config is None:
                label_model_config = CONFIG_MAPPING["deberta-v2"]()

            self.label_model_config = label_model_config
        else:
            self.label_model_config = None
        self.label_model_name = label_model_name

        if hidden_size is None:
            self.hidden_size = self.encoder_config.hidden_size
        else:
            self.hidden_size = hidden_size

        if vocab_size is None:
            self.vocab_size = self.encoder_config.vocab_size
        else:
            self.vocab_size = vocab_size

        if class_token_index == -1:
            self.class_token_index = self.vocab_size
        else:
            self.class_token_index = class_token_index

        if text_token_index == -1:
            self.text_token_index = self.vocab_size + 1
        else:
            self.text_token_index = text_token_index

        if example_token_index == -1:
            self.example_token_index = self.vocab_size + 2
        else:
            self.example_token_index = example_token_index

        self.ignore_index = ignore_index
        self.projector_hidden_act = projector_hidden_act
        self.problem_type = problem_type
        self.max_num_classes = max_num_classes
        self.initializer_range = initializer_range
        self.scorer_type = scorer_type
        self.scorer_num_heads = scorer_num_heads
        self.scorer_mlp_hidden_size = scorer_mlp_hidden_size
        self.scorer_attn_dropout = scorer_attn_dropout
        self.pooling_strategy = pooling_strategy
        self.class_token_pooling = class_token_pooling
        self.use_lstm = use_lstm
        self.focal_loss_alpha = focal_loss_alpha
        self.focal_loss_gamma = focal_loss_gamma
        self.focal_loss_reduction = focal_loss_reduction
        self.contrastive_loss_coef = contrastive_loss_coef
        self.logit_scale_init_value = logit_scale_init_value
        self.normalize_features = normalize_features
        self.extract_text_features = extract_text_features
        self.max_labels_alloc = max_labels_alloc
        self.architecture_type = architecture_type
        self.prompt_first = prompt_first
        self.squeeze_layers = squeeze_layers
        self.layer_wise = layer_wise
        self.encoder_layer_id = encoder_layer_id
        self.embed_class_token = embed_class_token
        self.pad_token_id = self.encoder_config.pad_token_id
        self.dropout = dropout
        self.use_segment_embeddings = use_segment_embeddings
        super().__init__(**kwargs)


================================================
FILE: gliclass/data_processing.py
================================================
import copy
import random
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence


@dataclass
class AugmentationConfig:
    """Configuration for data augmentation."""

    enabled: bool = True

    # Probability for each augmentation type
    random_label_removal_prob: float = 0.15
    random_label_addition_prob: float = 0.10
    random_text_addition_prob: float = 0.05
    random_add_description_prob: float = 0.25
    random_add_synonyms_prob: float = 0.1
    random_add_examples_prob: float = 0.25
    max_num_examples: int = 5


class DataAugmenter:
    def __init__(self, config, examples, labels, label2description=None):
        self.config = config
        self.examples = examples
        self.labels = sorted(labels)
        self.max_examples = self.config.max_num_examples
        self.label2description = label2description or {}

    def remove_labels(self, true_labels, all_labels):
        if len(all_labels) <= 1:
            return true_labels, all_labels
        k = random.randint(1, len(all_labels))
        all_labels = random.sample(all_labels, k=k)
        true_labels = [lbl for lbl in true_labels if lbl in all_labels]
        return true_labels, all_labels

    def add_random_labels(self, all_labels):
        if not self.labels:
            return all_labels
        num_add = len(all_labels) + 1
        k = random.randint(1, min(num_add, len(self.labels)))
        add_labels = random.sample(self.labels, k=k)
        all_labels.extend(add_labels)
        return all_labels

    def add_random_text(self, text, all_labels):
        if not self.examples:
            return text
        example = random.sample(self.examples, k=1)[0]
        curr_labels = example["all_labels"]
        joint_labels = set(all_labels) & set(curr_labels)
        if len(joint_labels):
            return text
        else:
            if random.randint(0, 1):
                text = example["text"] + " " + text
            else:
                text = text + " " + example["text"]
            return text

    def add_random_synonyms(self, all_labels):
        """Replace some labels with their synonyms if available."""
        if not self.label2description:
            return all_labels

        augmented_labels = []
        for label in all_labels:
            if label in self.label2description:
                label_info = self.label2description[label]
                synonyms = label_info.get("synonyms", [])

                if synonyms and random.random() < 0.5:
                    augmented_labels.append(random.choice(synonyms))
                else:
                    augmented_labels.append(label)
            else:
                augmented_labels.append(label)

        return augmented_labels

    def add_random_descriptions(self, item):
        """Add descriptions to labels in the text or metadata."""
        if not self.label2description or not item["all_labels"]:
            return item

        max_labels = min(3, len(item["all_labels"]))
        labels_to_describe = random.sample(item["all_labels"], k=random.randint(1, max_labels))

        descriptions = []
        for label in labels_to_describe:
            if label in self.label2description:
                label_info = self.label2description[label]
                desc_list = label_info.get("descriptions", [])
                if desc_list:
                    descriptions.append(f"{label}: {random.choice(desc_list)}")

        if descriptions:
            desc_text = " ".join(descriptions)
            if random.random() < 0.5:
                item["text"] = desc_text + " " + item["text"]
            else:
                item["text"] = item["text"] + " " + desc_text

        return item

    def add_random_examples(self, item):
        """Add example texts with similar labels."""
        if not item["all_labels"]:
            return item

        candidate_examples = item.get("examples", [])

        item_label_set = set(item["all_labels"])

        if not candidate_examples:
            for example in self.examples:
                example_label_set = set(example["true_labels"])
                example_text = example["text"]

                overlap = item_label_set & example_label_set

                # Only consider examples with at least one overlapping label
                if overlap:
                    candidate_examples.append({"text": example_text, "labels": list(example_label_set)})

        if not candidate_examples:
            return item

        # Sort by overlap and select top examples
        random.shuffle(candidate_examples)
        top_candidates = candidate_examples[: self.max_examples]

        num_examples = random.randint(1, min(2, len(top_candidates)))
        selected_examples = random.sample(top_candidates, k=num_examples)

        item["examples"] = selected_examples

        return item

    def augment(self, item):
        if not self.config.enabled:
            return item

        text = copy.deepcopy(item["text"])
        true_labels = copy.deepcopy(item["true_labels"])
        all_labels = copy.deepcopy(item["all_labels"])

        # Create augmented item
        aug_item = {"text": text, "true_labels": true_labels, "all_labels": all_labels}

        # Copy any additional fields
        for key in item:
            if key not in aug_item:
                aug_item[key] = copy.deepcopy(item[key])

        if random.random() < self.config.random_label_removal_prob:
            aug_item["true_labels"], aug_item["all_labels"] = self.remove_labels(
                aug_item["true_labels"], aug_item["all_labels"]
            )

        if random.random() < self.config.random_label_addition_prob:
            aug_item["all_labels"] = self.add_random_labels(aug_item["all_labels"])

        if random.random() < self.config.random_text_addition_prob:
            aug_item["text"] = self.add_random_text(aug_item["text"], aug_item["all_labels"])

        if random.random() < self.config.random_add_synonyms_prob:
            aug_item["all_labels"] = self.add_random_synonyms(aug_item["all_labels"])

        if random.random() < self.config.random_add_description_prob:
            aug_item = self.add_random_descriptions(aug_item)

        if random.random() < self.config.random_add_examples_prob:
            aug_item = self.add_random_examples(aug_item)

        return aug_item


class GLiClassDataset(Dataset):
    def __init__(
        self,
        examples,
        tokenizer,
        augment_config,
        label2description={},
        max_length=512,
        problem_type="multi_label_classification",
        architecture_type="uni-encoder",
        add_description=True,
        prompt_first=False,
        get_negatives=False,
        max_labels=50,
        labels_tokenizer=None,
        shuffle_labels=True,
    ):
        self.tokenizer = tokenizer
        self.labels_tokenizer = labels_tokenizer
        self.label2description = label2description
        self.augment_config = augment_config
        self.max_length = max_length
        self._data = examples
        self.add_description = add_description
        self.problem_type = problem_type
        self.architecture_type = architecture_type
        self.prompt_first = prompt_first
        self.dataset_labels = self.collect_dataset_labels()
        self.get_negatives = get_negatives
        self.max_labels = max_labels
        self.shuffle_labels = shuffle_labels

        self.sep_token = "<<SEP>>"
        self.label_token = "<<LABEL>>"
        self.example_token = "<<EXAMPLE>>"
        self.augmenter = DataAugmenter(augment_config, examples, self.dataset_labels, label2description)
        print("Total labels: ", len(self.dataset_labels))

    def get_diversity(self):
        return [item.get("_diversity", {}).get("overall_diversity", 0.5) for item in self.data]

    def collect_dataset_labels(self):
        dataset_labels = set()
        for example in self._data:
            dataset_labels.update(set(example["all_labels"]))
        return dataset_labels

    def prepare_labels(self, example, label2idx, problem_type):
        if problem_type == "single_label_classification":
            labels = label2idx[example["true_labels"][0]]
        elif problem_type == "multi_label_classification":
            if isinstance(example["true_labels"], dict):
                labels = [example["true_labels"].get(label, 0.0) for label in example["all_labels"]]
            else:
                labels = [1.0 if label in example["true_labels"] else 0.0 for label in example["all_labels"]]
        else:
            raise NotImplementedError(f"{problem_type} is not implemented.")
        return torch.tensor(labels)

    def prepare_prompt(self, item, label_token_first=True):
        prompt_texts = []
        for label in item["all_labels"]:
            if label_token_first:
                label_tag = f"{self.label_token}{label!s}"
            else:
                label_tag = f"{label!s}{self.label_token}"
            prompt_texts.append(label_tag)
        prompt_texts.append(self.sep_token)
        prompt = item.get("prompt", "")
        prompt_texts.append(prompt)
        return prompt_texts

    def format_examples(self, item):
        examples = item.get("examples", [])
        if not examples:
            return ""
        examples = random.sample(examples, k=random.randint(1, len(examples)))
        parts = []
        for example in examples:
            parts.append(self.example_token)
            parts.append(example.get("text", ""))
            parts.append(" \nLabels:\n ")
            parts.append(", ".join(example.get("labels", example.get("true_labels", []))))
        parts.append(self.sep_token)
        return "".join(parts)

    def tokenize(self, texts):
        tokenized_inputs = self.tokenizer(texts, truncation=True, max_length=self.max_length, padding="longest")
        return tokenized_inputs

    def tokenize_labels(self, labels):
        tokenized_inputs = self.labels_tokenizer(labels, truncation=True, max_length=self.max_length, padding="longest")
        return tokenized_inputs

    def tokenize_and_prepare_labels_for_uniencoder(self, example):
        if self.shuffle_labels:
            random.shuffle(example["all_labels"])
        input_text = self.prepare_prompt(example)
        examples_text = self.format_examples(example)
        if self.prompt_first:
            input_text = "".join(input_text) + str(example["text"]) + examples_text
        else:
            input_text = str(example["text"]) + "".join(input_text) + examples_text
        label2idx = {label: idx for idx, label in enumerate(example["all_labels"])}

        tokenized_inputs = self.tokenize(input_text)
        tokenized_inputs["labels"] = self.prepare_labels(example, label2idx, self.problem_type)
        tokenized_inputs["labels_text"] = example["all_labels"]
        tokenized_inputs["input_texts"] = example["text"]
        return tokenized_inputs

    def tokenize_and_prepare_labels_for_encoder_decoder(self, example):
        if self.shuffle_labels:
            random.shuffle(example["all_labels"])
        class_texts = self.prepare_prompt(example, label_token_first=True)
        class_texts = "".join(class_texts)
        examples_text = self.format_examples(example)

        label2idx = {label: idx for idx, label in enumerate(example["all_labels"])}

        input_text = str(example["text"]) + examples_text
        tokenized_inputs = self.tokenize(input_text)
        tokenized_classes = self.tokenize(class_texts)
        tokenized_inputs["class_input_ids"] = tokenized_classes["input_ids"]
        tokenized_inputs["class_attention_mask"] = tokenized_classes["attention_mask"]
        tokenized_inputs["labels"] = self.prepare_labels(example, label2idx, self.problem_type)
        return tokenized_inputs

    def tokenize_and_prepare_labels_for_biencoder(self, example):
        if self.shuffle_labels:
            random.shuffle(example["all_labels"])

        def prepare_prompt(labels):
            prompt_texts = []
            for _label in labels:
                label_tag = "<<LABEL>>"
                prompt_texts.append(label_tag)
            prompt_texts.append("<<SEP>>")
            return "".join(prompt_texts)

        input_text = example["text"]
        class_texts = example["all_labels"]

        if self.architecture_type == "bi-encoder-fused":
            prompt = prepare_prompt(class_texts)
            if self.prompt_first:
                input_text = f"{prompt} {input_text}"
            else:
                input_text = f"{input_text} {prompt}"

        tokenized_inputs = self.tokenize(input_text)
        tokenized_classes = self.tokenize_labels(class_texts)

        tokenized_inputs["class_input_ids"] = torch.tensor(tokenized_classes["input_ids"])
        tokenized_inputs["class_attention_mask"] = torch.tensor(tokenized_classes["attention_mask"])

        label2idx = {label: idx for idx, label in enumerate(example["all_labels"])}

        tokenized_inputs["labels_mask"] = torch.ones(len(class_texts))
        tokenized_inputs["labels"] = self.prepare_labels(example, label2idx, self.problem_type)
        return tokenized_inputs

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        example = self._data[idx]

        example = self.augmenter.augment(example)

        if self.architecture_type == "uni-encoder":
            model_inputs = self.tokenize_and_prepare_labels_for_uniencoder(example)
        elif self.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
            model_inputs = self.tokenize_and_prepare_labels_for_encoder_decoder(example)
        elif self.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
            model_inputs = self.tokenize_and_prepare_labels_for_biencoder(example)
        else:
            raise NotImplementedError("This architecture type is not implemented.")
        return model_inputs


def pad_2d_tensor(key_data):
    """
    Pad a list of 2D tensors to have the same size along both dimensions.

    :param key_data: List of 2D tensors to pad.
    :return: Tensor of padded tensors stacked along a new batch dimension.
    """
    if not key_data:
        raise ValueError("The input list 'key_data' should not be empty.")

    # Determine the maximum size along both dimensions
    max_rows = max(tensor.shape[0] for tensor in key_data)
    max_cols = max(tensor.shape[1] for tensor in key_data)

    tensors = []

    for tensor in key_data:
        rows, cols = tensor.shape
        row_padding = max_rows - rows
        col_padding = max_cols - cols
        # Pad the tensor along both dimensions
        padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=0)
        tensors.append(padded_tensor)

    # Stack the tensors into a single tensor along a new batch dimension
    padded_tensors = torch.stack(tensors)

    return padded_tensors


class DataCollatorWithPadding:
    def __init__(self, device="cuda:0", config=None):
        self.device = device
        self._max_labels_alloc = getattr(config, "max_labels_alloc", "dynamic") if config is not None else "dynamic"

    def _resolve_max_num_classes(self, batch):
        if self._max_labels_alloc == "dynamic":
            first = batch[0]
            if "labels_text" in first:
                return max(len(item["labels_text"]) for item in batch)
            if "labels_mask" in first:
                return max(item["labels_mask"].shape[0] for item in batch)
            first_labels = first.get("labels")
            if isinstance(first_labels, torch.Tensor) and first_labels.dim() >= 1:
                return max(item["labels"].shape[0] for item in batch)
            return None
        if isinstance(self._max_labels_alloc, int):
            return self._max_labels_alloc
        return None  # 'fixed': model uses config.max_num_classes

    def __call__(self, batch):
        keys = batch[0].keys()
        padded_batch = {key: [] for key in keys}

        for key in keys:
            key_data = [item[key] for item in batch]
            if isinstance(key_data[0], torch.Tensor):
                if key_data[0].dim() == 1:
                    padded_batch[key] = pad_sequence(key_data, batch_first=True)
                elif key_data[0].dim() == 2:
                    padded_batch[key] = pad_2d_tensor(key_data)
            elif isinstance(key_data[0], list):
                data_el = "string"
                if len(key_data[0]):
                    data_el = key_data[0][0]
                if isinstance(data_el, str):
                    padded_batch[key] = key_data
                else:
                    max_length = max(len(seq) for seq in key_data)
                    padded_batch[key] = torch.tensor([seq + [0] * (max_length - len(seq)) for seq in key_data])
            elif type(key_data[0]) in {int, float}:
                padded_batch[key] = torch.tensor(key_data)
            elif isinstance(key_data[0], str):
                padded_batch[key] = key_data
            else:
                raise TypeError(f"Unsupported data type: {type(key_data[0])}")

        padded_batch["max_num_classes"] = self._resolve_max_num_classes(batch)
        return padded_batch


================================================
FILE: gliclass/layers.py
================================================
# Copyright 2020 Microsoft and the Hugging Face Inc. team and Knowledgator.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from transformers.activations import ACT2FN

from .config import GLiClassModelConfig


class LstmSeq2SeqEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, bidirectional=False):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional,
            batch_first=True,
        )

    def forward(self, x, mask, hidden=None):
        # Packing the input sequence
        lengths = mask.sum(dim=1).cpu()
        packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)

        # Passing packed sequence through LSTM
        packed_output, hidden = self.lstm(packed_x, hidden)

        # Unpacking the output sequence
        output, _ = pad_packed_sequence(packed_output, batch_first=True)

        return output


class FeaturesProjector(nn.Module):
    def __init__(self, config: GLiClassModelConfig):
        super().__init__()

        self.linear_1 = nn.Linear(config.encoder_config.hidden_size, config.hidden_size, bias=True)
        self.act = ACT2FN[config.projector_hidden_act]
        self.dropout = nn.Dropout(config.dropout)
        self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)

    def forward(self, features):
        hidden_states = self.linear_1(features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class BiEncoderProjector(nn.Module):
    def __init__(self, config: GLiClassModelConfig):
        super().__init__()

        self.linear_1 = nn.Linear(config.label_model_config.hidden_size, config.hidden_size, bias=True)
        self.act = ACT2FN[config.projector_hidden_act]
        self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)

    def forward(self, features):
        hidden_states = self.linear_1(features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
class DropoutContext:
    def __init__(self):
        self.dropout = 0
        self.mask = None
        self.scale = 1
        self.reuse_mask = True


# Copied from transformers.models.deberta.modeling_deberta.get_mask
def get_mask(input, local_context):
    if not isinstance(local_context, DropoutContext):
        dropout = local_context
        mask = None
    else:
        dropout = local_context.dropout
        dropout *= local_context.scale
        mask = local_context.mask if local_context.reuse_mask else None

    if dropout > 0 and mask is None:
        mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)

    if isinstance(local_context, DropoutContext) and local_context.mask is None:
        local_context.mask = mask

    return mask, dropout


# Copied from transformers.models.deberta.modeling_deberta.XDropout
class XDropout(torch.autograd.Function):
    """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""

    @staticmethod
    def forward(ctx, input, local_ctx):
        mask, dropout = get_mask(input, local_ctx)
        ctx.scale = 1.0 / (1 - dropout)
        if dropout > 0:
            ctx.save_for_backward(mask)
            return input.masked_fill(mask, 0) * ctx.scale
        else:
            return input

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.scale > 1:
            (mask,) = ctx.saved_tensors
            return grad_output.masked_fill(mask, 0) * ctx.scale, None
        else:
            return grad_output, None

    @staticmethod
    def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: float | DropoutContext) -> torch._C.Value:
        from torch.onnx import symbolic_opset12

        dropout_p = local_ctx
        if isinstance(local_ctx, DropoutContext):
            dropout_p = local_ctx.dropout
        # StableDropout only calls this function when training.
        train = True
        # TODO: We should check if the opset_version being used to export
        # is > 12 here, but there's no good way to do that. As-is, if the
        # opset_version < 12, export will fail with a CheckerError.
        # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
        # if opset_version < 12:
        #   return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
        return symbolic_opset12.dropout(g, input, dropout_p, train)


# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(nn.Module):
    """
    Optimized dropout module for stabilizing the training.

    Args:
        drop_prob (float): the dropout probabilities
    """

    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob
        self.count = 0
        self.context_stack = None

    def forward(self, x):
        """
        Call the module.

        Args:
            x (`torch.tensor`): The input tensor to apply dropout
        """
        if self.training and self.drop_prob > 0:
            return XDropout.apply(x, self.get_context())
        return x

    def clear_context(self):
        self.count = 0
        self.context_stack = None

    def init_context(self, reuse_mask=True, scale=1):
        if self.context_stack is None:
            self.context_stack = []
        self.count = 0
        for c in self.context_stack:
            c.reuse_mask = reuse_mask
            c.scale = scale

    def get_context(self):
        if self.context_stack is not None:
            if self.count >= len(self.context_stack):
                self.context_stack.append(DropoutContext())
            ctx = self.context_stack[self.count]
            ctx.dropout = self.drop_prob
            self.count += 1
            return ctx
        else:
            return self.drop_prob


class SelfAttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        return self.norm(x + self.dropout(attn_output))


class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        attn_output, _ = self.cross_attn(query, key, value, attn_mask=mask)
        return self.norm(query + self.dropout(attn_output))


class Fuser(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.layers = nn.ModuleList(
            [
                nn.ModuleList(
                    [SelfAttentionBlock(d_model, num_heads, dropout), CrossAttentionBlock(d_model, num_heads, dropout)]
                )
                for _ in range(num_layers)
            ]
        )
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, query, key, query_mask=None, key_mask=None):
        if query_mask is not None and key_mask is not None:
            self_attn_mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2)
            cross_attn_mask = query_mask.unsqueeze(-1) * key_mask.unsqueeze(1)
        else:
            self_attn_mask = None
            cross_attn_mask = None

        value = self.fc(key)

        for self_attn, cross_attn in self.layers:
            query = self_attn(query, mask=self_attn_mask)
            query = cross_attn(query, key, value, mask=cross_attn_mask)

        return query


class LayerwiseAttention(nn.Module):
    def __init__(self, num_layers, hidden_size, output_size=None):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.output_size = output_size if output_size is not None else hidden_size

        # Squeeze operation
        self.squeeze = nn.Linear(hidden_size, 1)

        # Excitation operation
        self.W1 = nn.Linear(num_layers, num_layers // 2)
        self.W2 = nn.Linear(num_layers // 2, num_layers)

        # Final projection
        self.output_projection = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, encoder_outputs):
        # encoder_outputs is a list of tensors, each of shape [B, L, D]
        _B, _L, _D = encoder_outputs[0].shape

        # Concatenate all layers
        U = torch.stack(encoder_outputs, dim=1)  # [B, K, L, D]

        # Squeeze operation
        Z = self.squeeze(U).squeeze(-1)  # [B, K, L]
        Z = Z.mean(dim=2)  # [B, K]

        # Excitation operation
        s = self.W2(F.relu(self.W1(Z)))  # [B, K]
        s = torch.sigmoid(s)  # [B, K]

        # Apply attention weights
        U_weighted = U * s.unsqueeze(-1).unsqueeze(-1)  # [B, K, L, D]

        # Sum across layers
        U_sum = U_weighted.sum(dim=1)  # [B, L, D]

        # Final projection
        output = self.output_projection(U_sum)  # [B, L, output_size]

        return output


================================================
FILE: gliclass/loss_functions.py
================================================
import torch
import torch.nn.functional as F


def sequence_contrastive_loss(embeddings, mask):
    # embeddings shape: (B, L, D)
    # mask shape: (B, L)
    B, L, _D = embeddings.shape

    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=-1)

    # Compute similarity matrix
    sim_matrix = torch.matmul(embeddings, embeddings.transpose(1, 2))  # / self.temperature

    # Create labels for cross entropy (diagonal indices)
    labels = torch.arange(L, device=embeddings.device).unsqueeze(0).expand(B, -1)

    # Compute loss for each element in the batch
    loss = F.cross_entropy(sim_matrix.reshape(B * L, L), labels.reshape(-1), reduction="none")

    # Apply mask to loss
    loss = loss.view(B, L) * mask

    # Compute mean loss over non-padded elements
    loss = loss.sum() / mask.sum()

    return loss


def focal_loss_with_logits(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = 0.25,
    gamma: float = 2,
    reduction: str = "none",
    label_smoothing: float = 0.0,
    ignore_index: int = -100,  # default value for ignored index
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs (Tensor): A float tensor of arbitrary shape.
                The predictions for each example.
        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
                classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha (float): Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default: ``0.25``.
        gamma (float): Exponent of the modulating factor (1 - p_t) to
                balance easy vs hard examples. Default: ``2``.
        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
                ``'none'``: No reduction will be applied to the output.
                ``'mean'``: The output will be averaged.
                ``'sum'``: The output will be summed. Default: ``'none'``.
        label_smoothing (float): Specifies the amount of smoothing when computing the loss,
                                                                where 0.0 means no smoothing.
        ignore_index (int): Specifies a target value that is ignored and does not contribute
                            to the input gradient. Default: ``-100``.

    Returns:
        Loss tensor with the reduction option applied.
    """
    # Create a mask to ignore specified index
    valid_mask = targets != ignore_index

    # Apply label smoothing if needed
    if label_smoothing != 0:
        with torch.no_grad():
            targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing

    # Apply sigmoid activation to inputs
    p = torch.sigmoid(inputs)

    # Compute the binary cross-entropy loss without reduction
    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

    # Apply the valid mask to the loss
    loss = loss * valid_mask

    # Apply focal loss modulation if gamma is greater than 0
    if gamma > 0:
        p_t = p * targets + (1 - p) * (1 - targets)
        loss = loss * ((1 - p_t) ** gamma)

    # Apply alpha weighting if alpha is specified
    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    # Apply reduction method
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.sum() / valid_mask.sum()  # Normalize by the number of valid (non-ignored) elements
    elif reduction == "sum":
        return loss.sum()
    else:
        raise ValueError(
            f"Invalid value for argument 'reduction': '{reduction}'. Supported reduction modes: 'none', 'mean', 'sum'"
        )


================================================
FILE: gliclass/model.py
================================================
import os
import warnings
from typing import Tuple
from pathlib import Path
from dataclasses import dataclass

import torch
import transformers
from torch import nn
from packaging import version
from transformers import AutoModel, AutoConfig, PreTrainedModel
from transformers.utils import logging
from transformers.modeling_outputs import SequenceClassifierOutput

# Import initialization module (transformers 5.0+) or fallback to torch.nn.init
try:
    from transformers import initialization as init
except ImportError:
    # transformers < 5.0 doesn't have this module, use torch.nn.init instead
    from torch.nn import init
from .utils import MissedPackageException, is_module_available
from .config import GLiClassModelConfig
from .layers import FeaturesProjector, BiEncoderProjector, LayerwiseAttention, LstmSeq2SeqEncoder
from .scorers import SCORER2OBJECT
from .poolings import POOLING2OBJECT
from .loss_functions import focal_loss_with_logits, sequence_contrastive_loss

IS_LLM2VEC = is_module_available("llm2vec")
IS_PEFT = is_module_available("peft")
IS_TURBOT5 = is_module_available("turbot5")
IS_FLASHDEBERTA = is_module_available("flashdeberta")

logger = logging.get_logger(__name__)

if IS_LLM2VEC:
    from llm2vec.models import GemmaBiModel, LlamaBiModel, Qwen2BiModel, MistralBiModel

    DECODER_MODEL_MAPPING = {
        "MistralConfig": MistralBiModel,
        "LlamaConfig": LlamaBiModel,
        "GemmaConfig": GemmaBiModel,
        "Qwen2Config": Qwen2BiModel,
    }
else:
    DECODER_MODEL_MAPPING = {}

if IS_TURBOT5:
    from turbot5.model.modeling import T5EncoderModel as FlashT5EncoderModel
from transformers import T5EncoderModel, UMT5EncoderModel

if IS_FLASHDEBERTA:
    from flashdeberta import FlashDebertaV2Model
from transformers import DebertaV2Model

if IS_PEFT:
    from peft import LoraConfig, get_peft_model


@dataclass
class GLiClassOutput(SequenceClassifierOutput):
    text_embeddings: torch.Tensor | None = None
    class_embeddings: torch.Tensor | None = None


class GLiClassPreTrainedModel(PreTrainedModel):
    config_class = GLiClassModelConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _supports_sdpa = False
    _keys_to_ignore_on_load_unexpected = ["position_embeddings"]

    def _initialize_weights(self, module, is_remote_code: bool = False):
        """
        Initialize weights if not already initialized.

        This method is called by transformers 5.0+ during post_init().
        It uses the _is_hf_initialized flag to prevent reinitializing weights
        that were already loaded from a checkpoint.

        For transformers 4.x, this method is not called, maintaining backward compatibility.
        """
        if getattr(module, "_is_hf_initialized", False):
            return

        self._init_weights(module)
        module._is_hf_initialized = True

    def _init_weights(self, module):
        std = (
            self.config.initializer_range
            if hasattr(self.config, "initializer_range")
            else self.config.encoder_config.initializer_range
        )

        if hasattr(module, "class_embedding"):
            init.normal_(module.class_embedding, mean=0.0, std=std)

        if hasattr(module, "segment_embeddings"):
            init.normal_(module.segment_embeddings.weight, mean=0.0, std=std)

        if isinstance(module, (nn.Linear, nn.Conv2d)):
            init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            init.normal_(module.weight, mean=0.0, std=std)
            if module.padding_idx is not None:
                init.zeros_(module.weight[module.padding_idx])
        elif isinstance(module, nn.LSTM):
            for name, param in module.named_parameters():
                if "weight_ih" in name or "weight_hh" in name:
                    init.normal_(param, mean=0.0, std=std)
                elif "bias" in name:
                    init.zeros_(param)


class GLiClassBaseModel(nn.Module):  # ):
    def __init__(self, config: GLiClassModelConfig, device="cpu", **kwargs):
        super().__init__()
        self.config = config
        self.text_projector = FeaturesProjector(config)
        self.classes_projector = FeaturesProjector(config)

        if config.pooling_strategy not in POOLING2OBJECT:
            raise NotImplementedError(f"{config.pooling_strategy} is not implemented pooling type.")
        else:
            self.pooler = POOLING2OBJECT[config.pooling_strategy]()

        if config.pooling_strategy not in POOLING2OBJECT:
            raise NotImplementedError(
                f"{config.scorer_type} is not implemented. Choose one of this: 'dot', 'weighted-dot'"
            )
        else:
            self.scorer = SCORER2OBJECT[config.scorer_type](
                config.hidden_size,
                num_heads=config.scorer_num_heads,
                scorer_mlp_hidden_size=config.scorer_mlp_hidden_size,
                attn_dropout=config.scorer_attn_dropout,
            )

        if config.use_lstm:
            self.lstm = LstmSeq2SeqEncoder(config.hidden_size, config.hidden_size // 2, bidirectional=True)

        if config.squeeze_layers:
            self.layer_wise_attention = LayerwiseAttention(
                config.encoder_config.num_hidden_layers, config.encoder_config.hidden_size
            )

        drop_out = getattr(config, "dropout", 0.0)
        # self.dropout = StableDropout(drop_out)
        self.dropout = nn.Dropout(drop_out)

        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

        self.epsilon = 1e-8
        self.vocab_size = config.vocab_size
        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
        self.num_labels = -1

        self.device = torch.device(device)

    def _extract_class_features(self, token_embeds, input_ids, attention_mask, max_num_classes=None):
        batch_size, _sequence_length, embed_dim = token_embeds.shape

        class_token_mask = input_ids == self.config.class_token_index
        num_class_tokens = torch.sum(class_token_mask, dim=-1, keepdim=True)

        # max_num_classes from caller (CPU int) avoids GPU→CPU sync via .item()
        max_embed_dim = max_num_classes if max_num_classes is not None else self.config.max_num_classes

        # Get class token pooling method from config (default to "first" for backward compatibility)
        class_token_pooling = getattr(self.config, "class_token_pooling", "first")

        if class_token_pooling == "average":
            # Average all tokens belonging to each class label
            classes_embedding, classes_embedding_mask = self._extract_class_features_averaged(
                token_embeds,
                input_ids,
                attention_mask,
                class_token_mask,
                num_class_tokens,
                max_embed_dim,
                batch_size,
                embed_dim,
            )
        else:
            # Original behavior: use only the class token (or token after it)
            classes_embedding, classes_embedding_mask = self._extract_class_features_first(
                token_embeds,
                input_ids,
                attention_mask,
                class_token_mask,
                num_class_tokens,
                max_embed_dim,
                batch_size,
                embed_dim,
            )

        # Text features extraction
        if self.config.extract_text_features:
            text_token_mask = input_ids == self.config.text_token_index
            text_token_indices = text_token_mask.int().argmax(dim=-1)  # (batch,)
            max_text_length = input_ids.shape[-1]  # static, no GPU→CPU sync

            # (batch, max_text_length): source position in token_embeds for each target slot
            aranged_target_idx = (
                torch.arange(max_text_length, device=token_embeds.device).unsqueeze(0).expand(batch_size, -1)
            )
            valid_mask = aranged_target_idx < (input_ids.shape[-1] - text_token_indices).unsqueeze(1)

            source_indices = (text_token_indices.unsqueeze(1) + aranged_target_idx).clamp(max=input_ids.shape[-1] - 1)
            batch_arange = torch.arange(batch_size, device=token_embeds.device).unsqueeze(1)

            # Gather then zero-out invalid positions — no nonzero/scatter needed
            text_tokens_embeddings = token_embeds[batch_arange, source_indices] * valid_mask.unsqueeze(-1).to(
                token_embeds.dtype
            )
            text_tokens_mask = attention_mask[batch_arange, source_indices] * valid_mask
        else:
            text_tokens_embeddings = token_embeds
            text_tokens_mask = attention_mask
        return classes_embedding, classes_embedding_mask, text_tokens_embeddings, text_tokens_mask

    def _extract_class_features_first(
        self,
        token_embeds,
        input_ids,
        attention_mask,
        class_token_mask,
        num_class_tokens,
        max_embed_dim,
        batch_size,
        embed_dim,
    ):
        """Extract only the class token embedding (or token after it). Fully vectorized."""
        class_cum = class_token_mask.long().cumsum(dim=-1)  # (batch, seq)
        k_range = torch.arange(max_embed_dim, device=token_embeds.device).view(1, -1, 1)

        # select_mask[b, k, s] = True at the position of the k-th class token
        select_mask = class_token_mask.unsqueeze(1) & ((class_cum.unsqueeze(1) - 1) == k_range)

        if not self.config.embed_class_token:
            # Shift right by 1: select the token immediately after each class token
            shifted = torch.zeros_like(select_mask)
            shifted[:, :, 1:] = select_mask[:, :, :-1]
            select_mask = shifted

        classes_embedding = torch.einsum("bks,bsd->bkd", select_mask.to(token_embeds.dtype), token_embeds)

        arange_k = torch.arange(max_embed_dim, device=token_embeds.device).unsqueeze(0)
        classes_embedding_mask = (arange_k < num_class_tokens).to(attention_mask.dtype)

        return classes_embedding, classes_embedding_mask

    def _extract_class_features_averaged(
        self,
        token_embeds,
        input_ids,
        attention_mask,
        class_token_mask,
        num_class_tokens,
        max_embed_dim,
        batch_size,
        embed_dim,
    ):
        """Average all tokens belonging to each class label. Fully vectorized."""
        # class_cum[b, s] = cumulative count of class tokens up to position s
        class_cum = class_token_mask.long().cumsum(dim=-1)  # (batch, seq)

        if self.config.extract_text_features:
            text_token_mask = input_ids == self.config.text_token_index
        else:
            text_token_mask = torch.zeros_like(class_token_mask)
        # text_cum[b, s] >= 1 at and after the text token → use as exclusion boundary
        text_cum = text_token_mask.long().cumsum(dim=-1)  # (batch, seq)

        # span_mask[b, k, s] = True if token s belongs to the span of class k
        k_range = torch.arange(max_embed_dim, device=token_embeds.device).view(1, -1, 1)
        span_mask = (
            (class_cum.unsqueeze(1) == (k_range + 1))  # in the span of class k
            & (text_cum.unsqueeze(1) == 0)  # before the text boundary
            & attention_mask.unsqueeze(1).bool()  # real token (not padding)
        )
        if not self.config.embed_class_token:
            span_mask = span_mask & ~class_token_mask.unsqueeze(1)

        span_float = span_mask.to(token_embeds.dtype)  # (batch, max_embed_dim, seq)
        class_counts = span_float.sum(dim=-1, keepdim=True).clamp(min=1)
        classes_embedding = torch.einsum("bks,bsd->bkd", span_float, token_embeds) / class_counts

        arange_k = torch.arange(max_embed_dim, device=token_embeds.device).unsqueeze(0)
        classes_embedding_mask = (arange_k < num_class_tokens).to(attention_mask.dtype)

        return classes_embedding, classes_embedding_mask

    def get_loss(self, logits, labels, classes_embedding=None, classes_embedding_mask=None):
        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    # regression task
                    loss_fn = nn.MSELoss()
                    logits = logits.view(-1).to(labels.dtype)
                    loss = loss_fn(logits, labels.view(-1))
                elif labels.dim() == 1 or labels.size(-1) == 1:
                    label_index = (labels >= 0).nonzero()
                    labels = labels.long()
                    if label_index.size(0) > 0:
                        labeled_logits = torch.gather(
                            logits, 0, label_index.expand(label_index.size(0), logits.size(1))
                        )
                        labels = torch.gather(labels, 0, label_index.view(-1))
                        loss_fct = nn.CrossEntropyLoss()
                        loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
                    else:
                        loss = torch.tensor(0).to(logits)
                else:
                    log_softmax = nn.LogSoftmax(-1)
                    loss = -((log_softmax(logits) * labels).sum(-1)).mean()
            elif self.config.problem_type == "regression":
                loss_fct = nn.MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                all_losses = focal_loss_with_logits(
                    logits,
                    labels,
                    self.config.focal_loss_alpha,
                    self.config.focal_loss_gamma,
                    self.config.focal_loss_reduction,
                )
                if classes_embedding_mask is not None:
                    all_losses = all_losses * classes_embedding_mask.float()
                loss = all_losses.mean()

            if self.config.contrastive_loss_coef > 0 and classes_embedding is not None:
                contrastive_loss = sequence_contrastive_loss(classes_embedding, classes_embedding_mask)
                loss = loss + contrastive_loss * self.config.contrastive_loss_coef
        return loss


class GLiClassUniEncoder(GLiClassBaseModel):
    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
        super().__init__(config)
        if config.encoder_config is None:
            if config.encoder_model_name is None:
                raise ValueError("You need to specify encoder model name to use it as a backbone.")
            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)

        config_name = config.encoder_config.__class__.__name__

        model_kwargs = {}
        if config_name in DECODER_MODEL_MAPPING:
            if not IS_LLM2VEC:
                raise MissedPackageException(
                    f"The llm2vec package must be installed to use this decoder model: {config_name}"
                )
            else:
                print("Loading decoder model using LLM2Vec...")
                ModelClass = DECODER_MODEL_MAPPING[config_name]
            decoder = True
        elif config_name in {"T5Config", "MT5Config", "UMT5Config"}:
            decoder = False
            turbot5_type = os.environ.get("TURBOT5_ATTN_TYPE", "")
            if turbot5_type and IS_TURBOT5:
                ModelClass = FlashT5EncoderModel
                model_kwargs = {"attention_type": turbot5_type}
            elif config_name == "UMT5Config":
                ModelClass = UMT5EncoderModel
            else:
                ModelClass = T5EncoderModel
        elif config_name in {"DebertaV2Config"}:
            decoder = False
            if os.environ.get("USE_FLASHDEBERTA", "") and IS_FLASHDEBERTA:
                print("Using FlashDeberta backend.")
                ModelClass = FlashDebertaV2Model
            else:
                ModelClass = DebertaV2Model

        else:
            decoder = False
            ModelClass = AutoModel

        if from_pretrained:
            self.encoder_model = ModelClass.from_pretrained(config.encoder_model_name, **model_kwargs)
        elif decoder:
            self.encoder_model = ModelClass(config.encoder_config)
        elif config_name in {"T5Config", "MT5Config", "UMT5Config", "DebertaV2Config"}:
            self.encoder_model = ModelClass._from_config(config.encoder_config)
        else:
            self.encoder_model = ModelClass.from_config(config.encoder_config)

        if config.vocab_size is not None and hasattr(self.encoder_model, "resize_token_embeddings"):
            current_vocab = self.encoder_model.config.vocab_size
            if current_vocab != config.vocab_size:
                self.encoder_model.resize_token_embeddings(config.vocab_size)

        adapter_config_file = Path(config.encoder_model_name) / "adapter_config.json"

        if adapter_config_file.exists():
            if not IS_PEFT:
                warnings.warn(
                    "Adapter configs were detected, if you want to apply them you need to install peft package.",
                    stacklevel=2,
                )
            else:
                adapter_config = LoraConfig.from_pretrained(config.encoder_model_name)
                self.encoder_model = get_peft_model(self.encoder_model, adapter_config)

        if config.use_segment_embeddings:
            self.segment_embeddings = nn.Embedding(3, config.encoder_config.hidden_size)
            nn.init.normal_(self.segment_embeddings.weight, mean=0.0, std=config.initializer_range)

    def _create_segment_ids(self, input_ids):
        batch_size, _seq_length = input_ids.shape
        segment_ids = torch.zeros_like(input_ids)  # Default: segment 0 (labels)

        # Find example token positions
        example_token_mask = input_ids == self.config.example_token_index
        example_token_indices = example_token_mask.int().argmin(dim=-1)
        has_example = example_token_mask.any(dim=-1)

        text_token_mask = input_ids == self.config.text_token_index
        text_token_indices = text_token_mask.int().argmax(dim=-1)

        for batch_idx in range(batch_size):
            text_start = text_token_indices[batch_idx].item()

            # If examples exist, assign segment 1 to example section
            if has_example[batch_idx]:
                example_start = example_token_indices[batch_idx].item()
                segment_ids[batch_idx, text_start:example_start] = 1
                segment_ids[batch_idx, example_start:] = 2
            else:
                segment_ids[batch_idx, text_start:] = 1

        return segment_ids

    def process_encoder_output(self, input_ids, attention_mask, encoder_layer, labels=None, max_num_classes=None):
        classes_embedding, classes_embedding_mask, text_token_embeddings, text_mask = self._extract_class_features(
            encoder_layer, input_ids, attention_mask, max_num_classes
        )
        if self.config.use_lstm:
            text_token_embeddings = self.lstm(text_token_embeddings, text_mask)

        pooled_output = self.pooler(text_token_embeddings)
        pooled_output = self.text_projector(pooled_output)
        pooled_output = self.dropout(pooled_output)
        if self.config.normalize_features:
            pooled_output = pooled_output / (pooled_output.norm(p=2, dim=-1, keepdim=True) + self.epsilon)

        classes_embedding = self.classes_projector(classes_embedding)
        if self.config.normalize_features:
            classes_embedding = classes_embedding / (classes_embedding.norm(p=2, dim=-1, keepdim=True) + self.epsilon)

        logits = self.scorer(pooled_output, classes_embedding, text_mask=text_mask)

        if self.config.normalize_features:
            logits = logits * self.logit_scale.to(classes_embedding.device)

        loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)
        return (logits, loss, pooled_output, classes_embedding)

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        output_text_embeddings: bool | None = None,
        output_class_embeddings: bool | None = None,
        return_dict: bool | None = None,
        max_num_classes: int | None = None,
        **kwargs,
    ) -> Tuple | GLiClassOutput:
        r"""
        Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.squeeze_layers or self.config.layer_wise:
            output_hidden_states = True
            return_dict = True

        if self.config.use_segment_embeddings:
            embedding_layer = self.encoder_model.get_input_embeddings()
            token_embeds = embedding_layer(input_ids)

            segment_ids = self._create_segment_ids(input_ids)
            segment_embeds = self.segment_embeddings(segment_ids)

            inputs_embeds = token_embeds + segment_embeds

            outputs = self.encoder_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )
        else:
            outputs = self.encoder_model(
                input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                **kwargs,
            )

        if self.config.layer_wise and labels is not None:
            hidden_states = outputs.hidden_states
            loss = 0
            for encoder_layer in hidden_states:
                logits, layer_loss, pooled_output, classes_embedding = self.process_encoder_output(
                    input_ids, attention_mask, encoder_layer, labels, max_num_classes
                )
                loss += layer_loss
        else:
            if self.config.encoder_layer_id == -1:
                if self.config.squeeze_layers:
                    encoder_layer = self.layer_wise_attention(outputs.hidden_states)
                else:
                    encoder_layer = outputs[0]
            else:
                encoder_layer = outputs.hidden_states[self.config.encoder_layer_id]
            logits, loss, pooled_output, classes_embedding = self.process_encoder_output(
                input_ids, attention_mask, encoder_layer, labels, max_num_classes
            )

        if not return_dict:
            output = (logits, *outputs[1:])
            return ((loss, *output)) if loss is not None else output

        return GLiClassOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            text_embeddings=pooled_output if output_text_embeddings else None,
            class_embeddings=classes_embedding if output_class_embeddings else None,
        )


class GLiClassEncoderDecoder(GLiClassBaseModel):
    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
        super().__init__(config)
        if config.encoder_config is None:
            if config.encoder_model_name is None:
                raise ValueError("You need to specify encoder model name to use it as a backbone.")
            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)

        if not config.encoder_config.is_encoder_decoder:
            raise ValueError("You need to choose encoder-decoder model as a backbone.")

        if from_pretrained:
            self.encoder_decoder_model = AutoModel.from_pretrained(config.encoder_model_name)
        else:
            self.encoder_decoder_model = AutoModel.from_config(config.encoder_config)

    @staticmethod
    def _make_bidirectional_4d_mask(attention_mask_2d, dtype):
        """Convert a 2D padding mask into a 4D bidirectional attention mask.

        When a 4D mask is passed to the decoder, the model uses it as-is
        without applying its default causal pattern, enabling bidirectional
        self-attention in the decoder.

        Args:
            attention_mask_2d: (batch_size, seq_length) with 1 for real tokens, 0 for padding.
            dtype: The dtype of the model (needed for the min-value fill).

        Returns:
            4D mask of shape (batch_size, 1, seq_length, seq_length).
            Values are 0.0 for attended positions and a large negative value for masked positions.
        """
        batch_size, seq_length = attention_mask_2d.shape
        # (batch_size, 1, 1, seq_length) - masks out padding columns
        padding_mask = (1.0 - attention_mask_2d.to(dtype))[:, None, None, :] * torch.finfo(dtype).min
        return padding_mask.expand(batch_size, 1, seq_length, seq_length)

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        class_input_ids: torch.Tensor | None = None,
        class_attention_mask: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        output_text_embeddings: bool | None = None,
        output_class_embeddings: bool | None = None,
        return_dict: bool | None = True,
        **kwargs,
    ) -> Tuple | SequenceClassifierOutput:
        r"""
        Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Build a 4D bidirectional mask for the decoder so it attends to
        # all non-padding positions instead of using causal masking.
        decoder_4d_mask = None
        if class_attention_mask is not None:
            model_dtype = next(self.encoder_decoder_model.parameters()).dtype
            decoder_4d_mask = self._make_bidirectional_4d_mask(class_attention_mask, model_dtype)

        outputs = self.encoder_decoder_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=class_input_ids,
            decoder_attention_mask=decoder_4d_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs,
        )
        text_token_embeddings = outputs.encoder_last_hidden_state
        decoder_token_embeddings = outputs.last_hidden_state
        classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(
            decoder_token_embeddings, class_input_ids, class_attention_mask
        )

        if self.config.use_lstm:
            text_token_embeddings = self.lstm(text_token_embeddings, attention_mask)

        pooled_output = self.pooler(text_token_embeddings)
        pooled_output = self.text_projector(pooled_output)
        pooled_output = self.dropout(pooled_output)
        if self.config.normalize_features:
            pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)

        classes_embedding = self.classes_projector(classes_embedding)
        if self.config.normalize_features:
            classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)

        logits = self.scorer(pooled_output, classes_embedding)

        if self.config.normalize_features:
            logits = logits * self.logit_scale.to(classes_embedding.device)

        loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)

        if not return_dict:
            output = (logits, *outputs[1:])
            return ((loss, *output)) if loss is not None else output

        return GLiClassOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.decoder_hidden_states,
            attentions=outputs.decoder_attentions,
            text_embeddings=pooled_output if output_text_embeddings else None,
            class_embeddings=classes_embedding if output_class_embeddings else None,
        )


class GLiClassEncoderDecoderCLS(GLiClassBaseModel):
    """Encoder-decoder architecture where labels go to the encoder and text goes to the decoder.

    Class features are extracted from encoder output using _extract_class_features().
    Text features are extracted from the last non-padding token of the decoder output.
    """

    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
        super().__init__(config)
        if config.encoder_config is None:
            if config.encoder_model_name is None:
                raise ValueError("You need to specify encoder model name to use it as a backbone.")
            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)

        if not config.encoder_config.is_encoder_decoder:
            raise ValueError("You need to choose encoder-decoder model as a backbone.")

        if from_pretrained:
            self.encoder_decoder_model = AutoModel.from_pretrained(config.encoder_model_name)
        else:
            self.encoder_decoder_model = AutoModel.from_config(config.encoder_config)

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        class_input_ids: torch.Tensor | None = None,
        class_attention_mask: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        output_text_embeddings: bool | None = None,
        output_class_embeddings: bool | None = None,
        return_dict: bool | None = True,
        **kwargs,
    ) -> Tuple | SequenceClassifierOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Labels → encoder, Text → decoder
        outputs = self.encoder_decoder_model(
            input_ids=class_input_ids,
            attention_mask=class_attention_mask,
            decoder_input_ids=input_ids,
            decoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs,
        )

        # Class features from encoder output
        encoder_token_embeddings = outputs.encoder_last_hidden_state
        classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(
            encoder_token_embeddings, class_input_ids, class_attention_mask
        )

        # Text features from decoder's last non-padding token
        decoder_output = outputs.last_hidden_state
        batch_size = decoder_output.shape[0]
        last_non_pad_idx = attention_mask.sum(dim=1) - 1
        pooled_output = decoder_output[torch.arange(batch_size, device=decoder_output.device), last_non_pad_idx]

        pooled_output = self.text_projector(pooled_output)
        pooled_output = self.dropout(pooled_output)
        if self.config.normalize_features:
            pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)

        classes_embedding = self.classes_projector(classes_embedding)
        if self.config.normalize_features:
            classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)

        logits = self.scorer(pooled_output, classes_embedding)

        if self.config.normalize_features:
            logits = logits * self.logit_scale.to(classes_embedding.device)

        loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)

        if not return_dict:
            output = (logits, *outputs[1:])
            return ((loss, *output)) if loss is not None else output

        return GLiClassOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.decoder_hidden_states,
            attentions=outputs.decoder_attentions,
            text_embeddings=pooled_output if output_text_embeddings else None,
            class_embeddings=classes_embedding if output_class_embeddings else None,
        )


class GLiClassBiEncoder(GLiClassBaseModel):
    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
        super().__init__(config)
        if config.encoder_config is None:
            if config.encoder_model_name is None:
                raise ValueError("You need to specify encoder model name to use it as a backbone.")
            config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)

        if config.label_model_config is None:
            if config.label_model_name is None:
                raise ValueError("You need to specify label model name to use it as a backbone.")
            config.label_model_config = AutoConfig.from_pretrained(config.label_model_name)

        def initialize_encoder(configs, model_name, from_pretrained):
            if from_pretrained:
                return AutoModel.from_pretrained(model_name)
            else:
                return AutoModel.from_config(configs)

        self.encoder_model = initialize_encoder(config.encoder_config, config.encoder_model_name, from_pretrained)
        self.label_encoder = initialize_encoder(config.label_model_config, config.label_model_name, from_pretrained)
        self.biencoder_projector = BiEncoderProjector(config)

    def pool_outputs(self, encoder_outputs):
        text_embeddings = self.pooler(encoder_outputs[0])
        text_embeddings = self.text_projector(text_embeddings)
        text_embeddings = self.dropout(text_embeddings)
        if self.config.normalize_features:
            text_embeddings = nn.functional.normalize(text_embeddings, p=2, dim=-1, eps=self.epsilon)
        return text_embeddings

    def encode_text(self, input_ids, attention_mask):
        outputs = self.encoder_model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))
        text_embeddings = self.pool_outputs(outputs)
        return text_embeddings

    def encode_classes(self, class_input_ids, class_attention_mask, labels_mask=None):
        batch_size = class_input_ids.shape[0]
        num_classes = class_input_ids.shape[1]
        if labels_mask is not None:
            batch_indices, indices = torch.where(labels_mask == 1)
            selected_input_ids = class_input_ids[batch_indices, indices]
            selected_attention_mask = class_attention_mask[batch_indices, indices]

            outputs = self.label_encoder(selected_input_ids, attention_mask=selected_attention_mask)
            class_embeddings_filtered = self.pooler(outputs[0])

            class_embeddings = torch.zeros(
                batch_size,
                num_classes,
                class_embeddings_filtered.shape[-1],
                dtype=class_embeddings_filtered.dtype,
                device=class_embeddings_filtered.device,
            )

            class_embeddings[batch_indices, indices] = class_embeddings_filtered
        else:
            class_input_ids = class_input_ids.view(-1, class_input_ids.shape[-1])
            class_attention_mask = class_attention_mask.view(-1, class_input_ids.shape[-1])
            outputs = self.label_encoder(class_input_ids, attention_mask=class_attention_mask)
            class_embeddings = self.pooler(outputs[0])
            class_embeddings = class_embeddings.reshape(batch_size, num_classes, -1)
        class_embeddings = self.biencoder_projector(class_embeddings)
        class_embeddings = self.classes_projector(class_embeddings)
        if self.config.normalize_features:
            class_embeddings = nn.functional.normalize(class_embeddings, p=2, dim=-1, eps=self.epsilon)
        return class_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        class_input_ids: torch.Tensor | None = None,
        class_attention_mask: torch.Tensor | None = None,
        labels_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
        output_text_embeddings: bool | None = None,
        output_class_embeddings: bool | None = None,
        return_dict: bool | None = None,
        **kwargs,
    ) -> Tuple | SequenceClassifierOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        text_embeddings = self.encode_text(input_ids, attention_mask)
        class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)
        logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)

        if labels_mask is not None:
            logits = torch.where(labels_mask == 0, -1e3, logits)

        loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)

        if not return_dict:
            output = (logits,)
            return ((loss, *output)) if loss is not None else output

        return GLiClassOutput(
            loss=loss,
            logits=logits,
            text_embeddings=text_embeddings if output_text_embeddings else None,
            class_embeddings=class_embeddings if output_class_embeddings else None,
        )


class GLiClassBiEncoderFused(GLiClassBiEncoder):
    def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
        super().__init__(config, from_pretrained)

    def encode_text(self, input_ids, attention_mask, class_embeddings, labels_mask):
        embedding_layer = self.encoder_model.get_input_embeddings()
        inputs_embeds = embedding_layer(input_ids)

        class_token_mask = input_ids == self.config.class_token_index
        batch_indices, class_token_indices = torch.where(class_token_mask)

        labels_batch_indices, labels_indices = torch.where(labels_mask == 1)

        selected_class_embeddings = class_embeddings[labels_batch_indices, labels_indices]

        inputs_embeds[batch_indices, class_token_indices] = selected_class_embeddings
        encoder_outputs = self.encoder_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask.squeeze(1))

        post_class_embeddings = torch.zeros_like(class_embeddings)
        post_class_embeddings[labels_batch_indices, labels_indices] = encoder_outputs[0][
            batch_indices, class_token_indices
        ]
        return encoder_outputs, post_class_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        class_input_ids: torch.Tensor | None = None,
        class_attention_mask: torch.Tensor | None = None,
        labels_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
        output_text_embeddings: bool | None = None,
        output_class_embeddings: bool | None = None,
        return_dict: bool | None = None,
        **kwargs,
    ) -> Tuple | SequenceClassifierOutput:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        raw_class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)

        encoder_outputs, class_embeddings = self.encode_text(
            input_ids, attention_mask, raw_class_embeddings, labels_mask
        )

        text_embeddings = self.pool_outputs(encoder_outputs)

        logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)

        if labels_mask is not None:
            logits = torch.where(labels_mask == 0, -1e3, logits)

        loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)

        if not return_dict:
            output = (logits,)
            return ((loss, *output)) if loss is not None else output

        return GLiClassOutput(
            loss=loss,
            logits=logits,
            text_embeddings=text_embeddings if output_text_embeddings else None,
            class_embeddings=class_embeddings if output_class_embeddings else None,
        )


class GLiClassModel(GLiClassPreTrainedModel):
    def __init__(self, config, from_pretrained=False):
        super().__init__(config)
        if config.architecture_type == "uni-encoder":
            self.model = GLiClassUniEncoder(config, from_pretrained)
        elif config.architecture_type == "bi-encoder":
            self.model = GLiClassBiEncoder(config, from_pretrained)
        elif config.architecture_type == "bi-encoder-fused":
            self.model = GLiClassBiEncoderFused(config, from_pretrained)
        elif config.architecture_type == "encoder-decoder":
            self.model = GLiClassEncoderDecoder(config, from_pretrained)
        elif config.architecture_type == "encoder-decoder-cls":
            self.model = GLiClassEncoderDecoderCLS(config, from_pretrained)
        self.post_init()

    def get_input_embeddings(self):
        if self.config.architecture_type in {"uni-encoder"}:
            return self.model.encoder_model.get_input_embeddings()
        elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
            return self.model.encoder_decoder_model.get_input_embeddings()
        else:
            raise NotImplementedError("Getting input embeddings is not implemented for bi-encoder architecture")

    def set_input_embeddings(self, value):
        if self.config.architecture_type in {"uni-encoder"}:
            self.model.encoder_model.set_input_embeddings(value)
            return None
        elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
            self.model.encoder_decoder_model.set_input_embeddings(value)
        elif self.config.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
            self.model.encoder_model.set_input_embeddings(value)
        else:
            raise NotImplementedError("Setting input embeddings is not implemented for bi-encoder architecture")

    def tie_weights(self, recompute_mapping=True, missing_keys=None):
        """
        Tie model weights for architectures that share parameters.

        This method handles:
        - Version compatibility between transformers v4 and v5
        - Different GLiClass architecture types
        - Special handling for T5/MT5 models in transformers v5+ where encoder.embed_tokens
          may be incorrectly initialized instead of being tied to shared.weight

        Args:
            recompute_mapping: Whether to recompute weight mapping (transformers v5+)
            missing_keys: Keys that are missing from checkpoint (transformers v5+)
        """
        # Get encoder model based on architecture type
        encoder_model = None
        if self.config.architecture_type in {"uni-encoder"}:
            encoder_model = self.model.encoder_model
        elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
            encoder_model = self.model.encoder_decoder_model
        elif self.config.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
            encoder_model = self.model.encoder_model
        else:
            raise NotImplementedError("Tie weights is not implemented for this architecture type")

        # Call base tie_weights with version-appropriate parameters
        if version.parse(transformers.__version__) >= version.parse("5.0.0"):
            result = encoder_model.tie_weights(recompute_mapping=recompute_mapping, missing_keys=missing_keys)
        else:
            result = encoder_model.tie_weights()

        # Fix for T5/MT5/UMT5 models in transformers v5+
        # In v5, if encoder.embed_tokens.weight is missing from checkpoint, it gets randomly
        # initialized instead of being tied to shared.weight. We explicitly ensure proper tying.
        if (
            encoder_model is not None
            and hasattr(encoder_model, "shared")
            and hasattr(encoder_model, "encoder")
            and hasattr(encoder_model.encoder, "embed_tokens")
        ):
            shared_weight = encoder_model.shared.weight
            embed_weight = encoder_model.encoder.embed_tokens.weight

            # Only tie if they're not already the same tensor
            if shared_weight is not embed_weight:
                encoder_model.encoder.embed_tokens.weight = shared_weight
                if version.parse(transformers.__version__) >= version.parse("5.0.0"):
                    logger.info(
                        "Applied transformers v5 compatibility fix: tied encoder.embed_tokens.weight "
                        "to shared.weight for T5-based model"
                    )

        return result

    def resize_token_embeddings(self, new_num_tokens: int | None = None, pad_to_multiple_of=None) -> nn.Embedding:
        if self.config.architecture_type in {"uni-encoder"}:
            model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
            model_embeds = self.model.encoder_decoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        elif self.config.architecture_type in {"bi-encoder-fused"}:
            model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        else:
            raise NotImplementedError("Resizing is not implemented for bi-encoder architecture")
        self.config.encoder_config.vocab_size = model_embeds.num_embeddings
        self.config.vocab_size = model_embeds.num_embeddings
        self.vocab_size = model_embeds.num_embeddings
        return model_embeds

    def forward(self, *args, **kwargs):
        outputs = self.model(*args, **kwargs)
        return outputs


================================================
FILE: gliclass/ops.py
================================================
import torch
import torch.nn.functional as F

# ─── Attention (padded) ───────────────────────────────────────────────────────


def attn_padded(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    key_padding_mask: torch.Tensor | None = None,
    dropout_p: float = 0.0,
) -> torch.Tensor:
    """
    Padded attention via F.scaled_dot_product_attention.
    Uses FlashAttention backend automatically on CUDA when available.

    Args:
        q:                [batch, nq, nheads, head_dim]
        k:                [batch, nk, nheads, head_dim]
        v:                [batch, nk, nheads, head_dim]
        key_padding_mask: [batch, nk] bool, True = real token
    Returns:
        [batch, nq, nheads, head_dim]
    """
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    attn_mask = None
    if key_padding_mask is not None:
        attn_mask = key_padding_mask[:, None, None, :].bool()

    out = F.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=attn_mask,
        dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
    )
    return out.transpose(1, 2)  # [batch, nq, nheads, head_dim]


================================================
FILE: gliclass/pipeline.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, List

import torch
from tqdm import tqdm
from transformers import AutoTokenizer

from .model import GLiClassModel, GLiClassBiEncoder
from .utils import retrieval_augmented_text


def flatten_hierarchical_labels(
    labels: List[str] | Dict[str, Any], prefix: str = "", separator: str = "."
) -> List[str]:
    """
    Flatten hierarchical labels into dot notation.

    Supports arbitrary nesting depth. Examples:

    Input: {"sentiment": ["positive", "negative", "neutral"], "topic": ["product", "service", "shipping"]}
    Output: ["sentiment.positive", "sentiment.negative", "sentiment.neutral",
             "topic.product", "topic.service", "topic.shipping"]

    Input: {
        "category": {
            "electronics": ["phone", "laptop"],
            "clothing": ["shirt", "pants"]
        }
    }
    Output: [
        "category.electronics.phone",
        "category.electronics.laptop",
        "category.clothing.shirt",
        "category.clothing.pants"
    ]

    Input: ["label1", "label2"]  # Already flat
    Output: ["label1", "label2"]

    Args:
        labels: Either a list of string labels or a hierarchical dict
        prefix: Current prefix for recursion (internal use)
        separator: Separator to use between hierarchy levels (default: ".")

    Returns:
        List of flattened label strings with dot notation
    """
    if isinstance(labels, list):
        if prefix:
            return [f"{prefix}{separator}{label}" for label in labels]
        return labels

    elif isinstance(labels, dict):
        flattened = []
        for key, value in labels.items():
            new_prefix = f"{prefix}{separator}{key}" if prefix else key
            flattened.extend(flatten_hierarchical_labels(value, new_prefix, separator))
        return flattened

    elif isinstance(labels, str):
        if prefix:
            return [f"{prefix}{separator}{labels}"]
        return [labels]

    else:
        raise ValueError(f"Unsupported label type: {type(labels)}. Expected list, dict, or str.")


def build_hierarchical_output(
    predictions: List[Dict[str, float]],
    original_labels: List[str] | Dict[str, Any],
    separator: str = ".",
    all_scores: Dict[str, float] | None = None,
) -> Dict[str, float] | Dict[str, Any]:
    """
    Build hierarchical output structure matching the input labels structure.

    Args:
        predictions: List of prediction dicts with 'label' and 'score'
        original_labels: Original hierarchical labels structure
        separator: Separator used in flattened labels
        all_scores: Optional dict of all label scores (for complete output)

    Returns:
        Hierarchical structure with scores matching the input format

    Example:
        Input predictions: [
            {'label': 'sentiment.positive', 'score': 0.85},
            {'label': 'topic.product', 'score': 0.72}
        ]
        Input original_labels: {
            "sentiment": ["positive", "negative", "neutral"],
            "topic": ["product", "service", "shipping"]
        }
        Output: {
            "sentiment": {"positive": 0.85, "negative": 0.0, "neutral": 0.0},
            "topic": {"product": 0.72, "service": 0.0, "shipping": 0.0}
        }
    """
    score_lookup = {pred["label"]: pred["score"] for pred in predictions}

    if all_scores:
        for k, v in all_scores.items():
            if k not in score_lookup:
                score_lookup[k] = v

    def _build_recursive(structure: List[str] | Dict[str, Any], prefix: str = "") -> Dict[str, float] | Dict[str, Any]:
        if isinstance(structure, list):
            result = {}
            for label in structure:
                full_label = f"{prefix}{separator}{label}" if prefix else label
                result[label] = score_lookup.get(full_label, 0.0)
            return result

        elif isinstance(structure, dict):
            result = {}
            for key, value in structure.items():
                new_prefix = f"{prefix}{separator}{key}" if prefix else key
                result[key] = _build_recursive(value, new_prefix)
            return result

        elif isinstance(structure, str):
            full_label = f"{prefix}{separator}{structure}" if prefix else structure
            return {structure: score_lookup.get(full_label, 0.0)}

        return {}

    if isinstance(original_labels, list):
        return {label: score_lookup.get(label, 0.0) for label in original_labels}

    return _build_recursive(original_labels)


def format_examples_prompt(
    examples: List[Dict[str, Any]], example_token: str = "<<EXAMPLE>>", sep_token: str = "<<SEP>>"
) -> str:
    r"""
    Format few-shot examples into a prompt string using <<EXAMPLE>> token.

    Format matches training: <<EXAMPLE>>text \nLabels:\n label1, label2
    with a single <<SEP>> after all examples.

    Args:
        examples: List of example dicts with 'text' and 'labels'/'true_labels' keys
        example_token: Token to mark examples (default: "<<EXAMPLE>>")
        sep_token: Separator token after all examples (default: "<<SEP>>")

    Returns:
        Formatted examples string
    """
    if not examples:
        return ""

    formatted_parts = []
    for example in examples:
        text = example.get("text", "")
        labels = example.get("labels", example.get("true_labels", []))

        if isinstance(labels, list):
            labels_str = ", ".join(labels)
        else:
            labels_str = str(labels)

        # Match training format: " \nLabels:\n " instead of "\nLabels: "
        formatted_parts.append(f"{example_token}{text} \nLabels:\n {labels_str}")

    # Add single SEP token after all examples (matching training)
    formatted_parts.append(sep_token)

    return "".join(formatted_parts)


class BaseZeroShotClassificationPipeline(ABC):
    def __init__(
        self,
        model,
        tokenizer,
        max_classes=25,
        max_length=1024,
        classification_type="multi-label",
        device="cuda:0",
        progress_bar=True,
        label_separator: str = ".",
    ):
        self.model = model
        if isinstance(tokenizer, str):
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        else:
            self.tokenizer = tokenizer
        self.max_classes = max_classes
        self.classification_type = classification_type
        self.max_length = max_length
        self.progress_bar = progress_bar
        self.label_separator = label_separator
        self._max_labels_alloc = getattr(model.config, "max_labels_alloc", "dynamic")

        self.example_token = "<<EXAMPLE>>"
        self.label_token = "<<LABEL>>"
        self.sep_token = "<<SEP>>"

        if not isinstance(device, torch.device):
            if torch.cuda.is_available() and "cuda" in device:
                self.device = torch.device(device)
            else:
                self.device = torch.device("cpu")
        else:
            self.device = device

        if self.model.device != self.device:
            self.model.to(self.device)

        # Ensure model is in evaluation mode for inference
        self.model.eval()

    def _normalize_classification_type(self, classification_type: str | None) -> str:
        if classification_type is None:
            return self.classification_type

        normalized = classification_type.strip().lower()
        if normalized in {"single", "single-label", "single_label"}:
            return "single-label"
        if normalized in {"multi", "multi-label", "multi_label"}:
            return "multi-label"
        raise ValueError("Unsupported classification type: choose 'single-label' or 'multi-label'")

    def _normalize_texts(self, texts: str | List[str]) -> List[str]:
        if isinstance(texts, str):
            return [texts]
        return texts

    def _normalize_thresholds(self, threshold: float | List[float], num_texts: int) -> List[float]:
        if isinstance(threshold, list):
            if len(threshold) != num_texts:
                raise ValueError("Length of threshold list must match number of texts.")
            return threshold
        return [threshold] * num_texts

    def _normalize_classification_types(
        self,
        classification_type: str | List[str] | None,
        num_texts: int,
    ) -> List[str]:
        if isinstance(classification_type, list):
            if len(classification_type) != num_texts:
                raise ValueError("Length of classification_type list must match number of texts.")
            return [self._normalize_classification_type(item) for item in classification_type]

        normalized = self._normalize_classification_type(classification_type)
        return [normalized] * num_texts

    def _process_labels(
        self, labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]]
    ) -> List[str] | List[List[str]]:
        """Process labels to handle hierarchical structures."""
        if not labels:
            return labels

        if isinstance(labels, dict):
            return flatten_hierarchical_labels(labels, separator=self.label_separator)

        if isinstance(labels, list):
            if len(labels) == 0:
                return labels

            first_elem = labels[0]

            if isinstance(first_elem, str):
                return labels

            if isinstance(first_elem, dict):
                return [flatten_hierarchical_labels(lbl, separator=self.label_separator) for lbl in labels]

            if isinstance(first_elem, list):
                if first_elem and isinstance(first_elem[0], dict):
                    return [flatten_hierarchical_labels(lbl, separator=self.label_separator) for lbl in labels]
                return labels

        return labels

    def _format_examples_for_input(self, examples: List[Dict[str, Any]] | None = None) -> str:
        """Format few-shot examples using <<EXAMPLE>> and <<SEP>> tokens."""
        if not examples:
            return ""
        examples = [example for example in examples if example is not None]
        if not examples:
            return ""
        return format_examples_prompt(examples, example_token=self.example_token, sep_token=self.sep_token)

    def _examples_are_per_text(self, examples) -> bool:
        """Detect whether examples are provided per text rather than shared."""
        if not isinstance(examples, list) or len(examples) == 0:
            return False
        if all(isinstance(example, dict) for example in examples):
            return False
        return all(example is None or isinstance(example, list) for example in examples)

    def _get_text_examples(self, examples, index: int):
        """Get examples for a single text from shared or per-text input."""
        if not examples:
            return None
        if self._examples_are_per_text(examples):
            return examples[index] if index < len(examples) else None
        return examples

    def _format_prompt(self, prompt: str | List[str] | None = None, index: int = 0) -> str:
        """Format the task description prompt."""
        if prompt is None:
            return ""

        if isinstance(prompt, str):
            return prompt

        if isinstance(prompt, list):
            if index < len(prompt):
                return prompt[index]
            return prompt[0] if prompt else ""

        return ""

    def _resolve_max_num_classes(self, batch_labels, same_labels: bool):
        if self._max_labels_alloc == "dynamic":
            return len(batch_labels) if same_labels else max(len(labels) for labels in batch_labels)
        if isinstance(self._max_labels_alloc, int):
            return self._max_labels_alloc
        return None  # 'fixed': model uses config.max_num_classes

    @abstractmethod
    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
        pass

    def _get_batch_examples(self, examples, start_idx, batch_size):
        """Get examples for current batch."""
        if not examples:
            return None
        if self._examples_are_per_text(examples):
            return examples[start_idx : start_idx + batch_size]
        return examples

    def _get_batch_prompt(self, prompt, start_idx, batch_size):
        """Get prompt for current batch."""
        if not prompt:
            return None
        if isinstance(prompt, list):
            return prompt[start_idx : start_idx + batch_size]
        return prompt

    @torch.no_grad()
    def get_embeddings(self, texts, labels, batch_size=8, examples=None, prompt=None):
        if isinstance(texts, str):
            texts = [texts]

        labels = self._process_labels(labels)

        if isinstance(labels[0], str):
            same_labels = True
        else:
            same_labels = False

        results = []

        iterable = range(0, len(texts), batch_size)
        if self.progress_bar:
            iterable = tqdm(iterable)

        for idx in iterable:
            batch_texts = texts[idx : idx + batch_size]
            batch_examples = self._get_batch_examples(examples, idx, len(batch_texts))
            batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))

            tokenized_inputs = self.prepare_inputs(
                batch_texts, labels, same_labels, examples=batch_examples, prompt=batch_prompt
            )
            max_num_classes = self._resolve_max_num_classes(labels, same_labels)
            model_output = self.model(
                **tokenized_inputs,
                max_num_classes=max_num_classes,
                output_text_embeddings=True,
                output_class_embeddings=True,
            )
            logits = model_output.logits
            text_embeddings = model_output.text_embeddings
            class_embeddings = model_output.class_embeddings
            batch_size_actual = logits.shape[0]

            for i in range(batch_size_actual):
                result = {
                    "logits": logits[i].cpu().numpy(),
                    "text_embedding": text_embeddings[i].cpu().numpy(),
                    "class_embeddings": class_embeddings[i].cpu().numpy(),
                }
                results.append(result)

        return results

    @torch.no_grad()
    def __call__(
        self,
        texts: str | List[str],
        labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]],
        threshold: float | List[float] = 0.5,
        batch_size: int = 8,
        classification_type: str | List[str] | None = None,
        rac_examples: List | None = None,
        examples: List[Dict[str, Any]] | None = None,
        prompt: str | List[str] | None = None,
        return_hierarchical: bool = False,
    ):
        """
        Perform zero-shot classification.

        Args:
            texts: Single text or list of texts to classify
            labels: Labels in various formats (flat list or hierarchical dict)
            threshold: Classification threshold for multi-label, either one
                value for all texts or one value per text
            batch_size: Batch size for processing
            classification_type: Override classification mode globally or per text.
                If None, uses the pipeline's configured classification_type
            rac_examples: Retrieval augmented examples (legacy)
            examples: Few-shot examples with 'text' and 'labels'/'true_labels' keys
            prompt: Task description - string (same for all) or list (per-text)
            return_hierarchical: If True, return hierarchical structure with all scores

        Returns:
            List of classification results or hierarchical dict structure.
        """
        original_labels = labels

        texts = self._normalize_texts(texts)
        thresholds = self._normalize_thresholds(threshold, len(texts))
        classification_types = self._normalize_classification_types(classification_type, len(texts))

        if rac_examples:
            if len(texts) == 1 and not isinstance(rac_examples[0], list):
                texts = [retrieval_augmented_text(texts[0], rac_examples)]
            else:
                texts = [retrieval_augmented_text(text, ex) for text, ex in zip(texts, rac_examples)]

        processed_labels = self._process_labels(labels)

        if isinstance(processed_labels[0], str):
            same_labels = True
        else:
            same_labels = False

        results = []
        all_scores_list = []

        iterable = range(0, len(texts), batch_size)
        if self.progress_bar:
            iterable = tqdm(iterable)

        for idx in iterable:
            batch_texts = texts[idx : idx + batch_size]
            if not same_labels:
                batch_labels = processed_labels[idx : idx + batch_size]
            else:
                batch_labels = processed_labels

            batch_examples = self._get_batch_examples(examples, idx, len(batch_texts))
            batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))

            tokenized_inputs = self.prepare_inputs(
                batch_texts, batch_labels, same_labels, examples=batch_examples, prompt=batch_prompt
            )
            max_num_classes = self._resolve_max_num_classes(batch_labels, same_labels)
            model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)
            logits = model_output.logits
            probs = torch.sigmoid(logits)

            for i in range(len(batch_texts)):
                global_idx = idx + i
                item_classification_type = classification_types[global_idx]
                item_threshold = thresholds[global_idx]

                if same_labels:
                    curr_labels = batch_labels
                else:
                    curr_labels = batch_labels[i]

                if item_classification_type == "single-label":
                    score = torch.softmax(logits[i][: len(curr_labels)], dim=-1)

                    if return_hierarchical:
                        all_scores = {curr_labels[j]: score[j].item() for j in range(len(curr_labels))}
                        all_scores_list.append(all_scores)

                    pred_label = curr_labels[torch.argmax(score).item()]
                    results.append([{"label": pred_label, "score": score.max().item()}])
                elif item_classification_type == "multi-label":
                    text_results = []

                    if return_hierarchical:
                        all_scores = {curr_labels[j]: probs[i][j].item() for j in range(len(curr_labels))}
                        all_scores_list.append(all_scores)

                    for j, prob in enumerate(probs[i][: len(curr_labels)]):
                        score = prob.item()
                        if score >= item_threshold:
                            text_results.append({"label": curr_labels[j], "score": score})
                    results.append(text_results)
                else:
                    raise ValueError("Unsupported classification type: choose 'single-label' or 'multi-label'")

        if return_hierarchical:
            hierarchical_results = []
            for i, (result, all_scores) in enumerate(zip(results, all_scores_list)):
                if same_labels:
                    orig_lbl = original_labels
                else:
                    orig_lbl = original_labels[i] if i < len(original_labels) else original_labels

                hierarchical_results.append(
                    build_hierarchical_output(result, orig_lbl, self.label_separator, all_scores)
                )
            return hierarchical_results

        return results


class UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
    def __init__(
        self,
        model,
        tokenizer,
        max_classes=25,
        max_length=1024,
        classification_type="multi-label",
        device="cuda:0",
        progress_bar=True,
        label_separator: str = ".",
    ):
        super().__init__(
            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
        )

    def prepare_input(self, text, labels, examples=None, prompt=None):
        """
        Prepare input matching training format from data_processing.py:
        Order: Labels → SEP → Prompt → Text → Examples.
        """
        input_parts = []

        # 1. Add labels
        for label in labels:
            label_tag = f"{self.label_token}{label}"
            input_parts.append(label_tag)
        input_parts.append(self.sep_token)

        # 2. Add task description prompt
        if prompt:
            input_parts.append(prompt)

        # 3. Format examples to go after text
        examples_str = ""
        if examples:
            examples_str = self._format_examples_for_input(examples)

        if self.model.config.prompt_first:
            return "".join(input_parts) + text + examples_str
        else:
            return text + "".join(input_parts) + examples_str

    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
        inputs = []

        if same_labels:
            for i, text in enumerate(texts):
                text_examples = self._get_text_examples(examples, i)
                text_prompt = self._format_prompt(prompt, i)
                inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))
        else:
            for i, (text, labels_) in enumerate(zip(texts, labels)):
                text_examples = self._get_text_examples(examples, i)
                text_prompt = self._format_prompt(prompt, i)
                inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))

        tokenized_inputs = self.tokenizer(
            inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
        ).to(self.device)

        return tokenized_inputs


class EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
    def __init__(
        self,
        model,
        tokenizer,
        max_classes=25,
        max_length=1024,
        classification_type="multi-label",
        device="cuda:0",
        progress_bar=True,
        label_separator: str = ".",
    ):
        super().__init__(
            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
        )

    def prepare_labels_prompt(self, labels, prompt=None):
        """Match training format: Labels → SEP → Prompt."""
        input_parts = []

        for label in labels:
            # label_tag = f"{label}{self.label_token}"
            label_tag = f"{self.label_token}{label}"
            input_parts.append(label_tag)
        input_parts.append(self.sep_token)

        if prompt:
            input_parts.append(prompt)

        return "".join(input_parts)

    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
        prompts = []
        processed_texts = []

        if same_labels:
            for i, text in enumerate(texts):
                text_examples = self._get_text_examples(examples, i)
                text_prompt = self._format_prompt(prompt, i)
                prompts.append(self.prepare_labels_prompt(labels, text_prompt))
                examples_str = self._format_examples_for_input(text_examples) if text_examples else ""
                processed_texts.append(text + examples_str)
        else:
            for i, labels_ in enumerate(labels):
                text_examples = self._get_text_examples(examples, i)
                text_prompt = self._format_prompt(prompt, i)
                prompts.append(self.prepare_labels_prompt(labels_, text_prompt))
                examples_str = self._format_examples_for_input(text_examples) if text_examples else ""
                processed_texts.append(texts[i] + examples_str)

        tokenized_inputs = self.tokenizer(
            processed_texts, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
        ).to(self.device)

        tokenized_classes = self.tokenizer(
            prompts, max_length=self.max_length, truncation=True, padding="longest", return_tensors="pt"
        ).to(self.device)

        tokenized_inputs["class_input_ids"] = tokenized_classes["input_ids"]
        tokenized_inputs["class_attention_mask"] = tokenized_classes["attention_mask"]

        return tokenized_inputs


class BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
    def __init__(
        self,
        model,
        tokenizer,
        max_classes=25,
        max_length=1024,
        classification_type="multi-label",
        device="cuda:0",
        progress_bar=True,
        label_separator: str = ".",
    ):
        super().__init__(
            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
        )
        self.labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)

    def prepare_input(self, text, labels, examples=None, prompt=None):
        input_parts = []

        if prompt:
            input_parts.append(prompt)
            input_parts.append(" ")

        for _label in labels:
            input_parts.append(self.label_token)
        input_parts.append(self.sep_token)

        examples_str = ""
        if examples:
            examples_str = self._format_examples_for_input(examples)

        if self.model.config.prompt_first:
            return "".join(input_parts) + text + examples_str
        else:
            return text + "".join(input_parts) + examples_str

    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
        if self.model.config.architecture_type == "bi-encoder-fused":
            inputs = []
            if same_labels:
                for i, text in enumerate(texts):
                    text_examples = self._get_text_examples(examples, i)
                    text_prompt = self._format_prompt(prompt, i)
                    inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))
            else:
                for i, (text, labels_) in enumerate(zip(texts, labels)):
                    text_examples = self._get_text_examples(examples, i)
                    text_prompt = self._format_prompt(prompt, i)
                    inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))
        else:
            inputs = []
            for i, text in enumerate(texts):
                text_prompt = self._format_prompt(prompt, i)
                if text_prompt:
                    inputs.append(f"{text_prompt} {text}")
                else:
                    inputs.append(text)

        if same_labels:
            tokenized_inputs = self.tokenizer(
                inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
            ).to(self.device)

            tokenized_labels = self.labels_tokenizer(
                labels, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
            ).to(self.device)

            tokenized_inputs["class_input_ids"] = tokenized_labels["input_ids"].expand(len(texts), -1, -1)
            tokenized_inputs["class_attention_mask"] = tokenized_labels["attention_mask"].expand(len(texts), -1, -1)

            labels_mask = [[1 for _ in range(len(labels))] for _ in range(len(texts))]
            tokenized_inputs["labels_mask"] = torch.tensor(labels_mask).to(self.device)
        else:
            tokenized_inputs = self.tokenizer(
                inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
            ).to(self.device)

            class_input_ids = []
            class_attention_mask = []

            for labels_set in labels:
                tokenized_labels = self.labels_tokenizer(
                    labels_set, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt"
                ).to(self.device)
                class_input_ids.append(tokenized_labels["input_ids"])
                class_attention_mask.append(tokenized_labels["attention_mask"])

            tokenized_inputs["class_input_ids"] = torch.stack(class_input_ids)
            tokenized_inputs["class_attention_mask"] = torch.stack(class_attention_mask)

            labels_mask = [[1 for _ in range(len(labels[j]))] for j in range(len(texts))]
            tokenized_inputs["labels_mask"] = torch.tensor(labels_mask).to(self.device)
        return tokenized_inputs


class ZeroShotClassificationPipeline:
    """
    Main pipeline class for zero-shot classification with GLiClass models.

    Supports:
    - Hierarchical labels with dot notation (e.g., {"sentiment": ["positive", "negative"]})
    - Few-shot examples with <<EXAMPLE>> token
    - Task description prompts
    - Hierarchical output format matching input structure

    Example usage:

    ```python
    from gliclass import ZeroShotClassificationPipeline

    pipeline = ZeroShotClassificationPipeline(model, tokenizer)

    # === Hierarchical Labels for Review Classification ===
    hierarchical_labels = {
        "sentiment": ["positive", "negative", "neutral"],
        "topic": ["product", "service", "shipping"],
    }

    # Basic classification
    results = pipeline("The product quality is amazing but delivery was slow", hierarchical_labels)
    # Results: [
    #     {'label': 'sentiment.positive', 'score': 0.89},
    #     {'label': 'topic.product', 'score': 0.92},
    #     {'label': 'topic.shipping', 'score': 0.76}
    # ]

    # === With Task Description Prompt ===
    results = pipeline(
        "The product quality is amazing but delivery was slow",
        hierarchical_labels,
        prompt="Classify this customer review by sentiment and topic:",
    )

    # === With Few-Shot Examples (uses <<EXAMPLE>> token) ===
    examples = [
        {"text": "Love this item, great quality!", "labels": ["sentiment.positive", "topic.product"]},
        {"text": "Customer support was unhelpful and rude", "labels": ["sentiment.negative", "topic.service"]},
        {"text": "Package arrived damaged after 2 weeks", "labels": ["sentiment.negative", "topic.shipping"]},
    ]

    results = pipeline(
        "Fast delivery and the item works perfectly!",
        hierarchical_labels,
        examples=examples,
        prompt="Classify customer feedback:",
    )

    # === Hierarchical Output (matches input structure) ===
    results = pipeline(
        "The product quality is amazing but delivery was slow", hierarchical_labels, return_hierarchical=True
    )
    # Returns:
    # {
    #     "sentiment": {
    #         "positive": 0.89,
    #         "negative": 0.05,
    #         "neutral": 0.12
    #     },
    #     "topic": {
    #         "product": 0.92,
    #         "service": 0.15,
    #         "shipping": 0.76
    #     }
    # }

    # === Per-Text Prompts ===
    results = pipeline(
        ["Electronics review text", "Clothing review text"],
        hierarchical_labels,
        prompt=["Analyze this electronics review:", "Analyze this clothing review:"],
    )
    ```
    """

    def __init__(
        self,
        model,
        tokenizer,
        max_classes: int = 25,
        max_length: int = 1024,
        classification_type: str = "multi-label",
        device: str = "cuda:0",
        progress_bar: bool = True,
        label_separator: str = ".",
    ):
        """
        Initialize the classification pipeline.

        Args:
            model: GLiClass model or path to model
            tokenizer: Tokenizer or path to tokenizer
            max_classes: Maximum number of classes to process
            max_length: Maximum sequence length
            classification_type: 'single-label' or 'multi-label'
            device: Device to run inference on
            progress_bar: Whether to show progress bar
            label_separator: Separator for hierarchical label notation (default: ".")
        """
        if isinstance(model, str):
            model = GLiClassBiEncoder.from_pretrained(model)

        self.label_separator = label_separator

        if model.config.architecture_type == "uni-encoder":
            self.pipe = UniEncoderZeroShotClassificationPipeline(
                model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
            )
        elif model.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
            self.pipe = EncoderDecoderZeroShotClassificationPipeline(
                model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
            )
        elif model.config.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
            self.pipe = BiEncoderZeroShotClassificationPipeline(
                model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
            )
        else:
            raise NotImplementedError("This architecture is not implemented")

    def flatten_labels(self, labels: List[str] | Dict[str, Any]) -> List[str]:
        """
        Flatten hierarchical labels to dot notation.

        Example:
            >>> pipeline.flatten_labels(
            ...     {"sentiment": ["positive", "negative", "neutral"], "topic": ["product", "service", "shipping"]}
            ... )
            ["sentiment.positive", "sentiment.negative", "sentiment.neutral",
             "topic.product", "topic.service", "topic.shipping"]
        """
        return flatten_hierarchical_labels(labels, separator=self.label_separator)

    def get_embeddings(self, *args, **kwargs):
        """Get embeddings for texts and labels."""
        return self.pipe.get_embeddings(*args, **kwargs)

    def __call__(
        self,
        texts: str | List[str],
        labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]],
        threshold: float | List[float] = 0.5,
        batch_size: int = 8,
        classification_type: str | List[str] | None = None,
        rac_examples: List | None = None,
        examples: List[Dict[str, Any]] | None = None,
        prompt: str | List[str] | None = None,
        return_hierarchical: bool = False,
    ):
        """
        Perform zero-shot classification.

        Args:
            texts: Single text or list of texts to classify
            labels: Labels - flat list or hierarchical dict
                Examples:
                - ["positive", "negative"] - flat labels
                - {"sentiment": ["positive", "negative"], "topic": ["product", "service"]}
            threshold: Classification threshold for multi-label, either one
                value for all texts or one value per text
            batch_size: Batch size for processing
            classification_type: Override classification mode globally or per text.
                If None, uses the pipeline's configured classification_type
            rac_examples: Retrieval augmented examples (legacy)
            examples: Few-shot examples, each with 'text' and 'labels' keys
            prompt: Task description - string or list of strings (per-text)
            return_hierarchical: If True, return structure matching input labels

        Returns:
            List of predictions (flat) or hierarchical dicts with all scores
        """
        return self.pipe(
            texts,
            labels,
            threshold=threshold,
            batch_size=batch_size,
            classification_type=classification_type,
            rac_examples=rac_examples,
            examples=examples,
            prompt=prompt,
            return_hierarchical=return_hierarchical,
        )


class ZeroShotClassificationWithChunkingPipeline(BaseZeroShotClassificationPipeline):
    """Pipeline with long text chunking support."""

    def __init__(
        self,
        model,
        tokenizer,
        max_classes: int = 25,
        max_length: int = 1024,
        classification_type: str = "multi-label",
        device: str = "cuda:0",
        progress_bar: bool = True,
        text_chunk_size: int = 8192,
        text_chunk_overlap: int = 256,
        labels_chunk_size: int = 8,
        label_separator: str = ".",
    ):
        if isinstance(model, str):
            model = GLiClassModel.from_pretrained(model)
        super().__init__(
            model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
        )

        self.text_chunk_size = text_chunk_size
        self.text_chunk_overlap = text_chunk_overlap
        self.labels_chunk_size = labels_chunk_size

    def chunk_text(self, text, chunk_size=None, overlap=None):
        """Split text into overlapping chunks."""
        if chunk_size is None:
            chunk_size = self.text_chunk_size
        if overlap is None:
            overlap = self.text_chunk_overlap

        if len(text) <= chunk_size:
            return [text]

        chunks = []
        start = 0
        while start < len(text):
            end = start + chunk_size
            chunk = text[start:end]
            chunks.append(chunk)

            if end >= len(text):
                break

            start = end - overlap

        return chunks

    def prepare_input(self, text, labels, examples=None, prompt=None):
        """
        Prepare input matching training format from data_processing.py:
        Order: Labels → SEP → Prompt → Text → Examples.
        """
        input_parts = []

        # 1. Add labels
        for label in labels:
            label_tag = f"{self.label_token}{label}"
            input_parts.append(label_tag)
        input_parts.append(self.sep_token)

        # 2. Add task description prompt
        if prompt:
            input_parts.append(prompt)

        # 3. Format examples to go after text
        examples_str = ""
        if examples:
            examples_str = self._format_examples_for_input(examples)

        if self.model.config.prompt_first:
            return "".join(input_parts) + text + examples_str
        else:
            return text + "".join(input_parts) + examples_str

    def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
        inputs = []

        if same_labels:
            for i, text in enumerate(texts):
                text_examples = self._get_text_examples(examples, i)
                text_prompt = self._format_prompt(prompt, i)
                inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))
        else:
            for i, (text, labels_) in enumerate(zip(texts, labels)):
                text_examples = self._get_text_examples(examples, i)
                text_prompt = self._format_prompt(prompt, i)
                inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))

        tokenized_inputs = self.tokenizer(
            inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
        ).to(self.device)
        return tokenized_inputs

    def aggregate_chunk_scores(self, chunk_scores: List[Dict[str, float]], labels: List[str]) -> Dict[str, float]:
        """Aggregate scores across text chunks using max pooling."""
        aggregated = dict.fromkeys(labels, 0.0)

        for scores in chunk_scores:
            for label, score in scores.items():
                aggregated[label] = max(aggregated[label], score)

        return aggregated

    @torch.no_grad()
    def process_single_text(self, text, labels, threshold=0.5, examples=None, prompt=None):
        """Process a single long text through chunks."""
        text_chunks = self.chunk_text(text)

        all_chunk_scores = []

        for text_chunk in text_chunks:
            chunk_logits = []
            all_labels = []

            for labels_idx in range(0, len(labels), self.labels_chunk_size):
                curr_labels = labels[labels_idx : labels_idx + self.labels_chunk_size]
                if labels_idx == 0:
                    all_labels = []
                all_labels.extend(curr_labels)

                tokenized_inputs = self.prepare_inputs(
                    [text_chunk], curr_labels, same_labels=True, examples=examples, prompt=prompt
                )
                max_num_classes = self._resolve_max_num_classes(curr_labels, same_labels=True)
                model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)
                logits = model_output.logits

                chunk_logits.extend(logits[0][: len(curr_labels)].tolist())

            text_logits = torch.tensor(chunk_logits)

            if self.classification_type == "single-label":
                scores = torch.softmax(text_logits, dim=-1)
            else:
                scores = torch.sigmoid(text_logits)

            chunk_score_dict = {label: scores[i].item() for i, label in enumerate(all_labels)}
            all_chunk_scores.append(chunk_score_dict)

        aggregated_scores = self.aggregate_chunk_scores(all_chunk_scores, labels)

        if self.classification_type == "single-label":
            total = sum(aggregated_scores.values())
            if total > 0:
                aggregated_scores = {k: v / total for k, v in aggregated_scores.i
Download .txt
gitextract_2n3b6llp/

├── .github/
│   └── workflows/
│       ├── release.yaml
│       └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── demo.py
├── gliclass/
│   ├── __init__.py
│   ├── config.py
│   ├── data_processing.py
│   ├── layers.py
│   ├── loss_functions.py
│   ├── model.py
│   ├── ops.py
│   ├── pipeline.py
│   ├── poolings.py
│   ├── scorers.py
│   ├── serve/
│   │   ├── __init__.py
│   │   ├── __main__.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── memory.py
│   │   └── server.py
│   ├── training.py
│   └── utils.py
├── notebooks/
│   └── finetuning.ipynb
├── pyproject.toml
├── serve_configs/
│   └── serve_config.yaml
├── test_gliclass.py
├── tests/
│   ├── test_data_processing.py
│   ├── test_loss_functions.py
│   ├── test_poolings.py
│   ├── test_scorers.py
│   └── test_utils.py
├── train.py
└── train_rl.py
Download .txt
SYMBOL INDEX (409 symbols across 25 files)

FILE: demo.py
  function parse_labels_input (line 204) | def parse_labels_input(labels_input: str) -> Union[List[str], Dict[str, ...
  function parse_examples_input (line 226) | def parse_examples_input(examples_input: str) -> Optional[List[Dict[str,...
  function format_output (line 257) | def format_output(
  function format_as_json (line 277) | def format_as_json(results: Union[List[Dict], Dict], hierarchical: bool ...
  function format_hierarchical_dict (line 297) | def format_hierarchical_dict(d: Dict, indent: int = 0) -> str:
  function classification (line 313) | def classification(
  function update_output_visibility (line 821) | def update_output_visibility(hierarchical: bool, fmt: str):
  function classify_wrapper (line 842) | def classify_wrapper(text, labels, threshold, multi_label, prompt, examp...

FILE: gliclass/config.py
  class GLiClassModelConfig (line 19) | class GLiClassModelConfig(PretrainedConfig):
    method __init__ (line 23) | def __init__(

FILE: gliclass/data_processing.py
  class AugmentationConfig (line 11) | class AugmentationConfig:
  class DataAugmenter (line 26) | class DataAugmenter:
    method __init__ (line 27) | def __init__(self, config, examples, labels, label2description=None):
    method remove_labels (line 34) | def remove_labels(self, true_labels, all_labels):
    method add_random_labels (line 42) | def add_random_labels(self, all_labels):
    method add_random_text (line 51) | def add_random_text(self, text, all_labels):
    method add_random_synonyms (line 66) | def add_random_synonyms(self, all_labels):
    method add_random_descriptions (line 86) | def add_random_descriptions(self, item):
    method add_random_examples (line 111) | def add_random_examples(self, item):
    method augment (line 145) | def augment(self, item):
  class GLiClassDataset (line 184) | class GLiClassDataset(Dataset):
    method __init__ (line 185) | def __init__(
    method get_diversity (line 222) | def get_diversity(self):
    method collect_dataset_labels (line 225) | def collect_dataset_labels(self):
    method prepare_labels (line 231) | def prepare_labels(self, example, label2idx, problem_type):
    method prepare_prompt (line 243) | def prepare_prompt(self, item, label_token_first=True):
    method format_examples (line 256) | def format_examples(self, item):
    method tokenize (line 270) | def tokenize(self, texts):
    method tokenize_labels (line 274) | def tokenize_labels(self, labels):
    method tokenize_and_prepare_labels_for_uniencoder (line 278) | def tokenize_and_prepare_labels_for_uniencoder(self, example):
    method tokenize_and_prepare_labels_for_encoder_decoder (line 295) | def tokenize_and_prepare_labels_for_encoder_decoder(self, example):
    method tokenize_and_prepare_labels_for_biencoder (line 312) | def tokenize_and_prepare_labels_for_biencoder(self, example):
    method __len__ (line 346) | def __len__(self):
    method __getitem__ (line 349) | def __getitem__(self, idx):
  function pad_2d_tensor (line 365) | def pad_2d_tensor(key_data):
  class DataCollatorWithPadding (line 395) | class DataCollatorWithPadding:
    method __init__ (line 396) | def __init__(self, device="cuda:0", config=None):
    method _resolve_max_num_classes (line 400) | def _resolve_max_num_classes(self, batch):
    method __call__ (line 415) | def __call__(self, batch):

FILE: gliclass/layers.py
  class LstmSeq2SeqEncoder (line 23) | class LstmSeq2SeqEncoder(nn.Module):
    method __init__ (line 24) | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0,...
    method forward (line 35) | def forward(self, x, mask, hidden=None):
  class FeaturesProjector (line 49) | class FeaturesProjector(nn.Module):
    method __init__ (line 50) | def __init__(self, config: GLiClassModelConfig):
    method forward (line 58) | def forward(self, features):
  class BiEncoderProjector (line 66) | class BiEncoderProjector(nn.Module):
    method __init__ (line 67) | def __init__(self, config: GLiClassModelConfig):
    method forward (line 74) | def forward(self, features):
  class DropoutContext (line 82) | class DropoutContext:
    method __init__ (line 83) | def __init__(self):
  function get_mask (line 91) | def get_mask(input, local_context):
  class XDropout (line 110) | class XDropout(torch.autograd.Function):
    method forward (line 114) | def forward(ctx, input, local_ctx):
    method backward (line 124) | def backward(ctx, grad_output):
    method symbolic (line 132) | def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: floa...
  class StableDropout (line 150) | class StableDropout(nn.Module):
    method __init__ (line 158) | def __init__(self, drop_prob):
    method forward (line 164) | def forward(self, x):
    method clear_context (line 175) | def clear_context(self):
    method init_context (line 179) | def init_context(self, reuse_mask=True, scale=1):
    method get_context (line 187) | def get_context(self):
  class SelfAttentionBlock (line 199) | class SelfAttentionBlock(nn.Module):
    method __init__ (line 200) | def __init__(self, d_model, num_heads, dropout=0.1):
    method forward (line 206) | def forward(self, x, mask=None):
  class CrossAttentionBlock (line 211) | class CrossAttentionBlock(nn.Module):
    method __init__ (line 212) | def __init__(self, d_model, num_heads, dropout=0.1):
    method forward (line 218) | def forward(self, query, key, value, mask=None):
  class Fuser (line 223) | class Fuser(nn.Module):
    method __init__ (line 224) | def __init__(self, d_model, num_heads, num_layers, dropout=0.1):
    method forward (line 237) | def forward(self, query, key, query_mask=None, key_mask=None):
  class LayerwiseAttention (line 254) | class LayerwiseAttention(nn.Module):
    method __init__ (line 255) | def __init__(self, num_layers, hidden_size, output_size=None):
    method forward (line 271) | def forward(self, encoder_outputs):

FILE: gliclass/loss_functions.py
  function sequence_contrastive_loss (line 5) | def sequence_contrastive_loss(embeddings, mask):
  function focal_loss_with_logits (line 31) | def focal_loss_with_logits(

FILE: gliclass/model.py
  class GLiClassOutput (line 60) | class GLiClassOutput(SequenceClassifierOutput):
  class GLiClassPreTrainedModel (line 65) | class GLiClassPreTrainedModel(PreTrainedModel):
    method _initialize_weights (line 72) | def _initialize_weights(self, module, is_remote_code: bool = False):
    method _init_weights (line 88) | def _init_weights(self, module):
  class GLiClassBaseModel (line 117) | class GLiClassBaseModel(nn.Module):  # ):
    method __init__ (line 118) | def __init__(self, config: GLiClassModelConfig, device="cpu", **kwargs):
    method _extract_class_features (line 162) | def _extract_class_features(self, token_embeds, input_ids, attention_m...
    method _extract_class_features_first (line 224) | def _extract_class_features_first(
    method _extract_class_features_averaged (line 255) | def _extract_class_features_averaged(
    method get_loss (line 296) | def get_loss(self, logits, labels, classes_embedding=None, classes_emb...
  class GLiClassUniEncoder (line 347) | class GLiClassUniEncoder(GLiClassBaseModel):
    method __init__ (line 348) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
    method _create_segment_ids (line 419) | def _create_segment_ids(self, input_ids):
    method process_encoder_output (line 444) | def process_encoder_output(self, input_ids, attention_mask, encoder_la...
    method forward (line 469) | def forward(
  class GLiClassEncoderDecoder (line 556) | class GLiClassEncoderDecoder(GLiClassBaseModel):
    method __init__ (line 557) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
    method _make_bidirectional_4d_mask (line 573) | def _make_bidirectional_4d_mask(attention_mask_2d, dtype):
    method forward (line 593) | def forward(
  class GLiClassEncoderDecoderCLS (line 674) | class GLiClassEncoderDecoderCLS(GLiClassBaseModel):
    method __init__ (line 681) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
    method forward (line 696) | def forward(
  class GLiClassBiEncoder (line 767) | class GLiClassBiEncoder(GLiClassBaseModel):
    method __init__ (line 768) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
    method pool_outputs (line 790) | def pool_outputs(self, encoder_outputs):
    method encode_text (line 798) | def encode_text(self, input_ids, attention_mask):
    method encode_classes (line 803) | def encode_classes(self, class_input_ids, class_attention_mask, labels...
    method forward (line 835) | def forward(
  class GLiClassBiEncoderFused (line 871) | class GLiClassBiEncoderFused(GLiClassBiEncoder):
    method __init__ (line 872) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
    method encode_text (line 875) | def encode_text(self, input_ids, attention_mask, class_embeddings, lab...
    method forward (line 895) | def forward(
  class GLiClassModel (line 937) | class GLiClassModel(GLiClassPreTrainedModel):
    method __init__ (line 938) | def __init__(self, config, from_pretrained=False):
    method get_input_embeddings (line 952) | def get_input_embeddings(self):
    method set_input_embeddings (line 960) | def set_input_embeddings(self, value):
    method tie_weights (line 971) | def tie_weights(self, recompute_mapping=True, missing_keys=None):
    method resize_token_embeddings (line 1025) | def resize_token_embeddings(self, new_num_tokens: int | None = None, p...
    method forward (line 1039) | def forward(self, *args, **kwargs):

FILE: gliclass/ops.py
  function attn_padded (line 7) | def attn_padded(

FILE: gliclass/pipeline.py
  function flatten_hierarchical_labels (line 12) | def flatten_hierarchical_labels(
  function build_hierarchical_output (line 69) | def build_hierarchical_output(
  function format_examples_prompt (line 135) | def format_examples_prompt(
  class BaseZeroShotClassificationPipeline (line 174) | class BaseZeroShotClassificationPipeline(ABC):
    method __init__ (line 175) | def __init__(
    method _normalize_classification_type (line 216) | def _normalize_classification_type(self, classification_type: str | No...
    method _normalize_texts (line 227) | def _normalize_texts(self, texts: str | List[str]) -> List[str]:
    method _normalize_thresholds (line 232) | def _normalize_thresholds(self, threshold: float | List[float], num_te...
    method _normalize_classification_types (line 239) | def _normalize_classification_types(
    method _process_labels (line 252) | def _process_labels(
    method _format_examples_for_input (line 281) | def _format_examples_for_input(self, examples: List[Dict[str, Any]] | ...
    method _examples_are_per_text (line 290) | def _examples_are_per_text(self, examples) -> bool:
    method _get_text_examples (line 298) | def _get_text_examples(self, examples, index: int):
    method _format_prompt (line 306) | def _format_prompt(self, prompt: str | List[str] | None = None, index:...
    method _resolve_max_num_classes (line 321) | def _resolve_max_num_classes(self, batch_labels, same_labels: bool):
    method prepare_inputs (line 329) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
    method _get_batch_examples (line 332) | def _get_batch_examples(self, examples, start_idx, batch_size):
    method _get_batch_prompt (line 340) | def _get_batch_prompt(self, prompt, start_idx, batch_size):
    method get_embeddings (line 349) | def get_embeddings(self, texts, labels, batch_size=8, examples=None, p...
    method __call__ (line 397) | def __call__(
  class UniEncoderZeroShotClassificationPipeline (line 522) | class UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificatio...
    method __init__ (line 523) | def __init__(
    method prepare_input (line 538) | def prepare_input(self, text, labels, examples=None, prompt=None):
    method prepare_inputs (line 565) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
  class EncoderDecoderZeroShotClassificationPipeline (line 586) | class EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassific...
    method __init__ (line 587) | def __init__(
    method prepare_labels_prompt (line 602) | def prepare_labels_prompt(self, labels, prompt=None):
    method prepare_inputs (line 617) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
  class BiEncoderZeroShotClassificationPipeline (line 650) | class BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassification...
    method __init__ (line 651) | def __init__(
    method prepare_input (line 667) | def prepare_input(self, text, labels, examples=None, prompt=None):
    method prepare_inputs (line 687) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
  class ZeroShotClassificationPipeline (line 746) | class ZeroShotClassificationPipeline:
    method __init__ (line 825) | def __init__(
    method flatten_labels (line 869) | def flatten_labels(self, labels: List[str] | Dict[str, Any]) -> List[s...
    method get_embeddings (line 882) | def get_embeddings(self, *args, **kwargs):
    method __call__ (line 886) | def __call__(
  class ZeroShotClassificationWithChunkingPipeline (line 933) | class ZeroShotClassificationWithChunkingPipeline(BaseZeroShotClassificat...
    method __init__ (line 936) | def __init__(
    method chunk_text (line 960) | def chunk_text(self, text, chunk_size=None, overlap=None):
    method prepare_input (line 984) | def prepare_input(self, text, labels, examples=None, prompt=None):
    method prepare_inputs (line 1011) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
    method aggregate_chunk_scores (line 1030) | def aggregate_chunk_scores(self, chunk_scores: List[Dict[str, float]],...
    method process_single_text (line 1041) | def process_single_text(self, text, labels, threshold=0.5, examples=No...
    method __call__ (line 1095) | def __call__(
  function parse_hierarchical_prediction (line 1221) | def parse_hierarchical_prediction(prediction: str, separator: str = ".")...
  function group_predictions_by_hierarchy (line 1230) | def group_predictions_by_hierarchy(
  function get_best_per_category (line 1250) | def get_best_per_category(predictions: List[Dict[str, Any]], separator: ...

FILE: gliclass/poolings.py
  class GlobalMaxPooling1D (line 5) | class GlobalMaxPooling1D(nn.Module):
    method forward (line 8) | def forward(self, x: torch.Tensor):
  class FirstTokenPooling1D (line 12) | class FirstTokenPooling1D(nn.Module):
    method forward (line 15) | def forward(self, x: torch.Tensor):
  class LastTokenPooling1D (line 19) | class LastTokenPooling1D(nn.Module):
    method forward (line 22) | def forward(self, x: torch.Tensor):
  class GlobalAvgPooling1D (line 26) | class GlobalAvgPooling1D(nn.Module):
    method forward (line 29) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
  class GlobalSumPooling1D (line 38) | class GlobalSumPooling1D(nn.Module):
    method forward (line 41) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
  class GlobalRMSPooling1D (line 47) | class GlobalRMSPooling1D(nn.Module):
    method forward (line 50) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
  class GlobalAbsMaxPooling1D (line 59) | class GlobalAbsMaxPooling1D(nn.Module):
    method forward (line 62) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
  class GlobalAbsAvgPooling1D (line 69) | class GlobalAbsAvgPooling1D(nn.Module):
    method forward (line 72) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
  class PassPooling1D (line 81) | class PassPooling1D(nn.Module):
    method forward (line 84) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...

FILE: gliclass/scorers.py
  class ScorerWeightedDot (line 7) | class ScorerWeightedDot(nn.Module):
    method __init__ (line 8) | def __init__(self, hidden_size, dropout=0.1, **kwargs):
    method forward (line 21) | def forward(self, text_rep, label_rep, **kwargs):
  class ScorerDot (line 42) | class ScorerDot(nn.Module):
    method __init__ (line 43) | def __init__(self, *args, **kwargs):
    method forward (line 47) | def forward(self, text_rep, label_rep, **kwargs):
  class MLPScorer (line 53) | class MLPScorer(nn.Module):
    method __init__ (line 54) | def __init__(self, hidden_size, mlp_hidden_size=256, **kwargs):
    method forward (line 69) | def forward(self, text_rep, label_rep, **kwargs):
  class HopfieldScorer (line 81) | class HopfieldScorer(nn.Module):
    method __init__ (line 82) | def __init__(self, hidden_size, mlp_hidden_size=256, beta=4, num_itera...
    method forward (line 101) | def forward(self, text_rep, label_rep, **kwargs):
  class CrossAttnScorer (line 128) | class CrossAttnScorer(nn.Module):
    method __init__ (line 129) | def __init__(self, hidden_size, num_heads=16, attn_dropout=0.1, scorer...
    method forward (line 154) | def forward(self, text_rep, label_rep, text_mask=None, **kwargs):

FILE: gliclass/serve/__main__.py
  function main (line 21) | def main():

FILE: gliclass/serve/client.py
  class GLiClassClient (line 6) | class GLiClassClient:
    method __init__ (line 9) | def __init__(self, url: str = "http://localhost:8000/gliclass"):
    method __call__ (line 17) | def __call__(
    method classify (line 55) | def classify(
    method health_check (line 80) | def health_check(self) -> bool:

FILE: gliclass/serve/config.py
  class GLiClassServeConfig (line 10) | class GLiClassServeConfig:
    method __post_init__ (line 63) | def __post_init__(self):
    method to_env_vars (line 68) | def to_env_vars(self) -> dict:
    method from_yaml (line 76) | def from_yaml(cls, config_path: str | Path) -> "GLiClassServeConfig":
    method to_yaml (line 90) | def to_yaml(self, config_path: str | Path) -> None:
    method update (line 101) | def update(self, **kwargs) -> "GLiClassServeConfig":

FILE: gliclass/serve/memory.py
  function _power_of_two_seq_lens (line 21) | def _power_of_two_seq_lens(max_seq_len: int, min_seq_len: int = 64) -> L...
  class GLiClassMemoryEstimator (line 32) | class GLiClassMemoryEstimator:
    method __init__ (line 35) | def __init__(
    method measure_cuda_context (line 51) | def measure_cuda_context(self) -> None:
    method measure_model_memory (line 61) | def measure_model_memory(self) -> None:
    method available_memory (line 73) | def available_memory(self) -> int:
    method calibrate (line 80) | def calibrate(
    method _measure_peak (line 106) | def _measure_peak(
    method _lookup_seq_len (line 124) | def _lookup_seq_len(self, seq_len: int) -> int:
    method per_sample_at (line 133) | def per_sample_at(self, seq_len: int) -> int:
    method batch_size_fn (line 138) | def batch_size_fn(

FILE: gliclass/serve/server.py
  class GLiClassServer (line 20) | class GLiClassServer:
    method __init__ (line 23) | def __init__(self, config: GLiClassServeConfig):
    method _precompile (line 92) | def _precompile(self) -> None:
    method _calibrate_memory (line 118) | def _calibrate_memory(self) -> None:
    method batch_size_fn (line 129) | def batch_size_fn(self, seq_len: int | None = None) -> int:
    method observed_seq_len (line 149) | def observed_seq_len(
    method _filter_labels (line 176) | def _filter_labels(self, labels: list[str]) -> list[str]:
    method _run_batch_internal (line 183) | def _run_batch_internal(
    method predict (line 220) | def predict(
  function _build_deployment (line 249) | def _build_deployment(config: GLiClassServeConfig):
  function serve_gliclass (line 359) | def serve_gliclass(
  function shutdown (line 396) | def shutdown() -> None:
  class GLiClassFactory (line 400) | class GLiClassFactory:
    method __init__ (line 420) | def __init__(
    method handle (line 441) | def handle(self):
    method predict (line 445) | def predict(
    method predict_async (line 472) | async def predict_async(
    method shutdown (line 501) | def shutdown(self) -> None:
    method __enter__ (line 516) | def __enter__(self):
    method __exit__ (line 519) | def __exit__(self, exc_type, exc_val, exc_tb):
    method __del__ (line 523) | def __del__(self):

FILE: gliclass/training.py
  class EWC (line 32) | class EWC:
    method __init__ (line 35) | def __init__(
    method _compute_fisher (line 82) | def _compute_fisher(self, dataset: Dataset) -> Dict[str, torch.Tensor]:
    method _normalize_fisher (line 187) | def _normalize_fisher(self):
    method ewc_loss (line 199) | def ewc_loss(self, batch_size: int | None = None) -> torch.Tensor:
    method get_importance_scores (line 226) | def get_importance_scores(self) -> Dict[str, float]:
    method update_lambda (line 237) | def update_lambda(self, new_lambda: float):
    method consolidate (line 245) | def consolidate(self, dataset: Dataset, alpha: float = 0.5):
  class TrainingArguments (line 273) | class TrainingArguments(transformers.TrainingArguments):
  class Trainer (line 294) | class Trainer(transformers.Trainer):
    method __init__ (line 297) | def __init__(self, ewc: EWC | None = None, prev_dataset=None, *args, *...
    method _maybe_initialize_ewc (line 316) | def _maybe_initialize_ewc(self):
    method compute_loss (line 352) | def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
    method train (line 381) | def train(self, *args, **kwargs):
    method training_step (line 387) | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
    method prediction_step (line 437) | def prediction_step(
    method create_optimizer (line 498) | def create_optimizer(self):
  class RLTrainerConfig (line 577) | class RLTrainerConfig(TrainingArguments):
  class RLTrainer (line 607) | class RLTrainer(Trainer):
    method __init__ (line 608) | def __init__(
    method _init_metrics (line 625) | def _init_metrics(self):
    method compute_rewards (line 634) | def compute_rewards(
    method get_reference_scores (line 646) | def get_reference_scores(self, input_texts, labels_text):
    method compute_loss (line 671) | def compute_loss(
    method _inner_training_loop (line 757) | def _inner_training_loop(self, *args, **kwargs):
    method log_metrics (line 845) | def log_metrics(self):
    method _save_checkpoint (line 857) | def _save_checkpoint(self, model, step=None):

FILE: gliclass/utils.py
  function is_module_available (line 4) | def is_module_available(module_name):
  class MissedPackageException (line 21) | class MissedPackageException(Exception):
  function retrieval_augmented_text (line 27) | def retrieval_augmented_text(text: str, examples: list) -> str:
  function default_f1_reward (line 64) | def default_f1_reward(

FILE: test_gliclass.py
  class TestModel (line 14) | class TestModel:
    method __init__ (line 16) | def __init__(self, model, token):
    method load_model (line 28) | def load_model(self):
    method prepare_dataset (line 34) | def prepare_dataset(self, dataset, classes=None, text_column='text', l...
    method prepare_nomapping (line 55) | def prepare_nomapping(self, dataset, classes=None, text_column='text',...
    method get_gliclass_predictions (line 79) | def get_gliclass_predictions(self, test_texts, classes, batch_size=8):
    method evaluate (line 84) | def evaluate(self, predicts, true_labels):
    method process (line 90) | def process(self):

FILE: tests/test_data_processing.py
  class TestPad2DTensor (line 9) | class TestPad2DTensor:
    method sample_tensors (line 13) | def sample_tensors(self):
    method test_pads_to_maximum_dimensions (line 21) | def test_pads_to_maximum_dimensions(self, sample_tensors):
    method test_preserves_original_values (line 28) | def test_preserves_original_values(self, sample_tensors):
    method test_pads_with_zeros (line 48) | def test_pads_with_zeros(self, sample_tensors):
    method test_single_tensor (line 58) | def test_single_tensor(self):
    method test_uniform_size_tensors (line 67) | def test_uniform_size_tensors(self):
    method test_empty_tensor_handling (line 81) | def test_empty_tensor_handling(self):
    method test_preserves_dtype (line 94) | def test_preserves_dtype(self):
    method test_varying_row_counts (line 105) | def test_varying_row_counts(self):
    method test_varying_column_counts (line 118) | def test_varying_column_counts(self):
    method test_batch_consistency (line 131) | def test_batch_consistency(self):

FILE: tests/test_loss_functions.py
  class TestSequenceContrastiveLoss (line 9) | class TestSequenceContrastiveLoss:
    method sample_embeddings (line 13) | def sample_embeddings(self):
    method sample_mask (line 21) | def sample_mask(self):
    method test_returns_scalar_loss (line 25) | def test_returns_scalar_loss(self, sample_embeddings, sample_mask):
    method test_loss_is_positive (line 32) | def test_loss_is_positive(self, sample_embeddings, sample_mask):
    method test_identical_sequences_low_loss (line 38) | def test_identical_sequences_low_loss(self):
    method test_handles_masked_positions (line 48) | def test_handles_masked_positions(self):
    method test_gradient_flows_through_loss (line 58) | def test_gradient_flows_through_loss(self, sample_embeddings, sample_m...
  class TestFocalLossWithLogits (line 68) | class TestFocalLossWithLogits:
    method sample_logits (line 72) | def sample_logits(self):
    method sample_targets (line 77) | def sample_targets(self):
    method test_returns_tensor_with_reduction_none (line 81) | def test_returns_tensor_with_reduction_none(self, sample_logits, sampl...
    method test_returns_scalar_with_reduction_mean (line 88) | def test_returns_scalar_with_reduction_mean(self, sample_logits, sampl...
    method test_loss_is_positive (line 94) | def test_loss_is_positive(self, sample_logits, sample_targets):
    method test_perfect_predictions_low_loss (line 100) | def test_perfect_predictions_low_loss(self):
    method test_wrong_predictions_high_loss (line 110) | def test_wrong_predictions_high_loss(self):
    method test_alpha_parameter_effect (line 120) | def test_alpha_parameter_effect(self, sample_logits, sample_targets):
    method test_gamma_parameter_effect (line 128) | def test_gamma_parameter_effect(self, sample_logits, sample_targets):
    method test_reduction_sum (line 137) | def test_reduction_sum(self, sample_logits, sample_targets):
    method test_reduction_none (line 143) | def test_reduction_none(self, sample_logits, sample_targets):
    method test_handles_extreme_logits (line 149) | def test_handles_extreme_logits(self):
    method test_gradient_flows_through_loss (line 159) | def test_gradient_flows_through_loss(self, sample_logits, sample_targe...
    method test_all_zeros_targets (line 168) | def test_all_zeros_targets(self):
    method test_all_ones_targets (line 178) | def test_all_ones_targets(self):

FILE: tests/test_poolings.py
  class TestGlobalMaxPooling1D (line 19) | class TestGlobalMaxPooling1D:
    method pooling_layer (line 23) | def pooling_layer(self):
    method sample_input (line 28) | def sample_input(self):
    method test_returns_max_across_sequence (line 32) | def test_returns_max_across_sequence(self, pooling_layer, sample_input):
    method test_output_shape (line 39) | def test_output_shape(self, pooling_layer, sample_input):
  class TestGlobalAvgPooling1D (line 46) | class TestGlobalAvgPooling1D:
    method pooling_layer (line 50) | def pooling_layer(self):
    method test_returns_average_across_sequence (line 54) | def test_returns_average_across_sequence(self, pooling_layer):
    method test_handles_attention_mask (line 63) | def test_handles_attention_mask(self, pooling_layer):
    method test_output_shape (line 73) | def test_output_shape(self, pooling_layer):
  class TestGlobalSumPooling1D (line 82) | class TestGlobalSumPooling1D:
    method pooling_layer (line 86) | def pooling_layer(self):
    method test_returns_sum_across_sequence (line 90) | def test_returns_sum_across_sequence(self, pooling_layer):
    method test_handles_attention_mask (line 99) | def test_handles_attention_mask(self, pooling_layer):
  class TestFirstTokenPooling1D (line 110) | class TestFirstTokenPooling1D:
    method pooling_layer (line 114) | def pooling_layer(self):
    method test_returns_first_token (line 118) | def test_returns_first_token(self, pooling_layer):
    method test_works_with_batch (line 127) | def test_works_with_batch(self, pooling_layer):
    method test_output_shape (line 136) | def test_output_shape(self, pooling_layer):
  class TestLastTokenPooling1D (line 145) | class TestLastTokenPooling1D:
    method pooling_layer (line 149) | def pooling_layer(self):
    method test_returns_last_token (line 153) | def test_returns_last_token(self, pooling_layer):
    method test_works_with_batch (line 162) | def test_works_with_batch(self, pooling_layer):
    method test_output_shape (line 171) | def test_output_shape(self, pooling_layer):
  class TestGlobalRMSPooling1D (line 180) | class TestGlobalRMSPooling1D:
    method pooling_layer (line 184) | def pooling_layer(self):
    method test_returns_rms_across_sequence (line 188) | def test_returns_rms_across_sequence(self, pooling_layer):
    method test_handles_attention_mask (line 197) | def test_handles_attention_mask(self, pooling_layer):
    method test_output_shape (line 207) | def test_output_shape(self, pooling_layer):
  class TestGlobalAbsMaxPooling1D (line 216) | class TestGlobalAbsMaxPooling1D:
    method pooling_layer (line 220) | def pooling_layer(self):
    method test_returns_abs_max_across_sequence (line 224) | def test_returns_abs_max_across_sequence(self, pooling_layer):
    method test_handles_attention_mask (line 233) | def test_handles_attention_mask(self, pooling_layer):
    method test_output_shape (line 243) | def test_output_shape(self, pooling_layer):
  class TestGlobalAbsAvgPooling1D (line 252) | class TestGlobalAbsAvgPooling1D:
    method pooling_layer (line 256) | def pooling_layer(self):
    method test_returns_abs_avg_across_sequence (line 260) | def test_returns_abs_avg_across_sequence(self, pooling_layer):
    method test_handles_attention_mask (line 269) | def test_handles_attention_mask(self, pooling_layer):
    method test_output_shape (line 279) | def test_output_shape(self, pooling_layer):
  class TestPassPooling1D (line 288) | class TestPassPooling1D:
    method pooling_layer (line 292) | def pooling_layer(self):
    method test_returns_input_unchanged (line 296) | def test_returns_input_unchanged(self, pooling_layer):
    method test_ignores_attention_mask (line 304) | def test_ignores_attention_mask(self, pooling_layer):
    method test_maintains_shape (line 313) | def test_maintains_shape(self, pooling_layer):

FILE: tests/test_scorers.py
  class TestScorerWeightedDot (line 15) | class TestScorerWeightedDot:
    method scorer (line 17) | def scorer(self):
    method test_forward_pass (line 20) | def test_forward_pass(self, scorer):
    method test_gradient_flow (line 29) | def test_gradient_flow(self, scorer):
  class TestScorerDot (line 41) | class TestScorerDot:
    method scorer (line 43) | def scorer(self):
    method test_forward_pass (line 46) | def test_forward_pass(self, scorer):
    method test_gradient_flow (line 55) | def test_gradient_flow(self, scorer):
  class TestMLPScorer (line 67) | class TestMLPScorer:
    method scorer (line 69) | def scorer(self):
    method test_forward_pass (line 72) | def test_forward_pass(self, scorer):
    method test_different_batch_sizes (line 81) | def test_different_batch_sizes(self, scorer):
    method test_gradient_flow (line 90) | def test_gradient_flow(self, scorer):
  class TestHopfieldScorer (line 102) | class TestHopfieldScorer:
    method scorer (line 104) | def scorer(self):
    method test_forward_pass (line 107) | def test_forward_pass(self, scorer):
    method test_multiple_iterations (line 116) | def test_multiple_iterations(self):
    method test_gradient_flow (line 125) | def test_gradient_flow(self, scorer):
  class TestCrossAttnScorer (line 137) | class TestCrossAttnScorer:
    method scorer (line 139) | def scorer(self):
    method test_forward_pass_with_text_mask (line 142) | def test_forward_pass_with_text_mask(self, scorer):
    method test_forward_pass_without_text_mask (line 153) | def test_forward_pass_without_text_mask(self, scorer):
    method test_different_seq_lengths (line 162) | def test_different_seq_lengths(self, scorer):
    method test_gradient_flow (line 171) | def test_gradient_flow(self, scorer):
    method test_eval_mode (line 182) | def test_eval_mode(self, scorer):

FILE: tests/test_utils.py
  class TestIsModuleAvailable (line 9) | class TestIsModuleAvailable:
    method test_detects_installed_module (line 12) | def test_detects_installed_module(self):
    method test_detects_missing_module (line 17) | def test_detects_missing_module(self):
    method test_handles_submodules (line 21) | def test_handles_submodules(self):
  class TestRetrievalAugmentedText (line 26) | class TestRetrievalAugmentedText:
    method test_with_structured_examples (line 29) | def test_with_structured_examples(self):
    method test_empty_examples_returns_original_text (line 42) | def test_empty_examples_returns_original_text(self):
    method test_includes_true_label_markers (line 51) | def test_includes_true_label_markers(self):
  class TestDefaultF1Reward (line 62) | class TestDefaultF1Reward:
    method sample_inputs (line 66) | def sample_inputs(self):
    method test_returns_tensor (line 77) | def test_returns_tensor(self, sample_inputs):
    method test_output_shape (line 83) | def test_output_shape(self, sample_inputs):
    method test_perfect_predictions (line 89) | def test_perfect_predictions(self):
    method test_zero_f1_for_wrong_predictions (line 100) | def test_zero_f1_for_wrong_predictions(self):
    method test_handles_valid_mask (line 111) | def test_handles_valid_mask(self):

FILE: train.py
  class CustomTrainer (line 20) | class CustomTrainer(Trainer):
    method __init__ (line 23) | def __init__(self, *args, use_weighted_sampling=False, **kwargs):
    method _get_train_sampler (line 27) | def _get_train_sampler(self, train_dataset) -> torch.utils.data.Sampler:
  function compute_metrics (line 38) | def compute_metrics(p, problem_type='multi_label_classification'):
  function load_dataset (line 78) | def load_dataset(data_path: str) -> list:
  function main (line 92) | def main(args):

FILE: train_rl.py
  function accuracy_reward (line 20) | def accuracy_reward(probs, actions, targets, valid_mask):
  function recall_reward (line 27) | def recall_reward(
  function compute_metrics (line 43) | def compute_metrics(p):
  function main (line 71) | def main(args):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (390K chars).
[
  {
    "path": ".github/workflows/release.yaml",
    "chars": 1962,
    "preview": "name: Release GLiClass to PyPI\n\non:\n  push:\n    tags:\n      - 'v*'  # Trigger on version tags (e.g., v1.0.0, v2.1.3)\n\nco"
  },
  {
    "path": ".github/workflows/tests.yml",
    "chars": 1580,
    "preview": "name: Tests\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n  workflow_dispatch:\n\ncon"
  },
  {
    "path": ".gitignore",
    "chars": 3236,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n#custom\nmodels/\nwandb/\ng"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 10900,
    "preview": "# ⭐ GLiClass: Generalist and Lightweight Model for Sequence Classification\n\n**GLiClass** is an efficient, zero-shot sequ"
  },
  {
    "path": "demo.py",
    "chars": 41524,
    "preview": "\"\"\"\nGLiClass Enhanced Demo with Advanced Features\n\nFeatures:\n- Task description prompts\n- Hierarchical label inputs (JSO"
  },
  {
    "path": "gliclass/__init__.py",
    "chars": 418,
    "preview": "from .model import GLiClassModel, GLiClassBiEncoder, GLiClassUniEncoder, GLiClassEncoderDecoderCLS\nfrom .config import G"
  },
  {
    "path": "gliclass/config.py",
    "chars": 5522,
    "preview": "from transformers import AutoConfig\nfrom transformers.utils import logging\nfrom transformers.models.auto import CONFIG_M"
  },
  {
    "path": "gliclass/data_processing.py",
    "chars": 17349,
    "preview": "import copy\nimport random\nfrom dataclasses import dataclass\n\nimport torch\nfrom torch.utils.data import Dataset\nfrom torc"
  },
  {
    "path": "gliclass/layers.py",
    "chars": 10366,
    "preview": "# Copyright 2020 Microsoft and the Hugging Face Inc. team and Knowledgator.\n#\n# Licensed under the Apache License, Versi"
  },
  {
    "path": "gliclass/loss_functions.py",
    "chars": 3851,
    "preview": "import torch\nimport torch.nn.functional as F\n\n\ndef sequence_contrastive_loss(embeddings, mask):\n    # embeddings shape: "
  },
  {
    "path": "gliclass/model.py",
    "chars": 47469,
    "preview": "import os\nimport warnings\nfrom typing import Tuple\nfrom pathlib import Path\nfrom dataclasses import dataclass\n\nimport to"
  },
  {
    "path": "gliclass/ops.py",
    "chars": 1180,
    "preview": "import torch\nimport torch.nn.functional as F\n\n# ─── Attention (padded) ─────────────────────────────────────────────────"
  },
  {
    "path": "gliclass/pipeline.py",
    "chars": 48526,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, List\n\nimport torch\nfrom tqdm import tqdm\nfrom transfor"
  },
  {
    "path": "gliclass/poolings.py",
    "chars": 3154,
    "preview": "import torch\nfrom torch import nn\n\n\nclass GlobalMaxPooling1D(nn.Module):\n    \"\"\"Applies Global Max Pooling on the timest"
  },
  {
    "path": "gliclass/scorers.py",
    "chars": 6508,
    "preview": "import torch\nfrom torch import nn\n\nfrom .ops import attn_padded\n\n\nclass ScorerWeightedDot(nn.Module):\n    def __init__(s"
  },
  {
    "path": "gliclass/serve/__init__.py",
    "chars": 407,
    "preview": "\"\"\"GLiClass serving module.\"\"\"\n\nfrom .client import GLiClassClient\nfrom .config import GLiClassServeConfig\nfrom .memory "
  },
  {
    "path": "gliclass/serve/__main__.py",
    "chars": 6961,
    "preview": "\"\"\"CLI entry point for GLiClass serving.\"\"\"\n\nimport sys\nimport signal\nimport logging\nimport argparse\n\nimport ray\nfrom ra"
  },
  {
    "path": "gliclass/serve/client.py",
    "chars": 3428,
    "preview": "\"\"\"Client for GLiClass serving endpoint.\"\"\"\n\nimport requests\n\n\nclass GLiClassClient:\n    \"\"\"Client for interacting with "
  },
  {
    "path": "gliclass/serve/config.py",
    "chars": 3302,
    "preview": "\"\"\"Configuration for GLiClass Ray Serve deployment.\"\"\"\n\nfrom pathlib import Path\nfrom dataclasses import field, asdict, "
  },
  {
    "path": "gliclass/serve/memory.py",
    "chars": 5778,
    "preview": "\"\"\"Memory estimation for GLiClass via precomputed calibration table.\n\nStartup calibration runs the model on probe batche"
  },
  {
    "path": "gliclass/serve/server.py",
    "chars": 17381,
    "preview": "\"\"\"Ray Serve deployment for GLiClass with dynamic batching.\"\"\"\n\nimport os\nimport logging\nfrom typing import Any\n\nimport "
  },
  {
    "path": "gliclass/training.py",
    "chars": 34982,
    "preview": "import os\nfrom typing import Any, Dict, List, Tuple, Callable\nfrom dataclasses import field, dataclass\n\nimport numpy as "
  },
  {
    "path": "gliclass/utils.py",
    "chars": 3932,
    "preview": "import torch\n\n\ndef is_module_available(module_name):\n    \"\"\"\n    Checks whether the specified Python module is available"
  },
  {
    "path": "notebooks/finetuning.ipynb",
    "chars": 8995,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "pyproject.toml",
    "chars": 4481,
    "preview": "[build-system]\nrequires = [\"setuptools>=61.0.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.setuptools.packages.find"
  },
  {
    "path": "serve_configs/serve_config.yaml",
    "chars": 1248,
    "preview": "# Model configuration\nmodel: knowledgator/gliclass-edge-v3.0\ndevice: cuda\ndtype: float16\n\n# Limits\nmax_model_len: 2048\nm"
  },
  {
    "path": "test_gliclass.py",
    "chars": 7806,
    "preview": "from gliclass import GLiClassModel, ZeroShotClassificationPipeline\nfrom transformers import AutoTokenizer\nfrom datasets "
  },
  {
    "path": "tests/test_data_processing.py",
    "chars": 4650,
    "preview": "\"\"\"Tests for gliclass.data_processing module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.data_processing import pad_2"
  },
  {
    "path": "tests/test_loss_functions.py",
    "chars": 7025,
    "preview": "\"\"\"Tests for gliclass.loss_functions module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.loss_functions import sequenc"
  },
  {
    "path": "tests/test_poolings.py",
    "chars": 10182,
    "preview": "\"\"\"Tests for gliclass.poolings module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.poolings import (\n    GlobalMaxPool"
  },
  {
    "path": "tests/test_scorers.py",
    "chars": 5367,
    "preview": "\"\"\"Tests for gliclass.scorers module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.scorers import (\n    ScorerWeightedD"
  },
  {
    "path": "tests/test_utils.py",
    "chars": 4415,
    "preview": "\"\"\"Tests for gliclass.utils module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.utils import is_module_available, retr"
  },
  {
    "path": "train.py",
    "chars": 16566,
    "preview": "import os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\nimport numpy as np\nimport argparse\nimport json\n\nfrom sklearn.met"
  },
  {
    "path": "train_rl.py",
    "chars": 11175,
    "preview": "import os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\nimport numpy as np\nimport argparse\nimport json\n\nfrom sklearn.met"
  }
]

About this extraction

This page contains the full source code of the Knowledgator/GLiClass GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (364.2 KB), approximately 83.2k tokens, and a symbol index with 409 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!