Showing preview only (379K chars total). Download the full file or copy to clipboard to get everything.
Repository: Knowledgator/GLiClass
Branch: main
Commit: 6ccf83fe5130
Files: 35
Total size: 364.2 KB
Directory structure:
gitextract_2n3b6llp/
├── .github/
│ └── workflows/
│ ├── release.yaml
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── demo.py
├── gliclass/
│ ├── __init__.py
│ ├── config.py
│ ├── data_processing.py
│ ├── layers.py
│ ├── loss_functions.py
│ ├── model.py
│ ├── ops.py
│ ├── pipeline.py
│ ├── poolings.py
│ ├── scorers.py
│ ├── serve/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── memory.py
│ │ └── server.py
│ ├── training.py
│ └── utils.py
├── notebooks/
│ └── finetuning.ipynb
├── pyproject.toml
├── serve_configs/
│ └── serve_config.yaml
├── test_gliclass.py
├── tests/
│ ├── test_data_processing.py
│ ├── test_loss_functions.py
│ ├── test_poolings.py
│ ├── test_scorers.py
│ └── test_utils.py
├── train.py
└── train_rl.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/release.yaml
================================================
name: Release GLiClass to PyPI
on:
push:
tags:
- 'v*' # Trigger on version tags (e.g., v1.0.0, v2.1.3)
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build:
name: Build distribution 📦
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.x"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v5
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/project/gliclass/ # Replace <package-name> with your PyPI project name
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
fetch-depth: 0 # Fetch all history to check branches
- name: Verify tag is on main branch
run: |
if ! git branch -r --contains ${{ github.ref_name }} | grep -q 'origin/main'; then
echo "Error: Tag ${{ github.ref_name }} is not on the main branch"
exit 1
fi
echo "✓ Tag ${{ github.ref_name }} is on main branch"
- name: Download all the dists
uses: actions/download-artifact@v6
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
================================================
FILE: .github/workflows/tests.yml
================================================
name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test:
name: pytest (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('pyproject.toml') }}
restore-keys: |
${{ runner.os }}-py${{ matrix.python-version }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pytest pytest-asyncio
- name: Run pytest
run: pytest -v --tb=short
lint:
name: ruff
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install ruff
run: pip install ruff
- name: ruff check
run: ruff check gliclass
- name: ruff format --check
run: ruff format --check gliclass
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
#custom
models/
wandb/
gradio_cached_examples/
test.ipynb
demo1.py
.gradio/
uv.lock
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# ⭐ GLiClass: Generalist and Lightweight Model for Sequence Classification
**GLiClass** is an efficient, zero-shot sequence classification model inspired by the [GLiNER](https://github.com/urchade/GLiNER/tree/main) framework. It achieves comparable performance to traditional cross-encoder models while being significantly more computationally efficient, offering classification results approximately **10 times faster** by performing classification in a single forward pass.
<p align="center">
<a href="https://medium.com/@knowledgrator/pushing-zero-shot-classification-to-the-limit-696a2403032f">📄 Blog</a>
<span> • </span>
<a href="https://discord.gg/dkyeAgs9DG">📢 Discord</a>
<span> • </span>
<a href="https://huggingface.co/spaces/knowledgator/GLiClass_SandBox">📺 Demo</a>
<span> • </span>
<a href="https://huggingface.co/models?sort=trending&search=gliclass">🤗 Available models</a>
<span> • </span>
<a href="https://colab.research.google.com/github/Knowledgator/GLiClass/blob/main/finetuning.ipynb">
<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />
</a>
</p>
### 🚀 Quick Start
Install GLiClass easily using pip:
```bash
pip install gliclass
```
#### Install from Source
Clone and install directly from GitHub:
```bash
git clone https://github.com/Knowledgator/GLiClass
cd GLiClass
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
pip install -r requirements.txt
pip install .
```
Verify your installation:
```python
import gliclass
print(gliclass.__version__)
```
### 🧑💻 Usage Example
```python
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1.0")
pipeline = ZeroShotClassificationPipeline(
model, tokenizer, classification_type='multi-label', device='cuda:0'
)
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(f"{result['label']} => {result['score']:.3f}")
```
### 🔥 New Features
#### Hierarchical Labels
GLiClass now supports hierarchical label structures using dot notation:
```python
hierarchical_labels = {
"sentiment": ["positive", "negative", "neutral"],
"topic": ["product", "service", "shipping"]
}
text = "The product quality is amazing but delivery was slow"
results = pipeline(text, hierarchical_labels, threshold=0.5)[0]
for result in results:
print(f"{result['label']} => {result['score']:.3f}")
# Output:
# sentiment.positive => 0.892
# topic.product => 0.921
# topic.shipping => 0.763
```
Get hierarchical output matching your input structure:
```python
results = pipeline(text, hierarchical_labels, return_hierarchical=True)[0]
print(results)
# Output:
# {
# "sentiment": {"positive": 0.892, "negative": 0.051, "neutral": 0.124},
# "topic": {"product": 0.921, "service": 0.153, "shipping": 0.763}
# }
```
#### Few-Shot Examples
Improve classification accuracy with in-context examples using the `<<EXAMPLE>>` token:
```python
examples = [
{
"text": "Love this item, great quality!",
"labels": ["positive", "product"]
},
{
"text": "Customer support was unhelpful",
"labels": ["negative", "service"]
}
]
text = "Fast delivery and the item works perfectly!"
labels = ["positive", "negative", "product", "service", "shipping"]
results = pipeline(text, labels, examples=examples, threshold=0.5)[0]
for result in results:
print(f"{result['label']} => {result['score']:.3f}")
```
#### Task Description Prompts
Add custom prompts to guide the classification task:
```python
text = "The battery life on this phone is incredible"
labels = ["positive", "negative", "neutral"]
results = pipeline(
text,
labels,
prompt="Classify the sentiment of this product review:",
threshold=0.5
)[0]
```
Use per-text prompts for batch processing:
```python
texts = ["Review about electronics", "Review about clothing"]
prompts = [
"Analyze this electronics review:",
"Analyze this clothing review:"
]
results = pipeline(texts, labels, prompt=prompts)
```
#### Long Document Classification
Process long documents with automatic text chunking:
```python
from gliclass import ZeroShotClassificationWithChunkingPipeline
chunking_pipeline = ZeroShotClassificationWithChunkingPipeline(
model,
tokenizer,
text_chunk_size=8192,
text_chunk_overlap=256,
labels_chunk_size=8
)
long_document = "..." # Very long text
labels = ["category1", "category2", "category3"]
results = chunking_pipeline(long_document, labels, threshold=0.5)
```
### 🌟 Retrieval-Augmented Classification (RAC)
With new models trained with retrieval-agumented classification, such as [this model](https://huggingface.co/knowledgator/gliclass-base-v2.0-rac-init) you can specify examples to improve classification accuracy:
```python
example = {
"text": "A new machine learning platform automates complex data workflows but faces integration issues.",
"all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
"true_labels": ["AI", "integration", "automation"]
}
text = "The new AI-powered tool streamlines data analysis but has limited integration capabilities."
labels = ["AI", "automation", "data_analysis", "usability", "integration"]
results = pipeline(text, labels, threshold=0.1, rac_examples=[example])[0]
for predict in results:
print(f"{predict['label']} => {predict['score']:.3f}")
```
### 🚀 Production Serving
Deploy GLiClass with Ray Serve for production workloads with dynamic batching and memory-aware processing.
#### Installation
```bash
pip install gliclass[serve]
```
#### Quick Start
```bash
# Default model
python -m gliclass.serve
# Specify model and port
python -m gliclass.serve --model knowledgator/gliclass-edge-v3.0 --port 8000
# With config file
python -m gliclass.serve --config serve_configs/serve_config.yaml
```
#### Python Client
```python
from gliclass.serve import GLiClassClient
client = GLiClassClient(url="http://localhost:8000/gliclass")
result = client.classify(
text="This is a great product!",
labels=["positive", "negative", "neutral"],
threshold=0.3,
)
print(result) # [{"label": "positive", "score": 0.95}, ...]
```
#### HTTP API
The HTTP endpoint processes one text per request.
```bash
curl -X POST http://localhost:8000/gliclass \
-H "Content-Type: application/json" \
-d '{
"texts": "This is a great product!",
"labels": ["positive", "negative", "neutral"],
"threshold": 0.3
}'
# Response: [{"label": "positive", "score": 0.95}, ...]
```
**Note:** For batch processing multiple texts, use the `ZeroShotClassificationPipeline` directly instead of the serving API.
See `serve_configs/serve_config.yaml` for full configuration options.
### 🎯 Key Use Cases
- **Sentiment Analysis:** Rapidly classify texts as positive, negative, or neutral.
- **Document Classification:** Efficiently organize and categorize large document collections.
- **Search Results Re-ranking:** Improve relevance and precision by reranking search outputs.
- **News Categorization:** Automatically tag and organize news articles into predefined categories.
- **Fact Checking:** Quickly validate and categorize statements based on factual accuracy.
### 🛠️ How to Train
Prepare your training data as follows:
```json
[
{"text": "Sample text.", "all_labels": ["sports", "science", "business"], "true_labels": ["sports"]},
...
]
```
Optionally, specify confidence scores explicitly:
```json
[
{"text": "Sample text.", "all_labels": ["sports", "science"], "true_labels": {"sports": 0.9}},
...
]
```
Please, refer to the `train.py` script to set up your training from scratch or fine-tune existing models.
### ⚙️ Advanced Configuration
#### Architecture Types
GLiClass supports multiple architecture types:
- **uni-encoder**: Single encoder for both text and labels (default, most efficient)
- **bi-encoder**: Separate encoders for text and labels
- **bi-encoder-fused**: Bi-encoder with label embeddings fused into text encoding
- **encoder-decoder**: Encoder-decoder architecture for sequence-to-sequence tasks
```python
from gliclass import GLiClassBiEncoder
# Load a bi-encoder model
model = GLiClassBiEncoder.from_pretrained("knowledgator/gliclass-biencoder-v1.0")
```
#### Pooling Strategies
Configure how token embeddings are pooled:
- `first`: First token (CLS token)
- `avg`: Average pooling
- `max`: Max pooling
- `last`: Last token
- `sum`: Sum pooling
- `rms`: Root mean square pooling
- `abs_max`: Max of absolute values
- `abs_avg`: Average of absolute values
```python
from gliclass import GLiClassModelConfig
config = GLiClassModelConfig(
pooling_strategy='avg',
class_token_pooling='average' # or 'first'
)
```
#### Scoring Mechanisms
Choose different scoring mechanisms for classification:
- `simple`: Dot product (fastest)
- `weighted-dot`: Weighted dot product with learned projections
- `mlp`: Multi-layer perceptron scorer
- `hopfield`: Hopfield network-based scorer
```python
config = GLiClassModelConfig(
scorer_type='mlp'
)
```
---
### Flash Attention Backends
GLiClass supports optional flash attention backends for faster inference.
#### Install
```bash
pip install flashdeberta # DeBERTa v2
pip install turbot5 # T5 / mT5
```
---
#### FlashDeBERTa (DeBERTa v2)
Enable via environment variable:
```bash
export USE_FLASHDEBERTA=1
```
If `flashdeberta` is installed, DeBERTa v2 models will use `FlashDebertaV2Model`.
Otherwise, GLiClass falls back to `DebertaV2Model`.
---
#### TurboT5 (T5 / mT5)
Enable via environment variable:
```bash
export TURBOT5_ATTN_TYPE=triton-basic
```
If `turbot5` is installed, T5 / mT5 models will use `FlashT5EncoderModel`.
Otherwise, GLiClass falls back to `T5EncoderModel`.
Notes:
* Flash backends are **optional**
* Enabled automatically when available
* No code changes required
Want it even tighter (single block), or is this the sweet spot?
## 📚 Citations
If you find GLiClass useful in your research or project, please cite our papers:
```bibtex
@misc{stepanov2025gliclassgeneralistlightweightmodel,
title={GLiClass: Generalist Lightweight Model for Sequence Classification Tasks},
author={Ihor Stepanov and Mykhailo Shtopko and Dmytro Vodianytskyi and Oleksandr Lukashov and Alexander Yavorskyi and Mykyta Yaroshenko},
year={2025},
eprint={2508.07662},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2508.07662},
}
```
================================================
FILE: demo.py
================================================
"""
GLiClass Enhanced Demo with Advanced Features
Features:
- Task description prompts
- Hierarchical label inputs (JSON format)
- Few-shot examples
- Hierarchical output structure
- Label descriptions
"""
import json
from typing import Dict, List, Any, Union, Optional
import gradio as gr
import torch
from transformers import AutoTokenizer
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
# Initialize model and pipeline
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model_path = "knowledgator/gliclass-small-v1.0"
model = GLiClassModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
pipeline = ZeroShotClassificationPipeline(
model, tokenizer,
classification_type='multi-label',
device=device
)
# ============== Example Texts ==============
TEXT_PRODUCT_REVIEW = """
I recently purchased the Sony WH-1000XM4 Wireless Noise-Canceling Headphones from Amazon and I must say, I'm thoroughly impressed. The package arrived in New York within 2 days, thanks to Amazon Prime's expedited shipping.
The headphones themselves are remarkable. The noise-canceling feature works like a charm in the bustling city environment, and the 30-hour battery life means I don't have to charge them every day. Connecting them to my Samsung Galaxy S21 was a breeze, and the sound quality is second to none.
I also appreciated the customer service from Amazon when I had a question about the warranty. They responded within an hour and provided all the information I needed.
However, the headphones did not come with a hard case, which was listed in the product description. I contacted Amazon, and they offered a 10% discount on my next purchase as an apology.
Overall, I'd give these headphones a 4.5/5 rating and highly recommend them to anyone looking for top-notch quality in both product and service.
"""
TEXT_TECH_COMPANIES = """
Apple Inc. is an American multinational technology company headquartered in Cupertino, California. Apple is the world's largest technology company by revenue, with US$394.3 billion in 2022 revenue. As of March 2023, Apple is the world's biggest company by market capitalization.
Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975 to develop and sell BASIC interpreters for the Altair 8800. During his career at Microsoft, Gates held the positions of chairman, chief executive officer, president and chief software architect.
Apple was founded as Apple Computer Company on April 1, 1976, by Steve Wozniak, Steve Jobs (1955–2011) and Ronald Wayne to develop and sell Wozniak's Apple I personal computer.
"""
TEXT_SCIENTIFIC = """
Several studies have reported its pharmacological activities, including anti-inflammatory, antimicrobial, and antitumoral effects.
The effect of E-anethole was studied in the osteosarcoma MG-63 cell line, and the antiproliferative activity was evaluated by an MTT assay.
It showed a GI50 value of 60.25 μM with apoptosis induction through the mitochondrial-mediated pathway. Additionally, it induced cell cycle arrest at the G0/G1 phase, up-regulated the expression of p53, caspase-3, and caspase-9, and down-regulated Bcl-xL expression.
"""
TEXT_RESTAURANT_REVIEW = """
We visited La Maison last Friday for our anniversary dinner. The ambiance was absolutely stunning - dim lighting, soft jazz music, and elegant table settings. Our waiter, Marcus, was incredibly attentive without being intrusive.
For appetizers, we had the truffle bruschetta and the soup of the day. Both were divine! The main courses - filet mignon for me and lobster risotto for my wife - were cooked to perfection.
The only downside was the wait time for our desserts, which took about 25 minutes. However, the chocolate soufflé was worth the wait!
Price was on the higher side ($180 for two), but the quality justified the cost. Will definitely return!
"""
TEXT_NEWS_POLITICS = """
The Senate passed a landmark bipartisan infrastructure bill late Thursday night, allocating $1.2 trillion for roads, bridges, broadband internet, and clean energy initiatives. The vote was 69-30, with 19 Republican senators joining all Democrats in support.
President Biden called the passage "a historic investment in America's future" and urged the House to act quickly. However, progressive Democrats have signaled they won't vote for the infrastructure bill unless it's paired with a larger social spending package.
Senate Minority Leader criticized portions of the bill related to climate spending, calling them "unnecessary green new deal provisions," while environmental groups praised the clean energy investments as "a step in the right direction, but not nearly enough."
"""
TEXT_SPORTS = """
In a thrilling overtime finish, the Lakers defeated the Celtics 118-112 in Game 7 of the NBA Finals. LeBron James delivered a historic performance with 42 points, 16 rebounds, and 10 assists, securing his fifth championship ring and fourth Finals MVP award.
The game was tied at 102 with 30 seconds remaining in regulation when Marcus Smart hit a contested three-pointer. However, James answered with a driving layup at the buzzer to force overtime.
In the extra period, the Lakers outscored Boston 16-10, with Anthony Davis contributing two crucial blocks in the final minute. "This is what you dream about as a kid," James said in the post-game interview. "Playing against the Celtics, Game 7, everything on the line."
"""
TEXT_MOVIE_REVIEW = """
Christopher Nolan's "Oppenheimer" is a masterwork of biographical cinema that demands to be seen on the largest screen possible. Cillian Murphy delivers a career-defining performance as J. Robert Oppenheimer, capturing both the brilliance and moral anguish of the father of the atomic bomb.
The film's nonlinear structure, weaving between the Manhattan Project, the 1954 security hearing, and the 1959 Lewis Strauss confirmation hearing, could have been confusing. Instead, Nolan crafts a compelling narrative that builds to a devastating emotional climax.
At three hours, some viewers may find the pacing challenging, particularly in the courtroom sequences. However, the technical achievements - Ludwig Göransson's haunting score, Hoyte van Hoytema's IMAX cinematography - make this an unmissable theatrical experience. Rating: 9/10
"""
TEXT_TECH_STARTUP = """
San Francisco-based AI startup Anthropic announced today it has raised $450 million in Series C funding, valuing the company at $5 billion. The round was led by Spark Capital, with participation from Google and existing investors.
Founded in 2021 by former OpenAI researchers Dario and Daniela Amodei, Anthropic has positioned itself as a leader in AI safety research. The company's Claude assistant has gained significant market share in the enterprise segment.
"This funding will accelerate our research into interpretable and steerable AI systems," said CEO Dario Amodei. "We believe safety and capability go hand in hand." The company plans to double its research team and expand internationally, with offices planned in London and Tokyo.
"""
TEXT_HEALTH_WELLNESS = """
A new study published in the Journal of the American Medical Association suggests that intermittent fasting may offer significant benefits beyond weight loss. Researchers followed 500 participants over two years and found improvements in cardiovascular health markers, insulin sensitivity, and cognitive function.
Participants who followed a 16:8 fasting protocol (eating within an 8-hour window) showed a 15% reduction in LDL cholesterol and a 20% improvement in fasting glucose levels compared to the control group.
However, experts caution that intermittent fasting isn't suitable for everyone. "Pregnant women, people with a history of eating disorders, and those with certain medical conditions should consult their doctor first," said Dr. Sarah Chen, the study's lead author. "It's not a magic solution, but for many people, it can be a sustainable approach to improving metabolic health."
"""
TEXT_TRAVEL = """
Hidden among the limestone karsts of Ha Long Bay, Cat Ba Island offers travelers an authentic Vietnamese experience away from the tourist crowds. We spent five days exploring this gem and discovered why it's becoming a favorite among backpackers and adventure seekers.
The island's national park features challenging hikes through tropical rainforest, with the trek to the peak of Ngu Lam offering panoramic views of the bay. We also kayaked through hidden lagoons and explored caves that few tourists ever see.
Accommodation ranges from basic hostels ($8/night) to comfortable eco-resorts ($60/night). The seafood is incredibly fresh - we had the best grilled squid of our lives at a family-run restaurant in Cat Ba Town for just $5. Pro tip: rent a motorbike to explore the quieter beaches on the island's east side.
"""
TEXT_COOKING_RECIPE = """
This Thai green curry comes together in just 30 minutes and tastes better than takeout. The secret is making your own curry paste - it takes an extra 10 minutes but the flavor difference is remarkable.
For the paste, blend together: 10 green chilies, 4 garlic cloves, 2 shallots, 1 stalk lemongrass, 1 inch galangal, handful of cilantro stems, 1 tsp cumin, 1 tsp coriander, zest of 1 lime, and 2 tbsp fish sauce.
Heat coconut oil in a wok, fry the paste for 2 minutes until fragrant. Add chicken (or tofu), cook until browned. Pour in coconut milk, add bamboo shoots, Thai eggplant, and basil. Simmer for 15 minutes. Season with palm sugar and more fish sauce to taste.
Serve over jasmine rice with extra chilies on the side. This recipe serves 4 and can be made ahead - the flavors actually improve overnight.
"""
TEXT_FINANCIAL_ADVICE = """
With inflation running at 4.2% and the Fed signaling more rate hikes, many investors are wondering how to position their portfolios. Here's what our analysis suggests for Q4 2024.
Fixed income is finally attractive again. With 10-year Treasury yields above 4.5%, bonds offer meaningful real returns for the first time in years. We recommend increasing allocation to investment-grade corporate bonds and TIPS for inflation protection.
For equities, we're cautiously optimistic on value stocks, particularly in the energy and financial sectors. Tech valuations remain stretched despite recent pullbacks. International developed markets, especially Japan and Europe, offer better risk-reward at current levels.
Remember: past performance doesn't guarantee future results. This is general information, not personalized advice. Consult a financial advisor before making investment decisions.
"""
TEXT_ENVIRONMENTAL = """
The Great Barrier Reef experienced its sixth mass bleaching event in a decade this summer, with aerial surveys showing 91% of reefs affected. Scientists warn that without dramatic action on climate change, the world's largest coral ecosystem may not survive beyond 2050.
"We're witnessing the collapse of one of Earth's most biodiverse ecosystems in real time," said Dr. Terry Hughes of James Cook University. Water temperatures reached 2°C above the February average, causing corals to expel the symbiotic algae that give them color and nutrients.
Some researchers are experimenting with heat-resistant coral varieties and cloud-brightening technology to shade reefs. However, most scientists agree these are stopgap measures. "The only real solution is rapid decarbonization," Hughes said. "Everything else is just buying time."
"""
TEXT_EDUCATION = """
The debate over standardized testing in American schools has intensified following a new report showing significant post-pandemic learning gaps. The National Assessment of Educational Progress found that fourth-grade math scores dropped to levels not seen since 2005.
Proponents of testing argue that standardized assessments are essential for identifying struggling students and holding schools accountable. "Without data, we're flying blind," said Education Secretary Miguel Cardona. "Tests help us direct resources where they're needed most."
Critics counter that high-stakes testing narrows the curriculum and increases student stress without improving outcomes. "We're testing kids more than ever, but educational outcomes aren't improving," said education researcher Dr. Pasi Sahlberg. "Countries like Finland, which use minimal standardized testing, consistently outperform the US."
"""
TEXT_FASHION = """
Milan Fashion Week wrapped up yesterday with several surprising trends that will likely dominate fall/winter 2025. After years of quiet luxury and minimalism, designers are embracing bold maximalism - think dramatic volumes, clashing prints, and unapologetic color.
Prada's collection featured oversized coats with exaggerated shoulders paired with flowing silk pants, while Gucci returned to its pattern-mixing roots under new creative direction. Versace went full baroque with gold-embroidered gowns that would feel at home in a Renaissance painting.
Sustainability remained a talking point, with Stella McCartney showcasing a collection made entirely from recycled ocean plastic. However, critics noted that the industry still has far to go. "One sustainable collection doesn't offset the environmental impact of fast fashion," noted fashion journalist Vanessa Friedman. "The industry needs systemic change, not just good PR."
"""
TEXT_LEGAL_CASE = """
The Supreme Court agreed Monday to hear a case that could reshape the boundaries of free speech on social media platforms. The case, NetChoice v. Paxton, challenges Texas and Florida laws that prohibit large social media companies from removing certain political content.
Tech companies argue that the First Amendment protects their right to moderate content as they see fit, similar to how newspapers decide what to publish. "Forcing platforms to host speech they find objectionable is compelled speech, which the Constitution forbids," said NetChoice counsel Paul Clement.
Texas and Florida counter that social media platforms function as common carriers or public utilities and should be subject to similar non-discrimination requirements. "These companies have become the modern public square," said Texas Attorney General Ken Paxton. "They shouldn't be able to silence voices based on political viewpoint."
"""
TEXT_GAMING = """
After three years in development hell, "Hollow Eclipse" has finally launched - and it's everything fans hoped for. This action RPG from indie studio Moonlight Games delivers a haunting 40-hour adventure that rivals titles from studios with ten times the budget.
The combat system strikes a perfect balance between accessibility and depth. Basic attacks and dodges are simple to execute, but mastering the "shadow merge" mechanic - which lets you temporarily possess enemies - adds layers of strategy. Boss fights are challenging without feeling unfair, though the final boss may take even experienced players dozens of attempts.
Where the game truly shines is its atmosphere. The decaying gothic city of Velmoor is rendered in stunning hand-drawn art, and the ambient soundtrack creates constant unease. The story tackles themes of grief and memory with surprising emotional maturity. Minor technical issues (occasional frame drops, one softlock) can't diminish this achievement. Score: 9.5/10
"""
TEXT_REAL_ESTATE = """
The housing market is sending mixed signals as we enter 2025. Existing home sales fell for the third consecutive month, down 4.1% in November, yet prices continue to climb in most metropolitan areas. The median home price hit $416,000, up 3.8% year-over-year.
Low inventory remains the central issue. Many homeowners are reluctant to sell because they've locked in sub-3% mortgage rates and don't want to trade up to today's 7% rates. This "lock-in effect" has created a severe shortage of listings, particularly in the starter home category.
"We're seeing bidding wars even in this high-rate environment because there's simply nothing to buy," said economist Lawrence Yun. First-time buyers are particularly squeezed, with affordability at its worst level since 1984. Some markets, including Austin and Phoenix, are showing price corrections, but coastal cities remain stubbornly expensive.
"""
TEXT_MENTAL_HEALTH = """
Workplace burnout has reached epidemic proportions, with a new Gallup survey finding that 76% of employees experience burnout at least sometimes. But recognizing burnout isn't always straightforward - it often manifests differently than simple exhaustion.
The three hallmarks of burnout are: emotional exhaustion (feeling drained and unable to cope), depersonalization (becoming cynical and detached from work), and reduced personal accomplishment (feeling ineffective regardless of actual performance).
Recovery requires more than a vacation. "You can't just rest your way out of burnout," says psychologist Dr. Christina Maslach, who pioneered burnout research. "You need to address the root causes - usually workload, lack of control, insufficient recognition, or values conflicts." Strategies include setting firm boundaries, delegating tasks, and having honest conversations with managers about sustainable workloads. In severe cases, professional support from a therapist can help.
"""
TEXT_ASTRONOMY = """
NASA's James Webb Space Telescope has detected what may be signs of biological activity in the atmosphere of K2-18b, an exoplanet 120 light-years away. The discovery has electrified the scientific community, though researchers caution against jumping to conclusions.
The telescope's spectrometers identified dimethyl sulfide (DMS), a molecule produced almost exclusively by living organisms on Earth. Webb also confirmed the presence of methane and carbon dioxide, consistent with a water-rich atmosphere.
"This is tantalizing, but not definitive proof of life," said lead researcher Dr. Nikku Madhusudhan. "DMS could potentially be produced by unknown geological processes. We need more observations." K2-18b is a "Hycean" world - a planet with a hydrogen-rich atmosphere and potentially a liquid water ocean beneath. If confirmed, this would be humanity's first detection of a potential biosignature beyond our solar system.
"""
def parse_labels_input(labels_input: str) -> Union[List[str], Dict[str, Any]]:
"""
Parse labels input - supports both comma-separated and JSON hierarchical format.
Examples:
- "positive, negative, neutral" -> ["positive", "negative", "neutral"]
- '{"sentiment": ["positive", "negative"], "topic": ["food", "service"]}' -> dict
"""
labels_input = labels_input.strip()
# Try parsing as JSON first (for hierarchical labels)
if labels_input.startswith('{'):
try:
return json.loads(labels_input)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format for hierarchical labels: {e}")
# Otherwise, treat as comma-separated flat labels
labels = [label.strip() for label in labels_input.split(',') if label.strip()]
return labels
def parse_examples_input(examples_input: str) -> Optional[List[Dict[str, Any]]]:
"""
Parse few-shot examples input (JSON format).
Expected format:
[
{"text": "Example text 1", "labels": ["label1", "label2"]},
{"text": "Example text 2", "labels": ["label3"]}
]
"""
if not examples_input or not examples_input.strip():
return None
try:
examples = json.loads(examples_input.strip())
if not isinstance(examples, list):
raise ValueError("Examples must be a JSON array")
for i, ex in enumerate(examples):
if not isinstance(ex, dict):
raise ValueError(f"Example {i+1} must be a JSON object")
if 'text' not in ex:
raise ValueError(f"Example {i+1} missing 'text' field")
if 'labels' not in ex and 'true_labels' not in ex:
raise ValueError(f"Example {i+1} missing 'labels' field")
return examples
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format for examples: {e}")
def format_output(
results: Union[List[Dict], Dict],
hierarchical: bool = False,
output_format: str = "visual"
) -> Union[Dict[str, float], str]:
"""Format classification output for Gradio display."""
if output_format == "json":
return format_as_json(results, hierarchical)
if hierarchical and isinstance(results, dict):
# Format hierarchical output as readable string
return format_hierarchical_dict(results)
if isinstance(results, list):
return {result['label']: float(result['score']) for result in results}
return results
def format_as_json(results: Union[List[Dict], Dict], hierarchical: bool = False) -> str:
"""Format results as pretty-printed JSON string."""
if hierarchical and isinstance(results, dict):
# Already in hierarchical dict format
return json.dumps(results, indent=2, ensure_ascii=False)
if isinstance(results, list):
# Convert list of predictions to structured format
output = {
"predictions": [
{"label": r["label"], "score": round(r["score"], 4)}
for r in results
],
"scores": {r["label"]: round(r["score"], 4) for r in results}
}
return json.dumps(output, indent=2, ensure_ascii=False)
return json.dumps(results, indent=2, ensure_ascii=False)
def format_hierarchical_dict(d: Dict, indent: int = 0) -> str:
"""Format hierarchical dict for display with visual score bars."""
lines = []
prefix = " " * indent
for key, value in d.items():
if isinstance(value, dict):
lines.append(f"{prefix}**{key}**:")
lines.append(format_hierarchical_dict(value, indent + 1))
else:
score_bar = "█" * int(value * 20) + "░" * (20 - int(value * 20))
lines.append(f"{prefix}{key}: {score_bar} {value:.3f}")
return "\n".join(lines)
def classification(
text: str,
labels_input: str,
threshold: float,
multi_label: bool,
prompt: str,
examples_input: str,
hierarchical_output: bool,
output_format: str = "visual"
) -> Union[Dict[str, float], str]:
"""
Perform classification with all advanced features.
"""
try:
# Parse labels (flat or hierarchical)
labels = parse_labels_input(labels_input)
# Parse few-shot examples
examples = parse_examples_input(examples_input) if examples_input else None
# Set classification type
pipeline.pipe.classification_type = 'multi-label' if multi_label else 'single-label'
# Prepare prompt
task_prompt = prompt.strip() if prompt and prompt.strip() else None
# Run classification
results = pipeline(
text,
labels,
threshold=threshold,
examples=examples,
prompt=task_prompt,
return_hierarchical=hierarchical_output
)[0] # Single text, get first result
# Format output based on selected format
if output_format == "json":
return format_as_json(results, hierarchical_output)
elif hierarchical_output:
return format_hierarchical_dict(results)
else:
return {result['label']: float(result['score']) for result in results}
except Exception as e:
return f"Error: {str(e)}"
# ============== Example Configurations ==============
EXAMPLES = [
# Example 1: Basic flat labels with prompt
[
TEXT_PRODUCT_REVIEW,
"product review, electronics, positive feedback, negative feedback, customer service, shipping",
0.5,
True,
"Classify this customer review by topic and sentiment:",
"",
False,
"visual"
],
# Example 2: Hierarchical labels for restaurant review
[
TEXT_RESTAURANT_REVIEW,
'''{
"sentiment": ["positive", "negative", "mixed"],
"aspects": ["food quality", "service", "ambiance", "price", "wait time"],
"recommendation": ["would recommend", "would not recommend"]
}''',
0.4,
True,
"Analyze this restaurant review:",
"",
True,
"visual"
],
# Example 3: News article with few-shot examples
[
TEXT_NEWS_POLITICS,
"politics, business, technology, sports, entertainment, science, health",
0.5,
True,
"Classify this news article by category:",
'''[
{"text": "The Federal Reserve raised interest rates by 0.25% today, citing persistent inflation concerns.", "labels": ["politics", "business"]},
{"text": "Scientists discover high new high-temperature superconductor material that works at room temperature.", "labels": ["science", "technology"]}
]''',
False,
"visual"
],
# Example 4: Scientific classification with hierarchical output
[
TEXT_SCIENTIFIC,
'''{
"domain": ["biology", "chemistry", "medicine", "physics"],
"research_type": ["experimental", "theoretical", "review"],
"application": ["therapeutic", "diagnostic", "basic research"]
}''',
0.3,
True,
"Classify this scientific abstract:",
"",
True,
"visual"
],
# Example 5: Sports article - single label
[
TEXT_SPORTS,
"basketball, football, soccer, tennis, baseball, hockey, golf",
0.5,
False,
"What sport is this article about?",
"",
False,
"visual"
],
# Example 6: Movie review with detailed sentiment (JSON output)
[
TEXT_MOVIE_REVIEW,
'''{
"overall_sentiment": ["positive", "negative", "mixed"],
"aspects_praised": ["acting", "direction", "cinematography", "music", "story", "pacing"],
"aspects_criticized": ["acting", "direction", "cinematography", "music", "story", "pacing"],
"recommendation": ["must watch", "worth watching", "skip it"]
}''',
0.35,
True,
"Analyze this movie review in detail:",
"",
True,
"json"
],
# Example 7: Tech startup news
[
TEXT_TECH_STARTUP,
"funding announcement, product launch, acquisition, IPO, partnership, hiring, layoffs, legal",
0.4,
True,
"What type of tech news is this?",
"",
False,
"visual"
],
# Example 8: Health article with hierarchical categories
[
TEXT_HEALTH_WELLNESS,
'''{
"topic": ["nutrition", "exercise", "mental health", "sleep", "medical research"],
"content_type": ["research findings", "practical advice", "expert opinion", "warning"],
"audience": ["general public", "healthcare professionals", "patients"]
}''',
0.4,
True,
"Categorize this health article:",
"",
True,
"visual"
],
# Example 9: Travel content (JSON output)
[
TEXT_TRAVEL,
"destination guide, hotel review, restaurant review, adventure travel, budget travel, luxury travel, travel tips",
0.4,
True,
"What type of travel content is this?",
"",
False,
"json"
],
# Example 10: Recipe classification
[
TEXT_COOKING_RECIPE,
'''{
"cuisine": ["Thai", "Italian", "Mexican", "Indian", "Chinese", "Japanese", "French", "American"],
"difficulty": ["easy", "medium", "hard"],
"meal_type": ["breakfast", "lunch", "dinner", "dessert", "snack"],
"dietary": ["vegetarian friendly", "vegan friendly", "gluten free", "dairy free", "contains meat"]
}''',
0.35,
True,
"Classify this recipe:",
"",
True,
"visual"
],
# Example 11: Financial content with examples
[
TEXT_FINANCIAL_ADVICE,
"investment advice, market analysis, personal finance, retirement planning, tax advice, economic news",
0.4,
True,
"Categorize this financial content:",
'''[
{"text": "Here are 5 ways to maximize your 401k contributions before year end.", "labels": ["personal finance", "retirement planning", "tax advice"]},
{"text": "The S&P 500 rose 2% today following strong jobs report.", "labels": ["market analysis", "economic news"]}
]''',
False,
"visual"
],
# Example 12: Environmental news (JSON output)
[
TEXT_ENVIRONMENTAL,
'''{
"topic": ["climate change", "biodiversity", "pollution", "conservation", "renewable energy"],
"tone": ["alarming", "hopeful", "neutral", "urgent"],
"focus": ["problem description", "solutions", "policy", "research findings"]
}''',
0.35,
True,
"Analyze this environmental article:",
"",
True,
"json"
],
# Example 13: Education debate
[
TEXT_EDUCATION,
"education policy, standardized testing, curriculum, teacher issues, student welfare, technology in education, higher education",
0.4,
True,
"What education topics does this article cover?",
"",
False,
"visual"
],
# Example 14: Fashion news with hierarchy
[
TEXT_FASHION,
'''{
"content_type": ["trend report", "designer profile", "collection review", "industry news", "sustainability"],
"season": ["spring/summer", "fall/winter"],
"market_segment": ["luxury", "fast fashion", "sustainable fashion", "streetwear"]
}''',
0.4,
True,
"Classify this fashion article:",
"",
True,
"visual"
],
# Example 15: Legal case (JSON output)
[
TEXT_LEGAL_CASE,
"constitutional law, criminal law, civil rights, corporate law, intellectual property, free speech, privacy",
0.4,
True,
"What areas of law does this case involve?",
"",
False,
"json"
],
# Example 16: Gaming review with detailed analysis
[
TEXT_GAMING,
'''{
"genre": ["action", "RPG", "adventure", "puzzle", "strategy", "simulation", "sports"],
"platform_feel": ["indie", "AAA", "mid-tier"],
"strengths": ["gameplay", "story", "graphics", "music", "replayability"],
"weaknesses": ["bugs", "difficulty", "length", "graphics", "story"],
"recommendation": ["must play", "worth playing", "wait for sale", "skip"]
}''',
0.35,
True,
"Analyze this game review:",
"",
True,
"visual"
],
# Example 17: Real estate market analysis
[
TEXT_REAL_ESTATE,
"market analysis, buying advice, selling advice, investment, rental market, mortgage rates, housing policy",
0.4,
True,
"What real estate topics are covered?",
"",
False,
"visual"
],
# Example 18: Mental health with few-shot (JSON output)
[
TEXT_MENTAL_HEALTH,
'''{
"topic": ["burnout", "anxiety", "depression", "stress management", "work-life balance"],
"content_type": ["educational", "self-help advice", "research summary", "personal story"],
"actionability": ["provides concrete steps", "general awareness", "seeks professional help"]
}''',
0.35,
True,
"Categorize this mental health content:",
'''[
{"text": "Feeling overwhelmed? Try the 5-4-3-2-1 grounding technique: notice 5 things you see, 4 you hear...", "labels": ["topic.anxiety", "topic.stress management", "content_type.self-help advice", "actionability.provides concrete steps"]},
{"text": "A new study links social media use exceeding 3 hours daily with increased rates of depression in teens.", "labels": ["topic.depression", "content_type.research summary", "actionability.general awareness"]}
]''',
True,
"json"
],
# Example 19: Astronomy discovery
[
TEXT_ASTRONOMY,
"exoplanets, astrobiology, cosmology, solar system, space exploration, telescopes, astrophysics",
0.4,
True,
"What astronomy topics are discussed?",
"",
False,
"visual"
],
# Example 20: Tech companies - single label
[
TEXT_TECH_COMPANIES,
"company profile, product announcement, financial report, industry analysis, biography, opinion piece",
0.5,
False,
"What is the primary type of this article?",
"",
False,
"visual"
],
]
# ============== Gradio Interface ==============
with gr.Blocks(
title="GLiClass Advanced Demo",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
)
) as demo:
gr.Markdown("""
# 🏷️ GLiClass Advanced Zero-Shot Classification
Enhanced demo featuring **prompts**, **hierarchical labels**, **few-shot examples**, and **structured outputs**.
""")
with gr.Accordion("📖 How to Use This Demo", open=False):
gr.Markdown("""
## Features Overview
### 1. Task Description Prompts
Add a natural language description of the classification task to guide the model.
**Example:** `"Classify this customer review by sentiment and topic:"`
---
### 2. Hierarchical Labels (JSON Format)
Structure your labels in categories for organized classification:
```json
{
"sentiment": ["positive", "negative", "neutral"],
"topic": ["product", "service", "shipping"],
"urgency": ["high", "medium", "low"]
}
```
Or use simple comma-separated labels: `positive, negative, neutral`
---
### 3. Few-Shot Examples
Provide examples to guide the model's understanding:
```json
[
{"text": "Great product, love it!", "labels": ["positive", "product"]},
{"text": "Shipping was delayed by 2 weeks", "labels": ["negative", "shipping"]}
]
```
---
### 4. Hierarchical Output
When enabled with hierarchical labels, returns structured scores matching your input format.
""")
with gr.Accordion("💻 Code Example", open=False):
gr.Code(
'''from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1")
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
# Basic usage
text = "The product quality is amazing but delivery was slow"
labels = ["positive", "negative", "product", "shipping"]
results = pipeline(text, labels, threshold=0.5)[0]
# With hierarchical labels
hierarchical_labels = {
"sentiment": ["positive", "negative", "neutral"],
"topic": ["product", "service", "shipping"]
}
results = pipeline(
text,
hierarchical_labels,
prompt="Classify this review:",
return_hierarchical=True
)[0]
# With few-shot examples
examples = [
{"text": "Love this item!", "labels": ["sentiment.positive", "topic.product"]},
{"text": "Terrible customer support", "labels": ["sentiment.negative", "topic.service"]}
]
results = pipeline(
text,
hierarchical_labels,
examples=examples,
prompt="Classify customer feedback:"
)[0]
''',
language="python",
)
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
value=EXAMPLES[0][0],
label="📝 Text Input",
placeholder="Enter the text you want to classify...",
lines=8
)
prompt_input = gr.Textbox(
value=EXAMPLES[0][4],
label="💡 Task Description Prompt (Optional)",
placeholder="E.g., 'Classify this customer review by sentiment and topic:'",
lines=1
)
with gr.Column(scale=1):
labels_input = gr.Textbox(
value=EXAMPLES[0][1],
label="🏷️ Labels (comma-separated or JSON)",
placeholder='positive, negative\n\nOR\n\n{"category": ["label1", "label2"]}',
lines=6
)
with gr.Row():
threshold = gr.Slider(
0, 1,
value=0.5,
step=0.01,
label="Threshold",
info="Confidence threshold for predictions"
)
with gr.Row():
multi_label = gr.Checkbox(
value=True,
label="Multi-label",
info="Allow multiple labels per text"
)
hierarchical_output = gr.Checkbox(
value=False,
label="Hierarchical Output",
info="Return structured output matching label hierarchy"
)
with gr.Row():
output_format = gr.Radio(
choices=["visual", "json"],
value="visual",
label="Output Format",
info="Visual: charts/bars | JSON: raw data"
)
with gr.Accordion("🎯 Few-Shot Examples (Optional)", open=False):
examples_input = gr.Textbox(
value="",
label="Examples (JSON format)",
placeholder='''[
{"text": "Example text 1", "labels": ["label1", "label2"]},
{"text": "Example text 2", "labels": ["label3"]}
]''',
lines=5
)
gr.Markdown("""
*Provide labeled examples to guide the model. Each example needs a `text` field and a `labels` array.*
""")
submit_btn = gr.Button("🚀 Classify", variant="primary", size="lg")
output = gr.Label(label="📊 Classification Results")
output_text = gr.Textbox(
label="📊 Hierarchical Results",
visible=False,
lines=10
)
output_json = gr.Code(
label="📊 JSON Output",
language="json",
visible=False,
lines=15
)
# Dynamic output visibility based on format and hierarchical toggle
def update_output_visibility(hierarchical: bool, fmt: str):
if fmt == "json":
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
elif hierarchical:
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
hierarchical_output.change(
fn=update_output_visibility,
inputs=[hierarchical_output, output_format],
outputs=[output, output_text, output_json]
)
output_format.change(
fn=update_output_visibility,
inputs=[hierarchical_output, output_format],
outputs=[output, output_text, output_json]
)
# Classification function wrapper for different outputs
def classify_wrapper(text, labels, threshold, multi_label, prompt, examples, hierarchical, fmt):
result = classification(text, labels, threshold, multi_label, prompt, examples, hierarchical, fmt)
if fmt == "json":
return None, None, result
elif hierarchical or isinstance(result, str):
return None, result, None
else:
return result, None, None
# Event handlers
submit_btn.click(
fn=classify_wrapper,
inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],
outputs=[output, output_text, output_json]
)
input_text.submit(
fn=classify_wrapper,
inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],
outputs=[output, output_text, output_json]
)
gr.Markdown("### 📚 Example Configurations")
gr.Examples(
examples=EXAMPLES,
inputs=[input_text, labels_input, threshold, multi_label, prompt_input, examples_input, hierarchical_output, output_format],
outputs=[output, output_text, output_json],
fn=classify_wrapper,
cache_examples=False,
examples_per_page=5
)
gr.Markdown("""
---
### 🔧 Tips for Best Results
| Feature | Best Practice |
|---------|---------------|
| **Prompts** | Be specific about the task, e.g., "Classify by sentiment:" vs "Analyze:" |
| **Labels** | Use descriptive labels; "customer service issue" > "service" |
| **Hierarchical** | Group related labels under categories for organized results |
| **Examples** | 2-3 diverse examples often improve accuracy significantly |
| **Threshold** | Start at 0.5, lower for more predictions, raise for higher precision |
""")
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True, share=True)
================================================
FILE: gliclass/__init__.py
================================================
from .model import GLiClassModel, GLiClassBiEncoder, GLiClassUniEncoder, GLiClassEncoderDecoderCLS
from .config import GLiClassModelConfig
from .pipeline import (
ZeroShotClassificationPipeline,
BiEncoderZeroShotClassificationPipeline,
ZeroShotClassificationWithChunkingPipeline,
)
__version__ = "0.1.19"
# Serve module (optional import)
try:
from . import serve
except ImportError:
serve = None
================================================
FILE: gliclass/config.py
================================================
from transformers import AutoConfig
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
from transformers.configuration_utils import PretrainedConfig
from .utils import is_module_available
IS_TURBOT5 = is_module_available("turbot5")
if IS_TURBOT5:
from turbot5.model.config import T5Config
else:
from transformers import T5Config
logger = logging.get_logger(__name__)
class GLiClassModelConfig(PretrainedConfig):
model_type = "GLiClass"
is_composition = True
def __init__(
self,
encoder_config=None,
encoder_model=None,
label_model_config=None,
label_model_name=None,
class_token_index=-1,
text_token_index=-1,
example_token_index=-1,
ignore_index=-100,
hidden_size=None,
projector_hidden_act="gelu",
vocab_size=None,
problem_type="single_label_classification",
max_num_classes=25,
use_lstm=False,
initializer_range=0.03,
scorer_type="simple",
scorer_num_heads=16,
scorer_mlp_hidden_size=1024,
scorer_attn_dropout=0.1,
pooling_strategy="first",
class_token_pooling="first",
focal_loss_alpha=0.5,
focal_loss_gamma=2,
focal_loss_reduction=None,
logit_scale_init_value=2.6592,
normalize_features=False,
extract_text_features=False,
max_labels_alloc: str = "dynamic",
contrastive_loss_coef=0,
architecture_type="uni-encoder",
prompt_first=False,
squeeze_layers=False,
layer_wise=False,
encoder_layer_id=-1,
embed_class_token=True,
dropout=0.1,
use_segment_embeddings=False,
**kwargs,
):
if isinstance(encoder_config, dict):
encoder_config["model_type"] = encoder_config.get("model_type", "deberta-v2")
if encoder_config["model_type"] == "t5":
encoder_config = T5Config(**encoder_config)
elif encoder_config["model_type"] in CONFIG_MAPPING:
encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
else:
_name = encoder_model or kwargs.get("encoder_model_name")
if _name:
encoder_config = AutoConfig.from_pretrained(_name, trust_remote_code=True)
else:
encoder_config = PretrainedConfig(**encoder_config)
elif encoder_config is None:
encoder_config = CONFIG_MAPPING["deberta-v2"]()
self.encoder_config = encoder_config
self.encoder_model_name = encoder_model
if label_model_name is not None:
if isinstance(label_model_config, dict):
label_model_config["model_type"] = label_model_config.get("model_type", "deberta-v2")
label_model_config = CONFIG_MAPPING[label_model_config["model_type"]](**label_model_config)
elif label_model_config is None:
label_model_config = CONFIG_MAPPING["deberta-v2"]()
self.label_model_config = label_model_config
else:
self.label_model_config = None
self.label_model_name = label_model_name
if hidden_size is None:
self.hidden_size = self.encoder_config.hidden_size
else:
self.hidden_size = hidden_size
if vocab_size is None:
self.vocab_size = self.encoder_config.vocab_size
else:
self.vocab_size = vocab_size
if class_token_index == -1:
self.class_token_index = self.vocab_size
else:
self.class_token_index = class_token_index
if text_token_index == -1:
self.text_token_index = self.vocab_size + 1
else:
self.text_token_index = text_token_index
if example_token_index == -1:
self.example_token_index = self.vocab_size + 2
else:
self.example_token_index = example_token_index
self.ignore_index = ignore_index
self.projector_hidden_act = projector_hidden_act
self.problem_type = problem_type
self.max_num_classes = max_num_classes
self.initializer_range = initializer_range
self.scorer_type = scorer_type
self.scorer_num_heads = scorer_num_heads
self.scorer_mlp_hidden_size = scorer_mlp_hidden_size
self.scorer_attn_dropout = scorer_attn_dropout
self.pooling_strategy = pooling_strategy
self.class_token_pooling = class_token_pooling
self.use_lstm = use_lstm
self.focal_loss_alpha = focal_loss_alpha
self.focal_loss_gamma = focal_loss_gamma
self.focal_loss_reduction = focal_loss_reduction
self.contrastive_loss_coef = contrastive_loss_coef
self.logit_scale_init_value = logit_scale_init_value
self.normalize_features = normalize_features
self.extract_text_features = extract_text_features
self.max_labels_alloc = max_labels_alloc
self.architecture_type = architecture_type
self.prompt_first = prompt_first
self.squeeze_layers = squeeze_layers
self.layer_wise = layer_wise
self.encoder_layer_id = encoder_layer_id
self.embed_class_token = embed_class_token
self.pad_token_id = self.encoder_config.pad_token_id
self.dropout = dropout
self.use_segment_embeddings = use_segment_embeddings
super().__init__(**kwargs)
================================================
FILE: gliclass/data_processing.py
================================================
import copy
import random
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
@dataclass
class AugmentationConfig:
"""Configuration for data augmentation."""
enabled: bool = True
# Probability for each augmentation type
random_label_removal_prob: float = 0.15
random_label_addition_prob: float = 0.10
random_text_addition_prob: float = 0.05
random_add_description_prob: float = 0.25
random_add_synonyms_prob: float = 0.1
random_add_examples_prob: float = 0.25
max_num_examples: int = 5
class DataAugmenter:
def __init__(self, config, examples, labels, label2description=None):
self.config = config
self.examples = examples
self.labels = sorted(labels)
self.max_examples = self.config.max_num_examples
self.label2description = label2description or {}
def remove_labels(self, true_labels, all_labels):
if len(all_labels) <= 1:
return true_labels, all_labels
k = random.randint(1, len(all_labels))
all_labels = random.sample(all_labels, k=k)
true_labels = [lbl for lbl in true_labels if lbl in all_labels]
return true_labels, all_labels
def add_random_labels(self, all_labels):
if not self.labels:
return all_labels
num_add = len(all_labels) + 1
k = random.randint(1, min(num_add, len(self.labels)))
add_labels = random.sample(self.labels, k=k)
all_labels.extend(add_labels)
return all_labels
def add_random_text(self, text, all_labels):
if not self.examples:
return text
example = random.sample(self.examples, k=1)[0]
curr_labels = example["all_labels"]
joint_labels = set(all_labels) & set(curr_labels)
if len(joint_labels):
return text
else:
if random.randint(0, 1):
text = example["text"] + " " + text
else:
text = text + " " + example["text"]
return text
def add_random_synonyms(self, all_labels):
"""Replace some labels with their synonyms if available."""
if not self.label2description:
return all_labels
augmented_labels = []
for label in all_labels:
if label in self.label2description:
label_info = self.label2description[label]
synonyms = label_info.get("synonyms", [])
if synonyms and random.random() < 0.5:
augmented_labels.append(random.choice(synonyms))
else:
augmented_labels.append(label)
else:
augmented_labels.append(label)
return augmented_labels
def add_random_descriptions(self, item):
"""Add descriptions to labels in the text or metadata."""
if not self.label2description or not item["all_labels"]:
return item
max_labels = min(3, len(item["all_labels"]))
labels_to_describe = random.sample(item["all_labels"], k=random.randint(1, max_labels))
descriptions = []
for label in labels_to_describe:
if label in self.label2description:
label_info = self.label2description[label]
desc_list = label_info.get("descriptions", [])
if desc_list:
descriptions.append(f"{label}: {random.choice(desc_list)}")
if descriptions:
desc_text = " ".join(descriptions)
if random.random() < 0.5:
item["text"] = desc_text + " " + item["text"]
else:
item["text"] = item["text"] + " " + desc_text
return item
def add_random_examples(self, item):
"""Add example texts with similar labels."""
if not item["all_labels"]:
return item
candidate_examples = item.get("examples", [])
item_label_set = set(item["all_labels"])
if not candidate_examples:
for example in self.examples:
example_label_set = set(example["true_labels"])
example_text = example["text"]
overlap = item_label_set & example_label_set
# Only consider examples with at least one overlapping label
if overlap:
candidate_examples.append({"text": example_text, "labels": list(example_label_set)})
if not candidate_examples:
return item
# Sort by overlap and select top examples
random.shuffle(candidate_examples)
top_candidates = candidate_examples[: self.max_examples]
num_examples = random.randint(1, min(2, len(top_candidates)))
selected_examples = random.sample(top_candidates, k=num_examples)
item["examples"] = selected_examples
return item
def augment(self, item):
if not self.config.enabled:
return item
text = copy.deepcopy(item["text"])
true_labels = copy.deepcopy(item["true_labels"])
all_labels = copy.deepcopy(item["all_labels"])
# Create augmented item
aug_item = {"text": text, "true_labels": true_labels, "all_labels": all_labels}
# Copy any additional fields
for key in item:
if key not in aug_item:
aug_item[key] = copy.deepcopy(item[key])
if random.random() < self.config.random_label_removal_prob:
aug_item["true_labels"], aug_item["all_labels"] = self.remove_labels(
aug_item["true_labels"], aug_item["all_labels"]
)
if random.random() < self.config.random_label_addition_prob:
aug_item["all_labels"] = self.add_random_labels(aug_item["all_labels"])
if random.random() < self.config.random_text_addition_prob:
aug_item["text"] = self.add_random_text(aug_item["text"], aug_item["all_labels"])
if random.random() < self.config.random_add_synonyms_prob:
aug_item["all_labels"] = self.add_random_synonyms(aug_item["all_labels"])
if random.random() < self.config.random_add_description_prob:
aug_item = self.add_random_descriptions(aug_item)
if random.random() < self.config.random_add_examples_prob:
aug_item = self.add_random_examples(aug_item)
return aug_item
class GLiClassDataset(Dataset):
def __init__(
self,
examples,
tokenizer,
augment_config,
label2description={},
max_length=512,
problem_type="multi_label_classification",
architecture_type="uni-encoder",
add_description=True,
prompt_first=False,
get_negatives=False,
max_labels=50,
labels_tokenizer=None,
shuffle_labels=True,
):
self.tokenizer = tokenizer
self.labels_tokenizer = labels_tokenizer
self.label2description = label2description
self.augment_config = augment_config
self.max_length = max_length
self._data = examples
self.add_description = add_description
self.problem_type = problem_type
self.architecture_type = architecture_type
self.prompt_first = prompt_first
self.dataset_labels = self.collect_dataset_labels()
self.get_negatives = get_negatives
self.max_labels = max_labels
self.shuffle_labels = shuffle_labels
self.sep_token = "<<SEP>>"
self.label_token = "<<LABEL>>"
self.example_token = "<<EXAMPLE>>"
self.augmenter = DataAugmenter(augment_config, examples, self.dataset_labels, label2description)
print("Total labels: ", len(self.dataset_labels))
def get_diversity(self):
return [item.get("_diversity", {}).get("overall_diversity", 0.5) for item in self.data]
def collect_dataset_labels(self):
dataset_labels = set()
for example in self._data:
dataset_labels.update(set(example["all_labels"]))
return dataset_labels
def prepare_labels(self, example, label2idx, problem_type):
if problem_type == "single_label_classification":
labels = label2idx[example["true_labels"][0]]
elif problem_type == "multi_label_classification":
if isinstance(example["true_labels"], dict):
labels = [example["true_labels"].get(label, 0.0) for label in example["all_labels"]]
else:
labels = [1.0 if label in example["true_labels"] else 0.0 for label in example["all_labels"]]
else:
raise NotImplementedError(f"{problem_type} is not implemented.")
return torch.tensor(labels)
def prepare_prompt(self, item, label_token_first=True):
prompt_texts = []
for label in item["all_labels"]:
if label_token_first:
label_tag = f"{self.label_token}{label!s}"
else:
label_tag = f"{label!s}{self.label_token}"
prompt_texts.append(label_tag)
prompt_texts.append(self.sep_token)
prompt = item.get("prompt", "")
prompt_texts.append(prompt)
return prompt_texts
def format_examples(self, item):
examples = item.get("examples", [])
if not examples:
return ""
examples = random.sample(examples, k=random.randint(1, len(examples)))
parts = []
for example in examples:
parts.append(self.example_token)
parts.append(example.get("text", ""))
parts.append(" \nLabels:\n ")
parts.append(", ".join(example.get("labels", example.get("true_labels", []))))
parts.append(self.sep_token)
return "".join(parts)
def tokenize(self, texts):
tokenized_inputs = self.tokenizer(texts, truncation=True, max_length=self.max_length, padding="longest")
return tokenized_inputs
def tokenize_labels(self, labels):
tokenized_inputs = self.labels_tokenizer(labels, truncation=True, max_length=self.max_length, padding="longest")
return tokenized_inputs
def tokenize_and_prepare_labels_for_uniencoder(self, example):
if self.shuffle_labels:
random.shuffle(example["all_labels"])
input_text = self.prepare_prompt(example)
examples_text = self.format_examples(example)
if self.prompt_first:
input_text = "".join(input_text) + str(example["text"]) + examples_text
else:
input_text = str(example["text"]) + "".join(input_text) + examples_text
label2idx = {label: idx for idx, label in enumerate(example["all_labels"])}
tokenized_inputs = self.tokenize(input_text)
tokenized_inputs["labels"] = self.prepare_labels(example, label2idx, self.problem_type)
tokenized_inputs["labels_text"] = example["all_labels"]
tokenized_inputs["input_texts"] = example["text"]
return tokenized_inputs
def tokenize_and_prepare_labels_for_encoder_decoder(self, example):
if self.shuffle_labels:
random.shuffle(example["all_labels"])
class_texts = self.prepare_prompt(example, label_token_first=True)
class_texts = "".join(class_texts)
examples_text = self.format_examples(example)
label2idx = {label: idx for idx, label in enumerate(example["all_labels"])}
input_text = str(example["text"]) + examples_text
tokenized_inputs = self.tokenize(input_text)
tokenized_classes = self.tokenize(class_texts)
tokenized_inputs["class_input_ids"] = tokenized_classes["input_ids"]
tokenized_inputs["class_attention_mask"] = tokenized_classes["attention_mask"]
tokenized_inputs["labels"] = self.prepare_labels(example, label2idx, self.problem_type)
return tokenized_inputs
def tokenize_and_prepare_labels_for_biencoder(self, example):
if self.shuffle_labels:
random.shuffle(example["all_labels"])
def prepare_prompt(labels):
prompt_texts = []
for _label in labels:
label_tag = "<<LABEL>>"
prompt_texts.append(label_tag)
prompt_texts.append("<<SEP>>")
return "".join(prompt_texts)
input_text = example["text"]
class_texts = example["all_labels"]
if self.architecture_type == "bi-encoder-fused":
prompt = prepare_prompt(class_texts)
if self.prompt_first:
input_text = f"{prompt} {input_text}"
else:
input_text = f"{input_text} {prompt}"
tokenized_inputs = self.tokenize(input_text)
tokenized_classes = self.tokenize_labels(class_texts)
tokenized_inputs["class_input_ids"] = torch.tensor(tokenized_classes["input_ids"])
tokenized_inputs["class_attention_mask"] = torch.tensor(tokenized_classes["attention_mask"])
label2idx = {label: idx for idx, label in enumerate(example["all_labels"])}
tokenized_inputs["labels_mask"] = torch.ones(len(class_texts))
tokenized_inputs["labels"] = self.prepare_labels(example, label2idx, self.problem_type)
return tokenized_inputs
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
example = self._data[idx]
example = self.augmenter.augment(example)
if self.architecture_type == "uni-encoder":
model_inputs = self.tokenize_and_prepare_labels_for_uniencoder(example)
elif self.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
model_inputs = self.tokenize_and_prepare_labels_for_encoder_decoder(example)
elif self.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
model_inputs = self.tokenize_and_prepare_labels_for_biencoder(example)
else:
raise NotImplementedError("This architecture type is not implemented.")
return model_inputs
def pad_2d_tensor(key_data):
"""
Pad a list of 2D tensors to have the same size along both dimensions.
:param key_data: List of 2D tensors to pad.
:return: Tensor of padded tensors stacked along a new batch dimension.
"""
if not key_data:
raise ValueError("The input list 'key_data' should not be empty.")
# Determine the maximum size along both dimensions
max_rows = max(tensor.shape[0] for tensor in key_data)
max_cols = max(tensor.shape[1] for tensor in key_data)
tensors = []
for tensor in key_data:
rows, cols = tensor.shape
row_padding = max_rows - rows
col_padding = max_cols - cols
# Pad the tensor along both dimensions
padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=0)
tensors.append(padded_tensor)
# Stack the tensors into a single tensor along a new batch dimension
padded_tensors = torch.stack(tensors)
return padded_tensors
class DataCollatorWithPadding:
def __init__(self, device="cuda:0", config=None):
self.device = device
self._max_labels_alloc = getattr(config, "max_labels_alloc", "dynamic") if config is not None else "dynamic"
def _resolve_max_num_classes(self, batch):
if self._max_labels_alloc == "dynamic":
first = batch[0]
if "labels_text" in first:
return max(len(item["labels_text"]) for item in batch)
if "labels_mask" in first:
return max(item["labels_mask"].shape[0] for item in batch)
first_labels = first.get("labels")
if isinstance(first_labels, torch.Tensor) and first_labels.dim() >= 1:
return max(item["labels"].shape[0] for item in batch)
return None
if isinstance(self._max_labels_alloc, int):
return self._max_labels_alloc
return None # 'fixed': model uses config.max_num_classes
def __call__(self, batch):
keys = batch[0].keys()
padded_batch = {key: [] for key in keys}
for key in keys:
key_data = [item[key] for item in batch]
if isinstance(key_data[0], torch.Tensor):
if key_data[0].dim() == 1:
padded_batch[key] = pad_sequence(key_data, batch_first=True)
elif key_data[0].dim() == 2:
padded_batch[key] = pad_2d_tensor(key_data)
elif isinstance(key_data[0], list):
data_el = "string"
if len(key_data[0]):
data_el = key_data[0][0]
if isinstance(data_el, str):
padded_batch[key] = key_data
else:
max_length = max(len(seq) for seq in key_data)
padded_batch[key] = torch.tensor([seq + [0] * (max_length - len(seq)) for seq in key_data])
elif type(key_data[0]) in {int, float}:
padded_batch[key] = torch.tensor(key_data)
elif isinstance(key_data[0], str):
padded_batch[key] = key_data
else:
raise TypeError(f"Unsupported data type: {type(key_data[0])}")
padded_batch["max_num_classes"] = self._resolve_max_num_classes(batch)
return padded_batch
================================================
FILE: gliclass/layers.py
================================================
# Copyright 2020 Microsoft and the Hugging Face Inc. team and Knowledgator.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from transformers.activations import ACT2FN
from .config import GLiClassModelConfig
class LstmSeq2SeqEncoder(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, bidirectional=False):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
batch_first=True,
)
def forward(self, x, mask, hidden=None):
# Packing the input sequence
lengths = mask.sum(dim=1).cpu()
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
# Passing packed sequence through LSTM
packed_output, hidden = self.lstm(packed_x, hidden)
# Unpacking the output sequence
output, _ = pad_packed_sequence(packed_output, batch_first=True)
return output
class FeaturesProjector(nn.Module):
def __init__(self, config: GLiClassModelConfig):
super().__init__()
self.linear_1 = nn.Linear(config.encoder_config.hidden_size, config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.dropout = nn.Dropout(config.dropout)
self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)
def forward(self, features):
hidden_states = self.linear_1(features)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class BiEncoderProjector(nn.Module):
def __init__(self, config: GLiClassModelConfig):
super().__init__()
self.linear_1 = nn.Linear(config.label_model_config.hidden_size, config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)
def forward(self, features):
hidden_states = self.linear_1(features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
class DropoutContext:
def __init__(self):
self.dropout = 0
self.mask = None
self.scale = 1
self.reuse_mask = True
# Copied from transformers.models.deberta.modeling_deberta.get_mask
def get_mask(input, local_context):
if not isinstance(local_context, DropoutContext):
dropout = local_context
mask = None
else:
dropout = local_context.dropout
dropout *= local_context.scale
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext) and local_context.mask is None:
local_context.mask = mask
return mask, dropout
# Copied from transformers.models.deberta.modeling_deberta.XDropout
class XDropout(torch.autograd.Function):
"""Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
@staticmethod
def forward(ctx, input, local_ctx):
mask, dropout = get_mask(input, local_ctx)
ctx.scale = 1.0 / (1 - dropout)
if dropout > 0:
ctx.save_for_backward(mask)
return input.masked_fill(mask, 0) * ctx.scale
else:
return input
@staticmethod
def backward(ctx, grad_output):
if ctx.scale > 1:
(mask,) = ctx.saved_tensors
return grad_output.masked_fill(mask, 0) * ctx.scale, None
else:
return grad_output, None
@staticmethod
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: float | DropoutContext) -> torch._C.Value:
from torch.onnx import symbolic_opset12
dropout_p = local_ctx
if isinstance(local_ctx, DropoutContext):
dropout_p = local_ctx.dropout
# StableDropout only calls this function when training.
train = True
# TODO: We should check if the opset_version being used to export
# is > 12 here, but there's no good way to do that. As-is, if the
# opset_version < 12, export will fail with a CheckerError.
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
# if opset_version < 12:
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
return symbolic_opset12.dropout(g, input, dropout_p, train)
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
class StableDropout(nn.Module):
"""
Optimized dropout module for stabilizing the training.
Args:
drop_prob (float): the dropout probabilities
"""
def __init__(self, drop_prob):
super().__init__()
self.drop_prob = drop_prob
self.count = 0
self.context_stack = None
def forward(self, x):
"""
Call the module.
Args:
x (`torch.tensor`): The input tensor to apply dropout
"""
if self.training and self.drop_prob > 0:
return XDropout.apply(x, self.get_context())
return x
def clear_context(self):
self.count = 0
self.context_stack = None
def init_context(self, reuse_mask=True, scale=1):
if self.context_stack is None:
self.context_stack = []
self.count = 0
for c in self.context_stack:
c.reuse_mask = reuse_mask
c.scale = scale
def get_context(self):
if self.context_stack is not None:
if self.count >= len(self.context_stack):
self.context_stack.append(DropoutContext())
ctx = self.context_stack[self.count]
ctx.dropout = self.drop_prob
self.count += 1
return ctx
else:
return self.drop_prob
class SelfAttentionBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
return self.norm(x + self.dropout(attn_output))
class CrossAttentionBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
attn_output, _ = self.cross_attn(query, key, value, attn_mask=mask)
return self.norm(query + self.dropout(attn_output))
class Fuser(nn.Module):
def __init__(self, d_model, num_heads, num_layers, dropout=0.1):
super().__init__()
self.d_model = d_model
self.layers = nn.ModuleList(
[
nn.ModuleList(
[SelfAttentionBlock(d_model, num_heads, dropout), CrossAttentionBlock(d_model, num_heads, dropout)]
)
for _ in range(num_layers)
]
)
self.fc = nn.Linear(d_model, d_model)
def forward(self, query, key, query_mask=None, key_mask=None):
if query_mask is not None and key_mask is not None:
self_attn_mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2)
cross_attn_mask = query_mask.unsqueeze(-1) * key_mask.unsqueeze(1)
else:
self_attn_mask = None
cross_attn_mask = None
value = self.fc(key)
for self_attn, cross_attn in self.layers:
query = self_attn(query, mask=self_attn_mask)
query = cross_attn(query, key, value, mask=cross_attn_mask)
return query
class LayerwiseAttention(nn.Module):
def __init__(self, num_layers, hidden_size, output_size=None):
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.output_size = output_size if output_size is not None else hidden_size
# Squeeze operation
self.squeeze = nn.Linear(hidden_size, 1)
# Excitation operation
self.W1 = nn.Linear(num_layers, num_layers // 2)
self.W2 = nn.Linear(num_layers // 2, num_layers)
# Final projection
self.output_projection = nn.Linear(self.hidden_size, self.output_size)
def forward(self, encoder_outputs):
# encoder_outputs is a list of tensors, each of shape [B, L, D]
_B, _L, _D = encoder_outputs[0].shape
# Concatenate all layers
U = torch.stack(encoder_outputs, dim=1) # [B, K, L, D]
# Squeeze operation
Z = self.squeeze(U).squeeze(-1) # [B, K, L]
Z = Z.mean(dim=2) # [B, K]
# Excitation operation
s = self.W2(F.relu(self.W1(Z))) # [B, K]
s = torch.sigmoid(s) # [B, K]
# Apply attention weights
U_weighted = U * s.unsqueeze(-1).unsqueeze(-1) # [B, K, L, D]
# Sum across layers
U_sum = U_weighted.sum(dim=1) # [B, L, D]
# Final projection
output = self.output_projection(U_sum) # [B, L, output_size]
return output
================================================
FILE: gliclass/loss_functions.py
================================================
import torch
import torch.nn.functional as F
def sequence_contrastive_loss(embeddings, mask):
# embeddings shape: (B, L, D)
# mask shape: (B, L)
B, L, _D = embeddings.shape
# Normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=-1)
# Compute similarity matrix
sim_matrix = torch.matmul(embeddings, embeddings.transpose(1, 2)) # / self.temperature
# Create labels for cross entropy (diagonal indices)
labels = torch.arange(L, device=embeddings.device).unsqueeze(0).expand(B, -1)
# Compute loss for each element in the batch
loss = F.cross_entropy(sim_matrix.reshape(B * L, L), labels.reshape(-1), reduction="none")
# Apply mask to loss
loss = loss.view(B, L) * mask
# Compute mean loss over non-padded elements
loss = loss.sum() / mask.sum()
return loss
def focal_loss_with_logits(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
label_smoothing: float = 0.0,
ignore_index: int = -100, # default value for ignored index
) -> torch.Tensor:
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha (float): Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: ``2``.
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.
label_smoothing (float): Specifies the amount of smoothing when computing the loss,
where 0.0 means no smoothing.
ignore_index (int): Specifies a target value that is ignored and does not contribute
to the input gradient. Default: ``-100``.
Returns:
Loss tensor with the reduction option applied.
"""
# Create a mask to ignore specified index
valid_mask = targets != ignore_index
# Apply label smoothing if needed
if label_smoothing != 0:
with torch.no_grad():
targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing
# Apply sigmoid activation to inputs
p = torch.sigmoid(inputs)
# Compute the binary cross-entropy loss without reduction
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# Apply the valid mask to the loss
loss = loss * valid_mask
# Apply focal loss modulation if gamma is greater than 0
if gamma > 0:
p_t = p * targets + (1 - p) * (1 - targets)
loss = loss * ((1 - p_t) ** gamma)
# Apply alpha weighting if alpha is specified
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
# Apply reduction method
if reduction == "none":
return loss
elif reduction == "mean":
return loss.sum() / valid_mask.sum() # Normalize by the number of valid (non-ignored) elements
elif reduction == "sum":
return loss.sum()
else:
raise ValueError(
f"Invalid value for argument 'reduction': '{reduction}'. Supported reduction modes: 'none', 'mean', 'sum'"
)
================================================
FILE: gliclass/model.py
================================================
import os
import warnings
from typing import Tuple
from pathlib import Path
from dataclasses import dataclass
import torch
import transformers
from torch import nn
from packaging import version
from transformers import AutoModel, AutoConfig, PreTrainedModel
from transformers.utils import logging
from transformers.modeling_outputs import SequenceClassifierOutput
# Import initialization module (transformers 5.0+) or fallback to torch.nn.init
try:
from transformers import initialization as init
except ImportError:
# transformers < 5.0 doesn't have this module, use torch.nn.init instead
from torch.nn import init
from .utils import MissedPackageException, is_module_available
from .config import GLiClassModelConfig
from .layers import FeaturesProjector, BiEncoderProjector, LayerwiseAttention, LstmSeq2SeqEncoder
from .scorers import SCORER2OBJECT
from .poolings import POOLING2OBJECT
from .loss_functions import focal_loss_with_logits, sequence_contrastive_loss
IS_LLM2VEC = is_module_available("llm2vec")
IS_PEFT = is_module_available("peft")
IS_TURBOT5 = is_module_available("turbot5")
IS_FLASHDEBERTA = is_module_available("flashdeberta")
logger = logging.get_logger(__name__)
if IS_LLM2VEC:
from llm2vec.models import GemmaBiModel, LlamaBiModel, Qwen2BiModel, MistralBiModel
DECODER_MODEL_MAPPING = {
"MistralConfig": MistralBiModel,
"LlamaConfig": LlamaBiModel,
"GemmaConfig": GemmaBiModel,
"Qwen2Config": Qwen2BiModel,
}
else:
DECODER_MODEL_MAPPING = {}
if IS_TURBOT5:
from turbot5.model.modeling import T5EncoderModel as FlashT5EncoderModel
from transformers import T5EncoderModel, UMT5EncoderModel
if IS_FLASHDEBERTA:
from flashdeberta import FlashDebertaV2Model
from transformers import DebertaV2Model
if IS_PEFT:
from peft import LoraConfig, get_peft_model
@dataclass
class GLiClassOutput(SequenceClassifierOutput):
text_embeddings: torch.Tensor | None = None
class_embeddings: torch.Tensor | None = None
class GLiClassPreTrainedModel(PreTrainedModel):
config_class = GLiClassModelConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_sdpa = False
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
def _initialize_weights(self, module, is_remote_code: bool = False):
"""
Initialize weights if not already initialized.
This method is called by transformers 5.0+ during post_init().
It uses the _is_hf_initialized flag to prevent reinitializing weights
that were already loaded from a checkpoint.
For transformers 4.x, this method is not called, maintaining backward compatibility.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
module._is_hf_initialized = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.encoder_config.initializer_range
)
if hasattr(module, "class_embedding"):
init.normal_(module.class_embedding, mean=0.0, std=std)
if hasattr(module, "segment_embeddings"):
init.normal_(module.segment_embeddings.weight, mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
init.normal_(module.weight, mean=0.0, std=std)
if module.padding_idx is not None:
init.zeros_(module.weight[module.padding_idx])
elif isinstance(module, nn.LSTM):
for name, param in module.named_parameters():
if "weight_ih" in name or "weight_hh" in name:
init.normal_(param, mean=0.0, std=std)
elif "bias" in name:
init.zeros_(param)
class GLiClassBaseModel(nn.Module): # ):
def __init__(self, config: GLiClassModelConfig, device="cpu", **kwargs):
super().__init__()
self.config = config
self.text_projector = FeaturesProjector(config)
self.classes_projector = FeaturesProjector(config)
if config.pooling_strategy not in POOLING2OBJECT:
raise NotImplementedError(f"{config.pooling_strategy} is not implemented pooling type.")
else:
self.pooler = POOLING2OBJECT[config.pooling_strategy]()
if config.pooling_strategy not in POOLING2OBJECT:
raise NotImplementedError(
f"{config.scorer_type} is not implemented. Choose one of this: 'dot', 'weighted-dot'"
)
else:
self.scorer = SCORER2OBJECT[config.scorer_type](
config.hidden_size,
num_heads=config.scorer_num_heads,
scorer_mlp_hidden_size=config.scorer_mlp_hidden_size,
attn_dropout=config.scorer_attn_dropout,
)
if config.use_lstm:
self.lstm = LstmSeq2SeqEncoder(config.hidden_size, config.hidden_size // 2, bidirectional=True)
if config.squeeze_layers:
self.layer_wise_attention = LayerwiseAttention(
config.encoder_config.num_hidden_layers, config.encoder_config.hidden_size
)
drop_out = getattr(config, "dropout", 0.0)
# self.dropout = StableDropout(drop_out)
self.dropout = nn.Dropout(drop_out)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.epsilon = 1e-8
self.vocab_size = config.vocab_size
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.num_labels = -1
self.device = torch.device(device)
def _extract_class_features(self, token_embeds, input_ids, attention_mask, max_num_classes=None):
batch_size, _sequence_length, embed_dim = token_embeds.shape
class_token_mask = input_ids == self.config.class_token_index
num_class_tokens = torch.sum(class_token_mask, dim=-1, keepdim=True)
# max_num_classes from caller (CPU int) avoids GPU→CPU sync via .item()
max_embed_dim = max_num_classes if max_num_classes is not None else self.config.max_num_classes
# Get class token pooling method from config (default to "first" for backward compatibility)
class_token_pooling = getattr(self.config, "class_token_pooling", "first")
if class_token_pooling == "average":
# Average all tokens belonging to each class label
classes_embedding, classes_embedding_mask = self._extract_class_features_averaged(
token_embeds,
input_ids,
attention_mask,
class_token_mask,
num_class_tokens,
max_embed_dim,
batch_size,
embed_dim,
)
else:
# Original behavior: use only the class token (or token after it)
classes_embedding, classes_embedding_mask = self._extract_class_features_first(
token_embeds,
input_ids,
attention_mask,
class_token_mask,
num_class_tokens,
max_embed_dim,
batch_size,
embed_dim,
)
# Text features extraction
if self.config.extract_text_features:
text_token_mask = input_ids == self.config.text_token_index
text_token_indices = text_token_mask.int().argmax(dim=-1) # (batch,)
max_text_length = input_ids.shape[-1] # static, no GPU→CPU sync
# (batch, max_text_length): source position in token_embeds for each target slot
aranged_target_idx = (
torch.arange(max_text_length, device=token_embeds.device).unsqueeze(0).expand(batch_size, -1)
)
valid_mask = aranged_target_idx < (input_ids.shape[-1] - text_token_indices).unsqueeze(1)
source_indices = (text_token_indices.unsqueeze(1) + aranged_target_idx).clamp(max=input_ids.shape[-1] - 1)
batch_arange = torch.arange(batch_size, device=token_embeds.device).unsqueeze(1)
# Gather then zero-out invalid positions — no nonzero/scatter needed
text_tokens_embeddings = token_embeds[batch_arange, source_indices] * valid_mask.unsqueeze(-1).to(
token_embeds.dtype
)
text_tokens_mask = attention_mask[batch_arange, source_indices] * valid_mask
else:
text_tokens_embeddings = token_embeds
text_tokens_mask = attention_mask
return classes_embedding, classes_embedding_mask, text_tokens_embeddings, text_tokens_mask
def _extract_class_features_first(
self,
token_embeds,
input_ids,
attention_mask,
class_token_mask,
num_class_tokens,
max_embed_dim,
batch_size,
embed_dim,
):
"""Extract only the class token embedding (or token after it). Fully vectorized."""
class_cum = class_token_mask.long().cumsum(dim=-1) # (batch, seq)
k_range = torch.arange(max_embed_dim, device=token_embeds.device).view(1, -1, 1)
# select_mask[b, k, s] = True at the position of the k-th class token
select_mask = class_token_mask.unsqueeze(1) & ((class_cum.unsqueeze(1) - 1) == k_range)
if not self.config.embed_class_token:
# Shift right by 1: select the token immediately after each class token
shifted = torch.zeros_like(select_mask)
shifted[:, :, 1:] = select_mask[:, :, :-1]
select_mask = shifted
classes_embedding = torch.einsum("bks,bsd->bkd", select_mask.to(token_embeds.dtype), token_embeds)
arange_k = torch.arange(max_embed_dim, device=token_embeds.device).unsqueeze(0)
classes_embedding_mask = (arange_k < num_class_tokens).to(attention_mask.dtype)
return classes_embedding, classes_embedding_mask
def _extract_class_features_averaged(
self,
token_embeds,
input_ids,
attention_mask,
class_token_mask,
num_class_tokens,
max_embed_dim,
batch_size,
embed_dim,
):
"""Average all tokens belonging to each class label. Fully vectorized."""
# class_cum[b, s] = cumulative count of class tokens up to position s
class_cum = class_token_mask.long().cumsum(dim=-1) # (batch, seq)
if self.config.extract_text_features:
text_token_mask = input_ids == self.config.text_token_index
else:
text_token_mask = torch.zeros_like(class_token_mask)
# text_cum[b, s] >= 1 at and after the text token → use as exclusion boundary
text_cum = text_token_mask.long().cumsum(dim=-1) # (batch, seq)
# span_mask[b, k, s] = True if token s belongs to the span of class k
k_range = torch.arange(max_embed_dim, device=token_embeds.device).view(1, -1, 1)
span_mask = (
(class_cum.unsqueeze(1) == (k_range + 1)) # in the span of class k
& (text_cum.unsqueeze(1) == 0) # before the text boundary
& attention_mask.unsqueeze(1).bool() # real token (not padding)
)
if not self.config.embed_class_token:
span_mask = span_mask & ~class_token_mask.unsqueeze(1)
span_float = span_mask.to(token_embeds.dtype) # (batch, max_embed_dim, seq)
class_counts = span_float.sum(dim=-1, keepdim=True).clamp(min=1)
classes_embedding = torch.einsum("bks,bsd->bkd", span_float, token_embeds) / class_counts
arange_k = torch.arange(max_embed_dim, device=token_embeds.device).unsqueeze(0)
classes_embedding_mask = (arange_k < num_class_tokens).to(attention_mask.dtype)
return classes_embedding, classes_embedding_mask
def get_loss(self, logits, labels, classes_embedding=None, classes_embedding_mask=None):
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
# regression task
loss_fn = nn.MSELoss()
logits = logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim() == 1 or labels.size(-1) == 1:
label_index = (labels >= 0).nonzero()
labels = labels.long()
if label_index.size(0) > 0:
labeled_logits = torch.gather(
logits, 0, label_index.expand(label_index.size(0), logits.size(1))
)
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
else:
loss = torch.tensor(0).to(logits)
else:
log_softmax = nn.LogSoftmax(-1)
loss = -((log_softmax(logits) * labels).sum(-1)).mean()
elif self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
all_losses = focal_loss_with_logits(
logits,
labels,
self.config.focal_loss_alpha,
self.config.focal_loss_gamma,
self.config.focal_loss_reduction,
)
if classes_embedding_mask is not None:
all_losses = all_losses * classes_embedding_mask.float()
loss = all_losses.mean()
if self.config.contrastive_loss_coef > 0 and classes_embedding is not None:
contrastive_loss = sequence_contrastive_loss(classes_embedding, classes_embedding_mask)
loss = loss + contrastive_loss * self.config.contrastive_loss_coef
return loss
class GLiClassUniEncoder(GLiClassBaseModel):
def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
super().__init__(config)
if config.encoder_config is None:
if config.encoder_model_name is None:
raise ValueError("You need to specify encoder model name to use it as a backbone.")
config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
config_name = config.encoder_config.__class__.__name__
model_kwargs = {}
if config_name in DECODER_MODEL_MAPPING:
if not IS_LLM2VEC:
raise MissedPackageException(
f"The llm2vec package must be installed to use this decoder model: {config_name}"
)
else:
print("Loading decoder model using LLM2Vec...")
ModelClass = DECODER_MODEL_MAPPING[config_name]
decoder = True
elif config_name in {"T5Config", "MT5Config", "UMT5Config"}:
decoder = False
turbot5_type = os.environ.get("TURBOT5_ATTN_TYPE", "")
if turbot5_type and IS_TURBOT5:
ModelClass = FlashT5EncoderModel
model_kwargs = {"attention_type": turbot5_type}
elif config_name == "UMT5Config":
ModelClass = UMT5EncoderModel
else:
ModelClass = T5EncoderModel
elif config_name in {"DebertaV2Config"}:
decoder = False
if os.environ.get("USE_FLASHDEBERTA", "") and IS_FLASHDEBERTA:
print("Using FlashDeberta backend.")
ModelClass = FlashDebertaV2Model
else:
ModelClass = DebertaV2Model
else:
decoder = False
ModelClass = AutoModel
if from_pretrained:
self.encoder_model = ModelClass.from_pretrained(config.encoder_model_name, **model_kwargs)
elif decoder:
self.encoder_model = ModelClass(config.encoder_config)
elif config_name in {"T5Config", "MT5Config", "UMT5Config", "DebertaV2Config"}:
self.encoder_model = ModelClass._from_config(config.encoder_config)
else:
self.encoder_model = ModelClass.from_config(config.encoder_config)
if config.vocab_size is not None and hasattr(self.encoder_model, "resize_token_embeddings"):
current_vocab = self.encoder_model.config.vocab_size
if current_vocab != config.vocab_size:
self.encoder_model.resize_token_embeddings(config.vocab_size)
adapter_config_file = Path(config.encoder_model_name) / "adapter_config.json"
if adapter_config_file.exists():
if not IS_PEFT:
warnings.warn(
"Adapter configs were detected, if you want to apply them you need to install peft package.",
stacklevel=2,
)
else:
adapter_config = LoraConfig.from_pretrained(config.encoder_model_name)
self.encoder_model = get_peft_model(self.encoder_model, adapter_config)
if config.use_segment_embeddings:
self.segment_embeddings = nn.Embedding(3, config.encoder_config.hidden_size)
nn.init.normal_(self.segment_embeddings.weight, mean=0.0, std=config.initializer_range)
def _create_segment_ids(self, input_ids):
batch_size, _seq_length = input_ids.shape
segment_ids = torch.zeros_like(input_ids) # Default: segment 0 (labels)
# Find example token positions
example_token_mask = input_ids == self.config.example_token_index
example_token_indices = example_token_mask.int().argmin(dim=-1)
has_example = example_token_mask.any(dim=-1)
text_token_mask = input_ids == self.config.text_token_index
text_token_indices = text_token_mask.int().argmax(dim=-1)
for batch_idx in range(batch_size):
text_start = text_token_indices[batch_idx].item()
# If examples exist, assign segment 1 to example section
if has_example[batch_idx]:
example_start = example_token_indices[batch_idx].item()
segment_ids[batch_idx, text_start:example_start] = 1
segment_ids[batch_idx, example_start:] = 2
else:
segment_ids[batch_idx, text_start:] = 1
return segment_ids
def process_encoder_output(self, input_ids, attention_mask, encoder_layer, labels=None, max_num_classes=None):
classes_embedding, classes_embedding_mask, text_token_embeddings, text_mask = self._extract_class_features(
encoder_layer, input_ids, attention_mask, max_num_classes
)
if self.config.use_lstm:
text_token_embeddings = self.lstm(text_token_embeddings, text_mask)
pooled_output = self.pooler(text_token_embeddings)
pooled_output = self.text_projector(pooled_output)
pooled_output = self.dropout(pooled_output)
if self.config.normalize_features:
pooled_output = pooled_output / (pooled_output.norm(p=2, dim=-1, keepdim=True) + self.epsilon)
classes_embedding = self.classes_projector(classes_embedding)
if self.config.normalize_features:
classes_embedding = classes_embedding / (classes_embedding.norm(p=2, dim=-1, keepdim=True) + self.epsilon)
logits = self.scorer(pooled_output, classes_embedding, text_mask=text_mask)
if self.config.normalize_features:
logits = logits * self.logit_scale.to(classes_embedding.device)
loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)
return (logits, loss, pooled_output, classes_embedding)
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
output_text_embeddings: bool | None = None,
output_class_embeddings: bool | None = None,
return_dict: bool | None = None,
max_num_classes: int | None = None,
**kwargs,
) -> Tuple | GLiClassOutput:
r"""
Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.squeeze_layers or self.config.layer_wise:
output_hidden_states = True
return_dict = True
if self.config.use_segment_embeddings:
embedding_layer = self.encoder_model.get_input_embeddings()
token_embeds = embedding_layer(input_ids)
segment_ids = self._create_segment_ids(input_ids)
segment_embeds = self.segment_embeddings(segment_ids)
inputs_embeds = token_embeds + segment_embeds
outputs = self.encoder_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
else:
outputs = self.encoder_model(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
if self.config.layer_wise and labels is not None:
hidden_states = outputs.hidden_states
loss = 0
for encoder_layer in hidden_states:
logits, layer_loss, pooled_output, classes_embedding = self.process_encoder_output(
input_ids, attention_mask, encoder_layer, labels, max_num_classes
)
loss += layer_loss
else:
if self.config.encoder_layer_id == -1:
if self.config.squeeze_layers:
encoder_layer = self.layer_wise_attention(outputs.hidden_states)
else:
encoder_layer = outputs[0]
else:
encoder_layer = outputs.hidden_states[self.config.encoder_layer_id]
logits, loss, pooled_output, classes_embedding = self.process_encoder_output(
input_ids, attention_mask, encoder_layer, labels, max_num_classes
)
if not return_dict:
output = (logits, *outputs[1:])
return ((loss, *output)) if loss is not None else output
return GLiClassOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
text_embeddings=pooled_output if output_text_embeddings else None,
class_embeddings=classes_embedding if output_class_embeddings else None,
)
class GLiClassEncoderDecoder(GLiClassBaseModel):
def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
super().__init__(config)
if config.encoder_config is None:
if config.encoder_model_name is None:
raise ValueError("You need to specify encoder model name to use it as a backbone.")
config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
if not config.encoder_config.is_encoder_decoder:
raise ValueError("You need to choose encoder-decoder model as a backbone.")
if from_pretrained:
self.encoder_decoder_model = AutoModel.from_pretrained(config.encoder_model_name)
else:
self.encoder_decoder_model = AutoModel.from_config(config.encoder_config)
@staticmethod
def _make_bidirectional_4d_mask(attention_mask_2d, dtype):
"""Convert a 2D padding mask into a 4D bidirectional attention mask.
When a 4D mask is passed to the decoder, the model uses it as-is
without applying its default causal pattern, enabling bidirectional
self-attention in the decoder.
Args:
attention_mask_2d: (batch_size, seq_length) with 1 for real tokens, 0 for padding.
dtype: The dtype of the model (needed for the min-value fill).
Returns:
4D mask of shape (batch_size, 1, seq_length, seq_length).
Values are 0.0 for attended positions and a large negative value for masked positions.
"""
batch_size, seq_length = attention_mask_2d.shape
# (batch_size, 1, 1, seq_length) - masks out padding columns
padding_mask = (1.0 - attention_mask_2d.to(dtype))[:, None, None, :] * torch.finfo(dtype).min
return padding_mask.expand(batch_size, 1, seq_length, seq_length)
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
class_input_ids: torch.Tensor | None = None,
class_attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
output_text_embeddings: bool | None = None,
output_class_embeddings: bool | None = None,
return_dict: bool | None = True,
**kwargs,
) -> Tuple | SequenceClassifierOutput:
r"""
Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Build a 4D bidirectional mask for the decoder so it attends to
# all non-padding positions instead of using causal masking.
decoder_4d_mask = None
if class_attention_mask is not None:
model_dtype = next(self.encoder_decoder_model.parameters()).dtype
decoder_4d_mask = self._make_bidirectional_4d_mask(class_attention_mask, model_dtype)
outputs = self.encoder_decoder_model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=class_input_ids,
decoder_attention_mask=decoder_4d_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
text_token_embeddings = outputs.encoder_last_hidden_state
decoder_token_embeddings = outputs.last_hidden_state
classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(
decoder_token_embeddings, class_input_ids, class_attention_mask
)
if self.config.use_lstm:
text_token_embeddings = self.lstm(text_token_embeddings, attention_mask)
pooled_output = self.pooler(text_token_embeddings)
pooled_output = self.text_projector(pooled_output)
pooled_output = self.dropout(pooled_output)
if self.config.normalize_features:
pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)
classes_embedding = self.classes_projector(classes_embedding)
if self.config.normalize_features:
classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)
logits = self.scorer(pooled_output, classes_embedding)
if self.config.normalize_features:
logits = logits * self.logit_scale.to(classes_embedding.device)
loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)
if not return_dict:
output = (logits, *outputs[1:])
return ((loss, *output)) if loss is not None else output
return GLiClassOutput(
loss=loss,
logits=logits,
hidden_states=outputs.decoder_hidden_states,
attentions=outputs.decoder_attentions,
text_embeddings=pooled_output if output_text_embeddings else None,
class_embeddings=classes_embedding if output_class_embeddings else None,
)
class GLiClassEncoderDecoderCLS(GLiClassBaseModel):
"""Encoder-decoder architecture where labels go to the encoder and text goes to the decoder.
Class features are extracted from encoder output using _extract_class_features().
Text features are extracted from the last non-padding token of the decoder output.
"""
def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
super().__init__(config)
if config.encoder_config is None:
if config.encoder_model_name is None:
raise ValueError("You need to specify encoder model name to use it as a backbone.")
config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
if not config.encoder_config.is_encoder_decoder:
raise ValueError("You need to choose encoder-decoder model as a backbone.")
if from_pretrained:
self.encoder_decoder_model = AutoModel.from_pretrained(config.encoder_model_name)
else:
self.encoder_decoder_model = AutoModel.from_config(config.encoder_config)
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
class_input_ids: torch.Tensor | None = None,
class_attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
output_text_embeddings: bool | None = None,
output_class_embeddings: bool | None = None,
return_dict: bool | None = True,
**kwargs,
) -> Tuple | SequenceClassifierOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Labels → encoder, Text → decoder
outputs = self.encoder_decoder_model(
input_ids=class_input_ids,
attention_mask=class_attention_mask,
decoder_input_ids=input_ids,
decoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
# Class features from encoder output
encoder_token_embeddings = outputs.encoder_last_hidden_state
classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(
encoder_token_embeddings, class_input_ids, class_attention_mask
)
# Text features from decoder's last non-padding token
decoder_output = outputs.last_hidden_state
batch_size = decoder_output.shape[0]
last_non_pad_idx = attention_mask.sum(dim=1) - 1
pooled_output = decoder_output[torch.arange(batch_size, device=decoder_output.device), last_non_pad_idx]
pooled_output = self.text_projector(pooled_output)
pooled_output = self.dropout(pooled_output)
if self.config.normalize_features:
pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)
classes_embedding = self.classes_projector(classes_embedding)
if self.config.normalize_features:
classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)
logits = self.scorer(pooled_output, classes_embedding)
if self.config.normalize_features:
logits = logits * self.logit_scale.to(classes_embedding.device)
loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)
if not return_dict:
output = (logits, *outputs[1:])
return ((loss, *output)) if loss is not None else output
return GLiClassOutput(
loss=loss,
logits=logits,
hidden_states=outputs.decoder_hidden_states,
attentions=outputs.decoder_attentions,
text_embeddings=pooled_output if output_text_embeddings else None,
class_embeddings=classes_embedding if output_class_embeddings else None,
)
class GLiClassBiEncoder(GLiClassBaseModel):
def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
super().__init__(config)
if config.encoder_config is None:
if config.encoder_model_name is None:
raise ValueError("You need to specify encoder model name to use it as a backbone.")
config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
if config.label_model_config is None:
if config.label_model_name is None:
raise ValueError("You need to specify label model name to use it as a backbone.")
config.label_model_config = AutoConfig.from_pretrained(config.label_model_name)
def initialize_encoder(configs, model_name, from_pretrained):
if from_pretrained:
return AutoModel.from_pretrained(model_name)
else:
return AutoModel.from_config(configs)
self.encoder_model = initialize_encoder(config.encoder_config, config.encoder_model_name, from_pretrained)
self.label_encoder = initialize_encoder(config.label_model_config, config.label_model_name, from_pretrained)
self.biencoder_projector = BiEncoderProjector(config)
def pool_outputs(self, encoder_outputs):
text_embeddings = self.pooler(encoder_outputs[0])
text_embeddings = self.text_projector(text_embeddings)
text_embeddings = self.dropout(text_embeddings)
if self.config.normalize_features:
text_embeddings = nn.functional.normalize(text_embeddings, p=2, dim=-1, eps=self.epsilon)
return text_embeddings
def encode_text(self, input_ids, attention_mask):
outputs = self.encoder_model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))
text_embeddings = self.pool_outputs(outputs)
return text_embeddings
def encode_classes(self, class_input_ids, class_attention_mask, labels_mask=None):
batch_size = class_input_ids.shape[0]
num_classes = class_input_ids.shape[1]
if labels_mask is not None:
batch_indices, indices = torch.where(labels_mask == 1)
selected_input_ids = class_input_ids[batch_indices, indices]
selected_attention_mask = class_attention_mask[batch_indices, indices]
outputs = self.label_encoder(selected_input_ids, attention_mask=selected_attention_mask)
class_embeddings_filtered = self.pooler(outputs[0])
class_embeddings = torch.zeros(
batch_size,
num_classes,
class_embeddings_filtered.shape[-1],
dtype=class_embeddings_filtered.dtype,
device=class_embeddings_filtered.device,
)
class_embeddings[batch_indices, indices] = class_embeddings_filtered
else:
class_input_ids = class_input_ids.view(-1, class_input_ids.shape[-1])
class_attention_mask = class_attention_mask.view(-1, class_input_ids.shape[-1])
outputs = self.label_encoder(class_input_ids, attention_mask=class_attention_mask)
class_embeddings = self.pooler(outputs[0])
class_embeddings = class_embeddings.reshape(batch_size, num_classes, -1)
class_embeddings = self.biencoder_projector(class_embeddings)
class_embeddings = self.classes_projector(class_embeddings)
if self.config.normalize_features:
class_embeddings = nn.functional.normalize(class_embeddings, p=2, dim=-1, eps=self.epsilon)
return class_embeddings
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
class_input_ids: torch.Tensor | None = None,
class_attention_mask: torch.Tensor | None = None,
labels_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
output_text_embeddings: bool | None = None,
output_class_embeddings: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> Tuple | SequenceClassifierOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_embeddings = self.encode_text(input_ids, attention_mask)
class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)
logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)
if labels_mask is not None:
logits = torch.where(labels_mask == 0, -1e3, logits)
loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)
if not return_dict:
output = (logits,)
return ((loss, *output)) if loss is not None else output
return GLiClassOutput(
loss=loss,
logits=logits,
text_embeddings=text_embeddings if output_text_embeddings else None,
class_embeddings=class_embeddings if output_class_embeddings else None,
)
class GLiClassBiEncoderFused(GLiClassBiEncoder):
def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
super().__init__(config, from_pretrained)
def encode_text(self, input_ids, attention_mask, class_embeddings, labels_mask):
embedding_layer = self.encoder_model.get_input_embeddings()
inputs_embeds = embedding_layer(input_ids)
class_token_mask = input_ids == self.config.class_token_index
batch_indices, class_token_indices = torch.where(class_token_mask)
labels_batch_indices, labels_indices = torch.where(labels_mask == 1)
selected_class_embeddings = class_embeddings[labels_batch_indices, labels_indices]
inputs_embeds[batch_indices, class_token_indices] = selected_class_embeddings
encoder_outputs = self.encoder_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask.squeeze(1))
post_class_embeddings = torch.zeros_like(class_embeddings)
post_class_embeddings[labels_batch_indices, labels_indices] = encoder_outputs[0][
batch_indices, class_token_indices
]
return encoder_outputs, post_class_embeddings
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
class_input_ids: torch.Tensor | None = None,
class_attention_mask: torch.Tensor | None = None,
labels_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
output_text_embeddings: bool | None = None,
output_class_embeddings: bool | None = None,
return_dict: bool | None = None,
**kwargs,
) -> Tuple | SequenceClassifierOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
raw_class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)
encoder_outputs, class_embeddings = self.encode_text(
input_ids, attention_mask, raw_class_embeddings, labels_mask
)
text_embeddings = self.pool_outputs(encoder_outputs)
logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)
if labels_mask is not None:
logits = torch.where(labels_mask == 0, -1e3, logits)
loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)
if not return_dict:
output = (logits,)
return ((loss, *output)) if loss is not None else output
return GLiClassOutput(
loss=loss,
logits=logits,
text_embeddings=text_embeddings if output_text_embeddings else None,
class_embeddings=class_embeddings if output_class_embeddings else None,
)
class GLiClassModel(GLiClassPreTrainedModel):
def __init__(self, config, from_pretrained=False):
super().__init__(config)
if config.architecture_type == "uni-encoder":
self.model = GLiClassUniEncoder(config, from_pretrained)
elif config.architecture_type == "bi-encoder":
self.model = GLiClassBiEncoder(config, from_pretrained)
elif config.architecture_type == "bi-encoder-fused":
self.model = GLiClassBiEncoderFused(config, from_pretrained)
elif config.architecture_type == "encoder-decoder":
self.model = GLiClassEncoderDecoder(config, from_pretrained)
elif config.architecture_type == "encoder-decoder-cls":
self.model = GLiClassEncoderDecoderCLS(config, from_pretrained)
self.post_init()
def get_input_embeddings(self):
if self.config.architecture_type in {"uni-encoder"}:
return self.model.encoder_model.get_input_embeddings()
elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
return self.model.encoder_decoder_model.get_input_embeddings()
else:
raise NotImplementedError("Getting input embeddings is not implemented for bi-encoder architecture")
def set_input_embeddings(self, value):
if self.config.architecture_type in {"uni-encoder"}:
self.model.encoder_model.set_input_embeddings(value)
return None
elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
self.model.encoder_decoder_model.set_input_embeddings(value)
elif self.config.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
self.model.encoder_model.set_input_embeddings(value)
else:
raise NotImplementedError("Setting input embeddings is not implemented for bi-encoder architecture")
def tie_weights(self, recompute_mapping=True, missing_keys=None):
"""
Tie model weights for architectures that share parameters.
This method handles:
- Version compatibility between transformers v4 and v5
- Different GLiClass architecture types
- Special handling for T5/MT5 models in transformers v5+ where encoder.embed_tokens
may be incorrectly initialized instead of being tied to shared.weight
Args:
recompute_mapping: Whether to recompute weight mapping (transformers v5+)
missing_keys: Keys that are missing from checkpoint (transformers v5+)
"""
# Get encoder model based on architecture type
encoder_model = None
if self.config.architecture_type in {"uni-encoder"}:
encoder_model = self.model.encoder_model
elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
encoder_model = self.model.encoder_decoder_model
elif self.config.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
encoder_model = self.model.encoder_model
else:
raise NotImplementedError("Tie weights is not implemented for this architecture type")
# Call base tie_weights with version-appropriate parameters
if version.parse(transformers.__version__) >= version.parse("5.0.0"):
result = encoder_model.tie_weights(recompute_mapping=recompute_mapping, missing_keys=missing_keys)
else:
result = encoder_model.tie_weights()
# Fix for T5/MT5/UMT5 models in transformers v5+
# In v5, if encoder.embed_tokens.weight is missing from checkpoint, it gets randomly
# initialized instead of being tied to shared.weight. We explicitly ensure proper tying.
if (
encoder_model is not None
and hasattr(encoder_model, "shared")
and hasattr(encoder_model, "encoder")
and hasattr(encoder_model.encoder, "embed_tokens")
):
shared_weight = encoder_model.shared.weight
embed_weight = encoder_model.encoder.embed_tokens.weight
# Only tie if they're not already the same tensor
if shared_weight is not embed_weight:
encoder_model.encoder.embed_tokens.weight = shared_weight
if version.parse(transformers.__version__) >= version.parse("5.0.0"):
logger.info(
"Applied transformers v5 compatibility fix: tied encoder.embed_tokens.weight "
"to shared.weight for T5-based model"
)
return result
def resize_token_embeddings(self, new_num_tokens: int | None = None, pad_to_multiple_of=None) -> nn.Embedding:
if self.config.architecture_type in {"uni-encoder"}:
model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
elif self.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
model_embeds = self.model.encoder_decoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
elif self.config.architecture_type in {"bi-encoder-fused"}:
model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
else:
raise NotImplementedError("Resizing is not implemented for bi-encoder architecture")
self.config.encoder_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def forward(self, *args, **kwargs):
outputs = self.model(*args, **kwargs)
return outputs
================================================
FILE: gliclass/ops.py
================================================
import torch
import torch.nn.functional as F
# ─── Attention (padded) ───────────────────────────────────────────────────────
def attn_padded(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_padding_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
) -> torch.Tensor:
"""
Padded attention via F.scaled_dot_product_attention.
Uses FlashAttention backend automatically on CUDA when available.
Args:
q: [batch, nq, nheads, head_dim]
k: [batch, nk, nheads, head_dim]
v: [batch, nk, nheads, head_dim]
key_padding_mask: [batch, nk] bool, True = real token
Returns:
[batch, nq, nheads, head_dim]
"""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_mask = None
if key_padding_mask is not None:
attn_mask = key_padding_mask[:, None, None, :].bool()
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
)
return out.transpose(1, 2) # [batch, nq, nheads, head_dim]
================================================
FILE: gliclass/pipeline.py
================================================
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from .model import GLiClassModel, GLiClassBiEncoder
from .utils import retrieval_augmented_text
def flatten_hierarchical_labels(
labels: List[str] | Dict[str, Any], prefix: str = "", separator: str = "."
) -> List[str]:
"""
Flatten hierarchical labels into dot notation.
Supports arbitrary nesting depth. Examples:
Input: {"sentiment": ["positive", "negative", "neutral"], "topic": ["product", "service", "shipping"]}
Output: ["sentiment.positive", "sentiment.negative", "sentiment.neutral",
"topic.product", "topic.service", "topic.shipping"]
Input: {
"category": {
"electronics": ["phone", "laptop"],
"clothing": ["shirt", "pants"]
}
}
Output: [
"category.electronics.phone",
"category.electronics.laptop",
"category.clothing.shirt",
"category.clothing.pants"
]
Input: ["label1", "label2"] # Already flat
Output: ["label1", "label2"]
Args:
labels: Either a list of string labels or a hierarchical dict
prefix: Current prefix for recursion (internal use)
separator: Separator to use between hierarchy levels (default: ".")
Returns:
List of flattened label strings with dot notation
"""
if isinstance(labels, list):
if prefix:
return [f"{prefix}{separator}{label}" for label in labels]
return labels
elif isinstance(labels, dict):
flattened = []
for key, value in labels.items():
new_prefix = f"{prefix}{separator}{key}" if prefix else key
flattened.extend(flatten_hierarchical_labels(value, new_prefix, separator))
return flattened
elif isinstance(labels, str):
if prefix:
return [f"{prefix}{separator}{labels}"]
return [labels]
else:
raise ValueError(f"Unsupported label type: {type(labels)}. Expected list, dict, or str.")
def build_hierarchical_output(
predictions: List[Dict[str, float]],
original_labels: List[str] | Dict[str, Any],
separator: str = ".",
all_scores: Dict[str, float] | None = None,
) -> Dict[str, float] | Dict[str, Any]:
"""
Build hierarchical output structure matching the input labels structure.
Args:
predictions: List of prediction dicts with 'label' and 'score'
original_labels: Original hierarchical labels structure
separator: Separator used in flattened labels
all_scores: Optional dict of all label scores (for complete output)
Returns:
Hierarchical structure with scores matching the input format
Example:
Input predictions: [
{'label': 'sentiment.positive', 'score': 0.85},
{'label': 'topic.product', 'score': 0.72}
]
Input original_labels: {
"sentiment": ["positive", "negative", "neutral"],
"topic": ["product", "service", "shipping"]
}
Output: {
"sentiment": {"positive": 0.85, "negative": 0.0, "neutral": 0.0},
"topic": {"product": 0.72, "service": 0.0, "shipping": 0.0}
}
"""
score_lookup = {pred["label"]: pred["score"] for pred in predictions}
if all_scores:
for k, v in all_scores.items():
if k not in score_lookup:
score_lookup[k] = v
def _build_recursive(structure: List[str] | Dict[str, Any], prefix: str = "") -> Dict[str, float] | Dict[str, Any]:
if isinstance(structure, list):
result = {}
for label in structure:
full_label = f"{prefix}{separator}{label}" if prefix else label
result[label] = score_lookup.get(full_label, 0.0)
return result
elif isinstance(structure, dict):
result = {}
for key, value in structure.items():
new_prefix = f"{prefix}{separator}{key}" if prefix else key
result[key] = _build_recursive(value, new_prefix)
return result
elif isinstance(structure, str):
full_label = f"{prefix}{separator}{structure}" if prefix else structure
return {structure: score_lookup.get(full_label, 0.0)}
return {}
if isinstance(original_labels, list):
return {label: score_lookup.get(label, 0.0) for label in original_labels}
return _build_recursive(original_labels)
def format_examples_prompt(
examples: List[Dict[str, Any]], example_token: str = "<<EXAMPLE>>", sep_token: str = "<<SEP>>"
) -> str:
r"""
Format few-shot examples into a prompt string using <<EXAMPLE>> token.
Format matches training: <<EXAMPLE>>text \nLabels:\n label1, label2
with a single <<SEP>> after all examples.
Args:
examples: List of example dicts with 'text' and 'labels'/'true_labels' keys
example_token: Token to mark examples (default: "<<EXAMPLE>>")
sep_token: Separator token after all examples (default: "<<SEP>>")
Returns:
Formatted examples string
"""
if not examples:
return ""
formatted_parts = []
for example in examples:
text = example.get("text", "")
labels = example.get("labels", example.get("true_labels", []))
if isinstance(labels, list):
labels_str = ", ".join(labels)
else:
labels_str = str(labels)
# Match training format: " \nLabels:\n " instead of "\nLabels: "
formatted_parts.append(f"{example_token}{text} \nLabels:\n {labels_str}")
# Add single SEP token after all examples (matching training)
formatted_parts.append(sep_token)
return "".join(formatted_parts)
class BaseZeroShotClassificationPipeline(ABC):
def __init__(
self,
model,
tokenizer,
max_classes=25,
max_length=1024,
classification_type="multi-label",
device="cuda:0",
progress_bar=True,
label_separator: str = ".",
):
self.model = model
if isinstance(tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
else:
self.tokenizer = tokenizer
self.max_classes = max_classes
self.classification_type = classification_type
self.max_length = max_length
self.progress_bar = progress_bar
self.label_separator = label_separator
self._max_labels_alloc = getattr(model.config, "max_labels_alloc", "dynamic")
self.example_token = "<<EXAMPLE>>"
self.label_token = "<<LABEL>>"
self.sep_token = "<<SEP>>"
if not isinstance(device, torch.device):
if torch.cuda.is_available() and "cuda" in device:
self.device = torch.device(device)
else:
self.device = torch.device("cpu")
else:
self.device = device
if self.model.device != self.device:
self.model.to(self.device)
# Ensure model is in evaluation mode for inference
self.model.eval()
def _normalize_classification_type(self, classification_type: str | None) -> str:
if classification_type is None:
return self.classification_type
normalized = classification_type.strip().lower()
if normalized in {"single", "single-label", "single_label"}:
return "single-label"
if normalized in {"multi", "multi-label", "multi_label"}:
return "multi-label"
raise ValueError("Unsupported classification type: choose 'single-label' or 'multi-label'")
def _normalize_texts(self, texts: str | List[str]) -> List[str]:
if isinstance(texts, str):
return [texts]
return texts
def _normalize_thresholds(self, threshold: float | List[float], num_texts: int) -> List[float]:
if isinstance(threshold, list):
if len(threshold) != num_texts:
raise ValueError("Length of threshold list must match number of texts.")
return threshold
return [threshold] * num_texts
def _normalize_classification_types(
self,
classification_type: str | List[str] | None,
num_texts: int,
) -> List[str]:
if isinstance(classification_type, list):
if len(classification_type) != num_texts:
raise ValueError("Length of classification_type list must match number of texts.")
return [self._normalize_classification_type(item) for item in classification_type]
normalized = self._normalize_classification_type(classification_type)
return [normalized] * num_texts
def _process_labels(
self, labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]]
) -> List[str] | List[List[str]]:
"""Process labels to handle hierarchical structures."""
if not labels:
return labels
if isinstance(labels, dict):
return flatten_hierarchical_labels(labels, separator=self.label_separator)
if isinstance(labels, list):
if len(labels) == 0:
return labels
first_elem = labels[0]
if isinstance(first_elem, str):
return labels
if isinstance(first_elem, dict):
return [flatten_hierarchical_labels(lbl, separator=self.label_separator) for lbl in labels]
if isinstance(first_elem, list):
if first_elem and isinstance(first_elem[0], dict):
return [flatten_hierarchical_labels(lbl, separator=self.label_separator) for lbl in labels]
return labels
return labels
def _format_examples_for_input(self, examples: List[Dict[str, Any]] | None = None) -> str:
"""Format few-shot examples using <<EXAMPLE>> and <<SEP>> tokens."""
if not examples:
return ""
examples = [example for example in examples if example is not None]
if not examples:
return ""
return format_examples_prompt(examples, example_token=self.example_token, sep_token=self.sep_token)
def _examples_are_per_text(self, examples) -> bool:
"""Detect whether examples are provided per text rather than shared."""
if not isinstance(examples, list) or len(examples) == 0:
return False
if all(isinstance(example, dict) for example in examples):
return False
return all(example is None or isinstance(example, list) for example in examples)
def _get_text_examples(self, examples, index: int):
"""Get examples for a single text from shared or per-text input."""
if not examples:
return None
if self._examples_are_per_text(examples):
return examples[index] if index < len(examples) else None
return examples
def _format_prompt(self, prompt: str | List[str] | None = None, index: int = 0) -> str:
"""Format the task description prompt."""
if prompt is None:
return ""
if isinstance(prompt, str):
return prompt
if isinstance(prompt, list):
if index < len(prompt):
return prompt[index]
return prompt[0] if prompt else ""
return ""
def _resolve_max_num_classes(self, batch_labels, same_labels: bool):
if self._max_labels_alloc == "dynamic":
return len(batch_labels) if same_labels else max(len(labels) for labels in batch_labels)
if isinstance(self._max_labels_alloc, int):
return self._max_labels_alloc
return None # 'fixed': model uses config.max_num_classes
@abstractmethod
def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
pass
def _get_batch_examples(self, examples, start_idx, batch_size):
"""Get examples for current batch."""
if not examples:
return None
if self._examples_are_per_text(examples):
return examples[start_idx : start_idx + batch_size]
return examples
def _get_batch_prompt(self, prompt, start_idx, batch_size):
"""Get prompt for current batch."""
if not prompt:
return None
if isinstance(prompt, list):
return prompt[start_idx : start_idx + batch_size]
return prompt
@torch.no_grad()
def get_embeddings(self, texts, labels, batch_size=8, examples=None, prompt=None):
if isinstance(texts, str):
texts = [texts]
labels = self._process_labels(labels)
if isinstance(labels[0], str):
same_labels = True
else:
same_labels = False
results = []
iterable = range(0, len(texts), batch_size)
if self.progress_bar:
iterable = tqdm(iterable)
for idx in iterable:
batch_texts = texts[idx : idx + batch_size]
batch_examples = self._get_batch_examples(examples, idx, len(batch_texts))
batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))
tokenized_inputs = self.prepare_inputs(
batch_texts, labels, same_labels, examples=batch_examples, prompt=batch_prompt
)
max_num_classes = self._resolve_max_num_classes(labels, same_labels)
model_output = self.model(
**tokenized_inputs,
max_num_classes=max_num_classes,
output_text_embeddings=True,
output_class_embeddings=True,
)
logits = model_output.logits
text_embeddings = model_output.text_embeddings
class_embeddings = model_output.class_embeddings
batch_size_actual = logits.shape[0]
for i in range(batch_size_actual):
result = {
"logits": logits[i].cpu().numpy(),
"text_embedding": text_embeddings[i].cpu().numpy(),
"class_embeddings": class_embeddings[i].cpu().numpy(),
}
results.append(result)
return results
@torch.no_grad()
def __call__(
self,
texts: str | List[str],
labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]],
threshold: float | List[float] = 0.5,
batch_size: int = 8,
classification_type: str | List[str] | None = None,
rac_examples: List | None = None,
examples: List[Dict[str, Any]] | None = None,
prompt: str | List[str] | None = None,
return_hierarchical: bool = False,
):
"""
Perform zero-shot classification.
Args:
texts: Single text or list of texts to classify
labels: Labels in various formats (flat list or hierarchical dict)
threshold: Classification threshold for multi-label, either one
value for all texts or one value per text
batch_size: Batch size for processing
classification_type: Override classification mode globally or per text.
If None, uses the pipeline's configured classification_type
rac_examples: Retrieval augmented examples (legacy)
examples: Few-shot examples with 'text' and 'labels'/'true_labels' keys
prompt: Task description - string (same for all) or list (per-text)
return_hierarchical: If True, return hierarchical structure with all scores
Returns:
List of classification results or hierarchical dict structure.
"""
original_labels = labels
texts = self._normalize_texts(texts)
thresholds = self._normalize_thresholds(threshold, len(texts))
classification_types = self._normalize_classification_types(classification_type, len(texts))
if rac_examples:
if len(texts) == 1 and not isinstance(rac_examples[0], list):
texts = [retrieval_augmented_text(texts[0], rac_examples)]
else:
texts = [retrieval_augmented_text(text, ex) for text, ex in zip(texts, rac_examples)]
processed_labels = self._process_labels(labels)
if isinstance(processed_labels[0], str):
same_labels = True
else:
same_labels = False
results = []
all_scores_list = []
iterable = range(0, len(texts), batch_size)
if self.progress_bar:
iterable = tqdm(iterable)
for idx in iterable:
batch_texts = texts[idx : idx + batch_size]
if not same_labels:
batch_labels = processed_labels[idx : idx + batch_size]
else:
batch_labels = processed_labels
batch_examples = self._get_batch_examples(examples, idx, len(batch_texts))
batch_prompt = self._get_batch_prompt(prompt, idx, len(batch_texts))
tokenized_inputs = self.prepare_inputs(
batch_texts, batch_labels, same_labels, examples=batch_examples, prompt=batch_prompt
)
max_num_classes = self._resolve_max_num_classes(batch_labels, same_labels)
model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)
logits = model_output.logits
probs = torch.sigmoid(logits)
for i in range(len(batch_texts)):
global_idx = idx + i
item_classification_type = classification_types[global_idx]
item_threshold = thresholds[global_idx]
if same_labels:
curr_labels = batch_labels
else:
curr_labels = batch_labels[i]
if item_classification_type == "single-label":
score = torch.softmax(logits[i][: len(curr_labels)], dim=-1)
if return_hierarchical:
all_scores = {curr_labels[j]: score[j].item() for j in range(len(curr_labels))}
all_scores_list.append(all_scores)
pred_label = curr_labels[torch.argmax(score).item()]
results.append([{"label": pred_label, "score": score.max().item()}])
elif item_classification_type == "multi-label":
text_results = []
if return_hierarchical:
all_scores = {curr_labels[j]: probs[i][j].item() for j in range(len(curr_labels))}
all_scores_list.append(all_scores)
for j, prob in enumerate(probs[i][: len(curr_labels)]):
score = prob.item()
if score >= item_threshold:
text_results.append({"label": curr_labels[j], "score": score})
results.append(text_results)
else:
raise ValueError("Unsupported classification type: choose 'single-label' or 'multi-label'")
if return_hierarchical:
hierarchical_results = []
for i, (result, all_scores) in enumerate(zip(results, all_scores_list)):
if same_labels:
orig_lbl = original_labels
else:
orig_lbl = original_labels[i] if i < len(original_labels) else original_labels
hierarchical_results.append(
build_hierarchical_output(result, orig_lbl, self.label_separator, all_scores)
)
return hierarchical_results
return results
class UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
def __init__(
self,
model,
tokenizer,
max_classes=25,
max_length=1024,
classification_type="multi-label",
device="cuda:0",
progress_bar=True,
label_separator: str = ".",
):
super().__init__(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
def prepare_input(self, text, labels, examples=None, prompt=None):
"""
Prepare input matching training format from data_processing.py:
Order: Labels → SEP → Prompt → Text → Examples.
"""
input_parts = []
# 1. Add labels
for label in labels:
label_tag = f"{self.label_token}{label}"
input_parts.append(label_tag)
input_parts.append(self.sep_token)
# 2. Add task description prompt
if prompt:
input_parts.append(prompt)
# 3. Format examples to go after text
examples_str = ""
if examples:
examples_str = self._format_examples_for_input(examples)
if self.model.config.prompt_first:
return "".join(input_parts) + text + examples_str
else:
return text + "".join(input_parts) + examples_str
def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
inputs = []
if same_labels:
for i, text in enumerate(texts):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))
else:
for i, (text, labels_) in enumerate(zip(texts, labels)):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))
tokenized_inputs = self.tokenizer(
inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
).to(self.device)
return tokenized_inputs
class EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
def __init__(
self,
model,
tokenizer,
max_classes=25,
max_length=1024,
classification_type="multi-label",
device="cuda:0",
progress_bar=True,
label_separator: str = ".",
):
super().__init__(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
def prepare_labels_prompt(self, labels, prompt=None):
"""Match training format: Labels → SEP → Prompt."""
input_parts = []
for label in labels:
# label_tag = f"{label}{self.label_token}"
label_tag = f"{self.label_token}{label}"
input_parts.append(label_tag)
input_parts.append(self.sep_token)
if prompt:
input_parts.append(prompt)
return "".join(input_parts)
def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
prompts = []
processed_texts = []
if same_labels:
for i, text in enumerate(texts):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
prompts.append(self.prepare_labels_prompt(labels, text_prompt))
examples_str = self._format_examples_for_input(text_examples) if text_examples else ""
processed_texts.append(text + examples_str)
else:
for i, labels_ in enumerate(labels):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
prompts.append(self.prepare_labels_prompt(labels_, text_prompt))
examples_str = self._format_examples_for_input(text_examples) if text_examples else ""
processed_texts.append(texts[i] + examples_str)
tokenized_inputs = self.tokenizer(
processed_texts, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
).to(self.device)
tokenized_classes = self.tokenizer(
prompts, max_length=self.max_length, truncation=True, padding="longest", return_tensors="pt"
).to(self.device)
tokenized_inputs["class_input_ids"] = tokenized_classes["input_ids"]
tokenized_inputs["class_attention_mask"] = tokenized_classes["attention_mask"]
return tokenized_inputs
class BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
def __init__(
self,
model,
tokenizer,
max_classes=25,
max_length=1024,
classification_type="multi-label",
device="cuda:0",
progress_bar=True,
label_separator: str = ".",
):
super().__init__(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
self.labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)
def prepare_input(self, text, labels, examples=None, prompt=None):
input_parts = []
if prompt:
input_parts.append(prompt)
input_parts.append(" ")
for _label in labels:
input_parts.append(self.label_token)
input_parts.append(self.sep_token)
examples_str = ""
if examples:
examples_str = self._format_examples_for_input(examples)
if self.model.config.prompt_first:
return "".join(input_parts) + text + examples_str
else:
return text + "".join(input_parts) + examples_str
def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
if self.model.config.architecture_type == "bi-encoder-fused":
inputs = []
if same_labels:
for i, text in enumerate(texts):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))
else:
for i, (text, labels_) in enumerate(zip(texts, labels)):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))
else:
inputs = []
for i, text in enumerate(texts):
text_prompt = self._format_prompt(prompt, i)
if text_prompt:
inputs.append(f"{text_prompt} {text}")
else:
inputs.append(text)
if same_labels:
tokenized_inputs = self.tokenizer(
inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
).to(self.device)
tokenized_labels = self.labels_tokenizer(
labels, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
).to(self.device)
tokenized_inputs["class_input_ids"] = tokenized_labels["input_ids"].expand(len(texts), -1, -1)
tokenized_inputs["class_attention_mask"] = tokenized_labels["attention_mask"].expand(len(texts), -1, -1)
labels_mask = [[1 for _ in range(len(labels))] for _ in range(len(texts))]
tokenized_inputs["labels_mask"] = torch.tensor(labels_mask).to(self.device)
else:
tokenized_inputs = self.tokenizer(
inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
).to(self.device)
class_input_ids = []
class_attention_mask = []
for labels_set in labels:
tokenized_labels = self.labels_tokenizer(
labels_set, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt"
).to(self.device)
class_input_ids.append(tokenized_labels["input_ids"])
class_attention_mask.append(tokenized_labels["attention_mask"])
tokenized_inputs["class_input_ids"] = torch.stack(class_input_ids)
tokenized_inputs["class_attention_mask"] = torch.stack(class_attention_mask)
labels_mask = [[1 for _ in range(len(labels[j]))] for j in range(len(texts))]
tokenized_inputs["labels_mask"] = torch.tensor(labels_mask).to(self.device)
return tokenized_inputs
class ZeroShotClassificationPipeline:
"""
Main pipeline class for zero-shot classification with GLiClass models.
Supports:
- Hierarchical labels with dot notation (e.g., {"sentiment": ["positive", "negative"]})
- Few-shot examples with <<EXAMPLE>> token
- Task description prompts
- Hierarchical output format matching input structure
Example usage:
```python
from gliclass import ZeroShotClassificationPipeline
pipeline = ZeroShotClassificationPipeline(model, tokenizer)
# === Hierarchical Labels for Review Classification ===
hierarchical_labels = {
"sentiment": ["positive", "negative", "neutral"],
"topic": ["product", "service", "shipping"],
}
# Basic classification
results = pipeline("The product quality is amazing but delivery was slow", hierarchical_labels)
# Results: [
# {'label': 'sentiment.positive', 'score': 0.89},
# {'label': 'topic.product', 'score': 0.92},
# {'label': 'topic.shipping', 'score': 0.76}
# ]
# === With Task Description Prompt ===
results = pipeline(
"The product quality is amazing but delivery was slow",
hierarchical_labels,
prompt="Classify this customer review by sentiment and topic:",
)
# === With Few-Shot Examples (uses <<EXAMPLE>> token) ===
examples = [
{"text": "Love this item, great quality!", "labels": ["sentiment.positive", "topic.product"]},
{"text": "Customer support was unhelpful and rude", "labels": ["sentiment.negative", "topic.service"]},
{"text": "Package arrived damaged after 2 weeks", "labels": ["sentiment.negative", "topic.shipping"]},
]
results = pipeline(
"Fast delivery and the item works perfectly!",
hierarchical_labels,
examples=examples,
prompt="Classify customer feedback:",
)
# === Hierarchical Output (matches input structure) ===
results = pipeline(
"The product quality is amazing but delivery was slow", hierarchical_labels, return_hierarchical=True
)
# Returns:
# {
# "sentiment": {
# "positive": 0.89,
# "negative": 0.05,
# "neutral": 0.12
# },
# "topic": {
# "product": 0.92,
# "service": 0.15,
# "shipping": 0.76
# }
# }
# === Per-Text Prompts ===
results = pipeline(
["Electronics review text", "Clothing review text"],
hierarchical_labels,
prompt=["Analyze this electronics review:", "Analyze this clothing review:"],
)
```
"""
def __init__(
self,
model,
tokenizer,
max_classes: int = 25,
max_length: int = 1024,
classification_type: str = "multi-label",
device: str = "cuda:0",
progress_bar: bool = True,
label_separator: str = ".",
):
"""
Initialize the classification pipeline.
Args:
model: GLiClass model or path to model
tokenizer: Tokenizer or path to tokenizer
max_classes: Maximum number of classes to process
max_length: Maximum sequence length
classification_type: 'single-label' or 'multi-label'
device: Device to run inference on
progress_bar: Whether to show progress bar
label_separator: Separator for hierarchical label notation (default: ".")
"""
if isinstance(model, str):
model = GLiClassBiEncoder.from_pretrained(model)
self.label_separator = label_separator
if model.config.architecture_type == "uni-encoder":
self.pipe = UniEncoderZeroShotClassificationPipeline(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
elif model.config.architecture_type in {"encoder-decoder", "encoder-decoder-cls"}:
self.pipe = EncoderDecoderZeroShotClassificationPipeline(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
elif model.config.architecture_type in {"bi-encoder", "bi-encoder-fused"}:
self.pipe = BiEncoderZeroShotClassificationPipeline(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
else:
raise NotImplementedError("This architecture is not implemented")
def flatten_labels(self, labels: List[str] | Dict[str, Any]) -> List[str]:
"""
Flatten hierarchical labels to dot notation.
Example:
>>> pipeline.flatten_labels(
... {"sentiment": ["positive", "negative", "neutral"], "topic": ["product", "service", "shipping"]}
... )
["sentiment.positive", "sentiment.negative", "sentiment.neutral",
"topic.product", "topic.service", "topic.shipping"]
"""
return flatten_hierarchical_labels(labels, separator=self.label_separator)
def get_embeddings(self, *args, **kwargs):
"""Get embeddings for texts and labels."""
return self.pipe.get_embeddings(*args, **kwargs)
def __call__(
self,
texts: str | List[str],
labels: List[str] | Dict[str, Any] | List[List[str]] | List[Dict[str, Any]],
threshold: float | List[float] = 0.5,
batch_size: int = 8,
classification_type: str | List[str] | None = None,
rac_examples: List | None = None,
examples: List[Dict[str, Any]] | None = None,
prompt: str | List[str] | None = None,
return_hierarchical: bool = False,
):
"""
Perform zero-shot classification.
Args:
texts: Single text or list of texts to classify
labels: Labels - flat list or hierarchical dict
Examples:
- ["positive", "negative"] - flat labels
- {"sentiment": ["positive", "negative"], "topic": ["product", "service"]}
threshold: Classification threshold for multi-label, either one
value for all texts or one value per text
batch_size: Batch size for processing
classification_type: Override classification mode globally or per text.
If None, uses the pipeline's configured classification_type
rac_examples: Retrieval augmented examples (legacy)
examples: Few-shot examples, each with 'text' and 'labels' keys
prompt: Task description - string or list of strings (per-text)
return_hierarchical: If True, return structure matching input labels
Returns:
List of predictions (flat) or hierarchical dicts with all scores
"""
return self.pipe(
texts,
labels,
threshold=threshold,
batch_size=batch_size,
classification_type=classification_type,
rac_examples=rac_examples,
examples=examples,
prompt=prompt,
return_hierarchical=return_hierarchical,
)
class ZeroShotClassificationWithChunkingPipeline(BaseZeroShotClassificationPipeline):
"""Pipeline with long text chunking support."""
def __init__(
self,
model,
tokenizer,
max_classes: int = 25,
max_length: int = 1024,
classification_type: str = "multi-label",
device: str = "cuda:0",
progress_bar: bool = True,
text_chunk_size: int = 8192,
text_chunk_overlap: int = 256,
labels_chunk_size: int = 8,
label_separator: str = ".",
):
if isinstance(model, str):
model = GLiClassModel.from_pretrained(model)
super().__init__(
model, tokenizer, max_classes, max_length, classification_type, device, progress_bar, label_separator
)
self.text_chunk_size = text_chunk_size
self.text_chunk_overlap = text_chunk_overlap
self.labels_chunk_size = labels_chunk_size
def chunk_text(self, text, chunk_size=None, overlap=None):
"""Split text into overlapping chunks."""
if chunk_size is None:
chunk_size = self.text_chunk_size
if overlap is None:
overlap = self.text_chunk_overlap
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
if end >= len(text):
break
start = end - overlap
return chunks
def prepare_input(self, text, labels, examples=None, prompt=None):
"""
Prepare input matching training format from data_processing.py:
Order: Labels → SEP → Prompt → Text → Examples.
"""
input_parts = []
# 1. Add labels
for label in labels:
label_tag = f"{self.label_token}{label}"
input_parts.append(label_tag)
input_parts.append(self.sep_token)
# 2. Add task description prompt
if prompt:
input_parts.append(prompt)
# 3. Format examples to go after text
examples_str = ""
if examples:
examples_str = self._format_examples_for_input(examples)
if self.model.config.prompt_first:
return "".join(input_parts) + text + examples_str
else:
return text + "".join(input_parts) + examples_str
def prepare_inputs(self, texts, labels, same_labels=False, examples=None, prompt=None):
inputs = []
if same_labels:
for i, text in enumerate(texts):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
inputs.append(self.prepare_input(text, labels, text_examples, text_prompt))
else:
for i, (text, labels_) in enumerate(zip(texts, labels)):
text_examples = self._get_text_examples(examples, i)
text_prompt = self._format_prompt(prompt, i)
inputs.append(self.prepare_input(text, labels_, text_examples, text_prompt))
tokenized_inputs = self.tokenizer(
inputs, truncation=True, max_length=self.max_length, padding="longest", return_tensors="pt"
).to(self.device)
return tokenized_inputs
def aggregate_chunk_scores(self, chunk_scores: List[Dict[str, float]], labels: List[str]) -> Dict[str, float]:
"""Aggregate scores across text chunks using max pooling."""
aggregated = dict.fromkeys(labels, 0.0)
for scores in chunk_scores:
for label, score in scores.items():
aggregated[label] = max(aggregated[label], score)
return aggregated
@torch.no_grad()
def process_single_text(self, text, labels, threshold=0.5, examples=None, prompt=None):
"""Process a single long text through chunks."""
text_chunks = self.chunk_text(text)
all_chunk_scores = []
for text_chunk in text_chunks:
chunk_logits = []
all_labels = []
for labels_idx in range(0, len(labels), self.labels_chunk_size):
curr_labels = labels[labels_idx : labels_idx + self.labels_chunk_size]
if labels_idx == 0:
all_labels = []
all_labels.extend(curr_labels)
tokenized_inputs = self.prepare_inputs(
[text_chunk], curr_labels, same_labels=True, examples=examples, prompt=prompt
)
max_num_classes = self._resolve_max_num_classes(curr_labels, same_labels=True)
model_output = self.model(**tokenized_inputs, max_num_classes=max_num_classes)
logits = model_output.logits
chunk_logits.extend(logits[0][: len(curr_labels)].tolist())
text_logits = torch.tensor(chunk_logits)
if self.classification_type == "single-label":
scores = torch.softmax(text_logits, dim=-1)
else:
scores = torch.sigmoid(text_logits)
chunk_score_dict = {label: scores[i].item() for i, label in enumerate(all_labels)}
all_chunk_scores.append(chunk_score_dict)
aggregated_scores = self.aggregate_chunk_scores(all_chunk_scores, labels)
if self.classification_type == "single-label":
total = sum(aggregated_scores.values())
if total > 0:
aggregated_scores = {k: v / total for k, v in aggregated_scores.i
gitextract_2n3b6llp/ ├── .github/ │ └── workflows/ │ ├── release.yaml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── gliclass/ │ ├── __init__.py │ ├── config.py │ ├── data_processing.py │ ├── layers.py │ ├── loss_functions.py │ ├── model.py │ ├── ops.py │ ├── pipeline.py │ ├── poolings.py │ ├── scorers.py │ ├── serve/ │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── client.py │ │ ├── config.py │ │ ├── memory.py │ │ └── server.py │ ├── training.py │ └── utils.py ├── notebooks/ │ └── finetuning.ipynb ├── pyproject.toml ├── serve_configs/ │ └── serve_config.yaml ├── test_gliclass.py ├── tests/ │ ├── test_data_processing.py │ ├── test_loss_functions.py │ ├── test_poolings.py │ ├── test_scorers.py │ └── test_utils.py ├── train.py └── train_rl.py
SYMBOL INDEX (409 symbols across 25 files)
FILE: demo.py
function parse_labels_input (line 204) | def parse_labels_input(labels_input: str) -> Union[List[str], Dict[str, ...
function parse_examples_input (line 226) | def parse_examples_input(examples_input: str) -> Optional[List[Dict[str,...
function format_output (line 257) | def format_output(
function format_as_json (line 277) | def format_as_json(results: Union[List[Dict], Dict], hierarchical: bool ...
function format_hierarchical_dict (line 297) | def format_hierarchical_dict(d: Dict, indent: int = 0) -> str:
function classification (line 313) | def classification(
function update_output_visibility (line 821) | def update_output_visibility(hierarchical: bool, fmt: str):
function classify_wrapper (line 842) | def classify_wrapper(text, labels, threshold, multi_label, prompt, examp...
FILE: gliclass/config.py
class GLiClassModelConfig (line 19) | class GLiClassModelConfig(PretrainedConfig):
method __init__ (line 23) | def __init__(
FILE: gliclass/data_processing.py
class AugmentationConfig (line 11) | class AugmentationConfig:
class DataAugmenter (line 26) | class DataAugmenter:
method __init__ (line 27) | def __init__(self, config, examples, labels, label2description=None):
method remove_labels (line 34) | def remove_labels(self, true_labels, all_labels):
method add_random_labels (line 42) | def add_random_labels(self, all_labels):
method add_random_text (line 51) | def add_random_text(self, text, all_labels):
method add_random_synonyms (line 66) | def add_random_synonyms(self, all_labels):
method add_random_descriptions (line 86) | def add_random_descriptions(self, item):
method add_random_examples (line 111) | def add_random_examples(self, item):
method augment (line 145) | def augment(self, item):
class GLiClassDataset (line 184) | class GLiClassDataset(Dataset):
method __init__ (line 185) | def __init__(
method get_diversity (line 222) | def get_diversity(self):
method collect_dataset_labels (line 225) | def collect_dataset_labels(self):
method prepare_labels (line 231) | def prepare_labels(self, example, label2idx, problem_type):
method prepare_prompt (line 243) | def prepare_prompt(self, item, label_token_first=True):
method format_examples (line 256) | def format_examples(self, item):
method tokenize (line 270) | def tokenize(self, texts):
method tokenize_labels (line 274) | def tokenize_labels(self, labels):
method tokenize_and_prepare_labels_for_uniencoder (line 278) | def tokenize_and_prepare_labels_for_uniencoder(self, example):
method tokenize_and_prepare_labels_for_encoder_decoder (line 295) | def tokenize_and_prepare_labels_for_encoder_decoder(self, example):
method tokenize_and_prepare_labels_for_biencoder (line 312) | def tokenize_and_prepare_labels_for_biencoder(self, example):
method __len__ (line 346) | def __len__(self):
method __getitem__ (line 349) | def __getitem__(self, idx):
function pad_2d_tensor (line 365) | def pad_2d_tensor(key_data):
class DataCollatorWithPadding (line 395) | class DataCollatorWithPadding:
method __init__ (line 396) | def __init__(self, device="cuda:0", config=None):
method _resolve_max_num_classes (line 400) | def _resolve_max_num_classes(self, batch):
method __call__ (line 415) | def __call__(self, batch):
FILE: gliclass/layers.py
class LstmSeq2SeqEncoder (line 23) | class LstmSeq2SeqEncoder(nn.Module):
method __init__ (line 24) | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0,...
method forward (line 35) | def forward(self, x, mask, hidden=None):
class FeaturesProjector (line 49) | class FeaturesProjector(nn.Module):
method __init__ (line 50) | def __init__(self, config: GLiClassModelConfig):
method forward (line 58) | def forward(self, features):
class BiEncoderProjector (line 66) | class BiEncoderProjector(nn.Module):
method __init__ (line 67) | def __init__(self, config: GLiClassModelConfig):
method forward (line 74) | def forward(self, features):
class DropoutContext (line 82) | class DropoutContext:
method __init__ (line 83) | def __init__(self):
function get_mask (line 91) | def get_mask(input, local_context):
class XDropout (line 110) | class XDropout(torch.autograd.Function):
method forward (line 114) | def forward(ctx, input, local_ctx):
method backward (line 124) | def backward(ctx, grad_output):
method symbolic (line 132) | def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: floa...
class StableDropout (line 150) | class StableDropout(nn.Module):
method __init__ (line 158) | def __init__(self, drop_prob):
method forward (line 164) | def forward(self, x):
method clear_context (line 175) | def clear_context(self):
method init_context (line 179) | def init_context(self, reuse_mask=True, scale=1):
method get_context (line 187) | def get_context(self):
class SelfAttentionBlock (line 199) | class SelfAttentionBlock(nn.Module):
method __init__ (line 200) | def __init__(self, d_model, num_heads, dropout=0.1):
method forward (line 206) | def forward(self, x, mask=None):
class CrossAttentionBlock (line 211) | class CrossAttentionBlock(nn.Module):
method __init__ (line 212) | def __init__(self, d_model, num_heads, dropout=0.1):
method forward (line 218) | def forward(self, query, key, value, mask=None):
class Fuser (line 223) | class Fuser(nn.Module):
method __init__ (line 224) | def __init__(self, d_model, num_heads, num_layers, dropout=0.1):
method forward (line 237) | def forward(self, query, key, query_mask=None, key_mask=None):
class LayerwiseAttention (line 254) | class LayerwiseAttention(nn.Module):
method __init__ (line 255) | def __init__(self, num_layers, hidden_size, output_size=None):
method forward (line 271) | def forward(self, encoder_outputs):
FILE: gliclass/loss_functions.py
function sequence_contrastive_loss (line 5) | def sequence_contrastive_loss(embeddings, mask):
function focal_loss_with_logits (line 31) | def focal_loss_with_logits(
FILE: gliclass/model.py
class GLiClassOutput (line 60) | class GLiClassOutput(SequenceClassifierOutput):
class GLiClassPreTrainedModel (line 65) | class GLiClassPreTrainedModel(PreTrainedModel):
method _initialize_weights (line 72) | def _initialize_weights(self, module, is_remote_code: bool = False):
method _init_weights (line 88) | def _init_weights(self, module):
class GLiClassBaseModel (line 117) | class GLiClassBaseModel(nn.Module): # ):
method __init__ (line 118) | def __init__(self, config: GLiClassModelConfig, device="cpu", **kwargs):
method _extract_class_features (line 162) | def _extract_class_features(self, token_embeds, input_ids, attention_m...
method _extract_class_features_first (line 224) | def _extract_class_features_first(
method _extract_class_features_averaged (line 255) | def _extract_class_features_averaged(
method get_loss (line 296) | def get_loss(self, logits, labels, classes_embedding=None, classes_emb...
class GLiClassUniEncoder (line 347) | class GLiClassUniEncoder(GLiClassBaseModel):
method __init__ (line 348) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
method _create_segment_ids (line 419) | def _create_segment_ids(self, input_ids):
method process_encoder_output (line 444) | def process_encoder_output(self, input_ids, attention_mask, encoder_la...
method forward (line 469) | def forward(
class GLiClassEncoderDecoder (line 556) | class GLiClassEncoderDecoder(GLiClassBaseModel):
method __init__ (line 557) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
method _make_bidirectional_4d_mask (line 573) | def _make_bidirectional_4d_mask(attention_mask_2d, dtype):
method forward (line 593) | def forward(
class GLiClassEncoderDecoderCLS (line 674) | class GLiClassEncoderDecoderCLS(GLiClassBaseModel):
method __init__ (line 681) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
method forward (line 696) | def forward(
class GLiClassBiEncoder (line 767) | class GLiClassBiEncoder(GLiClassBaseModel):
method __init__ (line 768) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
method pool_outputs (line 790) | def pool_outputs(self, encoder_outputs):
method encode_text (line 798) | def encode_text(self, input_ids, attention_mask):
method encode_classes (line 803) | def encode_classes(self, class_input_ids, class_attention_mask, labels...
method forward (line 835) | def forward(
class GLiClassBiEncoderFused (line 871) | class GLiClassBiEncoderFused(GLiClassBiEncoder):
method __init__ (line 872) | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
method encode_text (line 875) | def encode_text(self, input_ids, attention_mask, class_embeddings, lab...
method forward (line 895) | def forward(
class GLiClassModel (line 937) | class GLiClassModel(GLiClassPreTrainedModel):
method __init__ (line 938) | def __init__(self, config, from_pretrained=False):
method get_input_embeddings (line 952) | def get_input_embeddings(self):
method set_input_embeddings (line 960) | def set_input_embeddings(self, value):
method tie_weights (line 971) | def tie_weights(self, recompute_mapping=True, missing_keys=None):
method resize_token_embeddings (line 1025) | def resize_token_embeddings(self, new_num_tokens: int | None = None, p...
method forward (line 1039) | def forward(self, *args, **kwargs):
FILE: gliclass/ops.py
function attn_padded (line 7) | def attn_padded(
FILE: gliclass/pipeline.py
function flatten_hierarchical_labels (line 12) | def flatten_hierarchical_labels(
function build_hierarchical_output (line 69) | def build_hierarchical_output(
function format_examples_prompt (line 135) | def format_examples_prompt(
class BaseZeroShotClassificationPipeline (line 174) | class BaseZeroShotClassificationPipeline(ABC):
method __init__ (line 175) | def __init__(
method _normalize_classification_type (line 216) | def _normalize_classification_type(self, classification_type: str | No...
method _normalize_texts (line 227) | def _normalize_texts(self, texts: str | List[str]) -> List[str]:
method _normalize_thresholds (line 232) | def _normalize_thresholds(self, threshold: float | List[float], num_te...
method _normalize_classification_types (line 239) | def _normalize_classification_types(
method _process_labels (line 252) | def _process_labels(
method _format_examples_for_input (line 281) | def _format_examples_for_input(self, examples: List[Dict[str, Any]] | ...
method _examples_are_per_text (line 290) | def _examples_are_per_text(self, examples) -> bool:
method _get_text_examples (line 298) | def _get_text_examples(self, examples, index: int):
method _format_prompt (line 306) | def _format_prompt(self, prompt: str | List[str] | None = None, index:...
method _resolve_max_num_classes (line 321) | def _resolve_max_num_classes(self, batch_labels, same_labels: bool):
method prepare_inputs (line 329) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
method _get_batch_examples (line 332) | def _get_batch_examples(self, examples, start_idx, batch_size):
method _get_batch_prompt (line 340) | def _get_batch_prompt(self, prompt, start_idx, batch_size):
method get_embeddings (line 349) | def get_embeddings(self, texts, labels, batch_size=8, examples=None, p...
method __call__ (line 397) | def __call__(
class UniEncoderZeroShotClassificationPipeline (line 522) | class UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificatio...
method __init__ (line 523) | def __init__(
method prepare_input (line 538) | def prepare_input(self, text, labels, examples=None, prompt=None):
method prepare_inputs (line 565) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
class EncoderDecoderZeroShotClassificationPipeline (line 586) | class EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassific...
method __init__ (line 587) | def __init__(
method prepare_labels_prompt (line 602) | def prepare_labels_prompt(self, labels, prompt=None):
method prepare_inputs (line 617) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
class BiEncoderZeroShotClassificationPipeline (line 650) | class BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassification...
method __init__ (line 651) | def __init__(
method prepare_input (line 667) | def prepare_input(self, text, labels, examples=None, prompt=None):
method prepare_inputs (line 687) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
class ZeroShotClassificationPipeline (line 746) | class ZeroShotClassificationPipeline:
method __init__ (line 825) | def __init__(
method flatten_labels (line 869) | def flatten_labels(self, labels: List[str] | Dict[str, Any]) -> List[s...
method get_embeddings (line 882) | def get_embeddings(self, *args, **kwargs):
method __call__ (line 886) | def __call__(
class ZeroShotClassificationWithChunkingPipeline (line 933) | class ZeroShotClassificationWithChunkingPipeline(BaseZeroShotClassificat...
method __init__ (line 936) | def __init__(
method chunk_text (line 960) | def chunk_text(self, text, chunk_size=None, overlap=None):
method prepare_input (line 984) | def prepare_input(self, text, labels, examples=None, prompt=None):
method prepare_inputs (line 1011) | def prepare_inputs(self, texts, labels, same_labels=False, examples=No...
method aggregate_chunk_scores (line 1030) | def aggregate_chunk_scores(self, chunk_scores: List[Dict[str, float]],...
method process_single_text (line 1041) | def process_single_text(self, text, labels, threshold=0.5, examples=No...
method __call__ (line 1095) | def __call__(
function parse_hierarchical_prediction (line 1221) | def parse_hierarchical_prediction(prediction: str, separator: str = ".")...
function group_predictions_by_hierarchy (line 1230) | def group_predictions_by_hierarchy(
function get_best_per_category (line 1250) | def get_best_per_category(predictions: List[Dict[str, Any]], separator: ...
FILE: gliclass/poolings.py
class GlobalMaxPooling1D (line 5) | class GlobalMaxPooling1D(nn.Module):
method forward (line 8) | def forward(self, x: torch.Tensor):
class FirstTokenPooling1D (line 12) | class FirstTokenPooling1D(nn.Module):
method forward (line 15) | def forward(self, x: torch.Tensor):
class LastTokenPooling1D (line 19) | class LastTokenPooling1D(nn.Module):
method forward (line 22) | def forward(self, x: torch.Tensor):
class GlobalAvgPooling1D (line 26) | class GlobalAvgPooling1D(nn.Module):
method forward (line 29) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
class GlobalSumPooling1D (line 38) | class GlobalSumPooling1D(nn.Module):
method forward (line 41) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
class GlobalRMSPooling1D (line 47) | class GlobalRMSPooling1D(nn.Module):
method forward (line 50) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
class GlobalAbsMaxPooling1D (line 59) | class GlobalAbsMaxPooling1D(nn.Module):
method forward (line 62) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
class GlobalAbsAvgPooling1D (line 69) | class GlobalAbsAvgPooling1D(nn.Module):
method forward (line 72) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
class PassPooling1D (line 81) | class PassPooling1D(nn.Module):
method forward (line 84) | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None...
FILE: gliclass/scorers.py
class ScorerWeightedDot (line 7) | class ScorerWeightedDot(nn.Module):
method __init__ (line 8) | def __init__(self, hidden_size, dropout=0.1, **kwargs):
method forward (line 21) | def forward(self, text_rep, label_rep, **kwargs):
class ScorerDot (line 42) | class ScorerDot(nn.Module):
method __init__ (line 43) | def __init__(self, *args, **kwargs):
method forward (line 47) | def forward(self, text_rep, label_rep, **kwargs):
class MLPScorer (line 53) | class MLPScorer(nn.Module):
method __init__ (line 54) | def __init__(self, hidden_size, mlp_hidden_size=256, **kwargs):
method forward (line 69) | def forward(self, text_rep, label_rep, **kwargs):
class HopfieldScorer (line 81) | class HopfieldScorer(nn.Module):
method __init__ (line 82) | def __init__(self, hidden_size, mlp_hidden_size=256, beta=4, num_itera...
method forward (line 101) | def forward(self, text_rep, label_rep, **kwargs):
class CrossAttnScorer (line 128) | class CrossAttnScorer(nn.Module):
method __init__ (line 129) | def __init__(self, hidden_size, num_heads=16, attn_dropout=0.1, scorer...
method forward (line 154) | def forward(self, text_rep, label_rep, text_mask=None, **kwargs):
FILE: gliclass/serve/__main__.py
function main (line 21) | def main():
FILE: gliclass/serve/client.py
class GLiClassClient (line 6) | class GLiClassClient:
method __init__ (line 9) | def __init__(self, url: str = "http://localhost:8000/gliclass"):
method __call__ (line 17) | def __call__(
method classify (line 55) | def classify(
method health_check (line 80) | def health_check(self) -> bool:
FILE: gliclass/serve/config.py
class GLiClassServeConfig (line 10) | class GLiClassServeConfig:
method __post_init__ (line 63) | def __post_init__(self):
method to_env_vars (line 68) | def to_env_vars(self) -> dict:
method from_yaml (line 76) | def from_yaml(cls, config_path: str | Path) -> "GLiClassServeConfig":
method to_yaml (line 90) | def to_yaml(self, config_path: str | Path) -> None:
method update (line 101) | def update(self, **kwargs) -> "GLiClassServeConfig":
FILE: gliclass/serve/memory.py
function _power_of_two_seq_lens (line 21) | def _power_of_two_seq_lens(max_seq_len: int, min_seq_len: int = 64) -> L...
class GLiClassMemoryEstimator (line 32) | class GLiClassMemoryEstimator:
method __init__ (line 35) | def __init__(
method measure_cuda_context (line 51) | def measure_cuda_context(self) -> None:
method measure_model_memory (line 61) | def measure_model_memory(self) -> None:
method available_memory (line 73) | def available_memory(self) -> int:
method calibrate (line 80) | def calibrate(
method _measure_peak (line 106) | def _measure_peak(
method _lookup_seq_len (line 124) | def _lookup_seq_len(self, seq_len: int) -> int:
method per_sample_at (line 133) | def per_sample_at(self, seq_len: int) -> int:
method batch_size_fn (line 138) | def batch_size_fn(
FILE: gliclass/serve/server.py
class GLiClassServer (line 20) | class GLiClassServer:
method __init__ (line 23) | def __init__(self, config: GLiClassServeConfig):
method _precompile (line 92) | def _precompile(self) -> None:
method _calibrate_memory (line 118) | def _calibrate_memory(self) -> None:
method batch_size_fn (line 129) | def batch_size_fn(self, seq_len: int | None = None) -> int:
method observed_seq_len (line 149) | def observed_seq_len(
method _filter_labels (line 176) | def _filter_labels(self, labels: list[str]) -> list[str]:
method _run_batch_internal (line 183) | def _run_batch_internal(
method predict (line 220) | def predict(
function _build_deployment (line 249) | def _build_deployment(config: GLiClassServeConfig):
function serve_gliclass (line 359) | def serve_gliclass(
function shutdown (line 396) | def shutdown() -> None:
class GLiClassFactory (line 400) | class GLiClassFactory:
method __init__ (line 420) | def __init__(
method handle (line 441) | def handle(self):
method predict (line 445) | def predict(
method predict_async (line 472) | async def predict_async(
method shutdown (line 501) | def shutdown(self) -> None:
method __enter__ (line 516) | def __enter__(self):
method __exit__ (line 519) | def __exit__(self, exc_type, exc_val, exc_tb):
method __del__ (line 523) | def __del__(self):
FILE: gliclass/training.py
class EWC (line 32) | class EWC:
method __init__ (line 35) | def __init__(
method _compute_fisher (line 82) | def _compute_fisher(self, dataset: Dataset) -> Dict[str, torch.Tensor]:
method _normalize_fisher (line 187) | def _normalize_fisher(self):
method ewc_loss (line 199) | def ewc_loss(self, batch_size: int | None = None) -> torch.Tensor:
method get_importance_scores (line 226) | def get_importance_scores(self) -> Dict[str, float]:
method update_lambda (line 237) | def update_lambda(self, new_lambda: float):
method consolidate (line 245) | def consolidate(self, dataset: Dataset, alpha: float = 0.5):
class TrainingArguments (line 273) | class TrainingArguments(transformers.TrainingArguments):
class Trainer (line 294) | class Trainer(transformers.Trainer):
method __init__ (line 297) | def __init__(self, ewc: EWC | None = None, prev_dataset=None, *args, *...
method _maybe_initialize_ewc (line 316) | def _maybe_initialize_ewc(self):
method compute_loss (line 352) | def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
method train (line 381) | def train(self, *args, **kwargs):
method training_step (line 387) | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
method prediction_step (line 437) | def prediction_step(
method create_optimizer (line 498) | def create_optimizer(self):
class RLTrainerConfig (line 577) | class RLTrainerConfig(TrainingArguments):
class RLTrainer (line 607) | class RLTrainer(Trainer):
method __init__ (line 608) | def __init__(
method _init_metrics (line 625) | def _init_metrics(self):
method compute_rewards (line 634) | def compute_rewards(
method get_reference_scores (line 646) | def get_reference_scores(self, input_texts, labels_text):
method compute_loss (line 671) | def compute_loss(
method _inner_training_loop (line 757) | def _inner_training_loop(self, *args, **kwargs):
method log_metrics (line 845) | def log_metrics(self):
method _save_checkpoint (line 857) | def _save_checkpoint(self, model, step=None):
FILE: gliclass/utils.py
function is_module_available (line 4) | def is_module_available(module_name):
class MissedPackageException (line 21) | class MissedPackageException(Exception):
function retrieval_augmented_text (line 27) | def retrieval_augmented_text(text: str, examples: list) -> str:
function default_f1_reward (line 64) | def default_f1_reward(
FILE: test_gliclass.py
class TestModel (line 14) | class TestModel:
method __init__ (line 16) | def __init__(self, model, token):
method load_model (line 28) | def load_model(self):
method prepare_dataset (line 34) | def prepare_dataset(self, dataset, classes=None, text_column='text', l...
method prepare_nomapping (line 55) | def prepare_nomapping(self, dataset, classes=None, text_column='text',...
method get_gliclass_predictions (line 79) | def get_gliclass_predictions(self, test_texts, classes, batch_size=8):
method evaluate (line 84) | def evaluate(self, predicts, true_labels):
method process (line 90) | def process(self):
FILE: tests/test_data_processing.py
class TestPad2DTensor (line 9) | class TestPad2DTensor:
method sample_tensors (line 13) | def sample_tensors(self):
method test_pads_to_maximum_dimensions (line 21) | def test_pads_to_maximum_dimensions(self, sample_tensors):
method test_preserves_original_values (line 28) | def test_preserves_original_values(self, sample_tensors):
method test_pads_with_zeros (line 48) | def test_pads_with_zeros(self, sample_tensors):
method test_single_tensor (line 58) | def test_single_tensor(self):
method test_uniform_size_tensors (line 67) | def test_uniform_size_tensors(self):
method test_empty_tensor_handling (line 81) | def test_empty_tensor_handling(self):
method test_preserves_dtype (line 94) | def test_preserves_dtype(self):
method test_varying_row_counts (line 105) | def test_varying_row_counts(self):
method test_varying_column_counts (line 118) | def test_varying_column_counts(self):
method test_batch_consistency (line 131) | def test_batch_consistency(self):
FILE: tests/test_loss_functions.py
class TestSequenceContrastiveLoss (line 9) | class TestSequenceContrastiveLoss:
method sample_embeddings (line 13) | def sample_embeddings(self):
method sample_mask (line 21) | def sample_mask(self):
method test_returns_scalar_loss (line 25) | def test_returns_scalar_loss(self, sample_embeddings, sample_mask):
method test_loss_is_positive (line 32) | def test_loss_is_positive(self, sample_embeddings, sample_mask):
method test_identical_sequences_low_loss (line 38) | def test_identical_sequences_low_loss(self):
method test_handles_masked_positions (line 48) | def test_handles_masked_positions(self):
method test_gradient_flows_through_loss (line 58) | def test_gradient_flows_through_loss(self, sample_embeddings, sample_m...
class TestFocalLossWithLogits (line 68) | class TestFocalLossWithLogits:
method sample_logits (line 72) | def sample_logits(self):
method sample_targets (line 77) | def sample_targets(self):
method test_returns_tensor_with_reduction_none (line 81) | def test_returns_tensor_with_reduction_none(self, sample_logits, sampl...
method test_returns_scalar_with_reduction_mean (line 88) | def test_returns_scalar_with_reduction_mean(self, sample_logits, sampl...
method test_loss_is_positive (line 94) | def test_loss_is_positive(self, sample_logits, sample_targets):
method test_perfect_predictions_low_loss (line 100) | def test_perfect_predictions_low_loss(self):
method test_wrong_predictions_high_loss (line 110) | def test_wrong_predictions_high_loss(self):
method test_alpha_parameter_effect (line 120) | def test_alpha_parameter_effect(self, sample_logits, sample_targets):
method test_gamma_parameter_effect (line 128) | def test_gamma_parameter_effect(self, sample_logits, sample_targets):
method test_reduction_sum (line 137) | def test_reduction_sum(self, sample_logits, sample_targets):
method test_reduction_none (line 143) | def test_reduction_none(self, sample_logits, sample_targets):
method test_handles_extreme_logits (line 149) | def test_handles_extreme_logits(self):
method test_gradient_flows_through_loss (line 159) | def test_gradient_flows_through_loss(self, sample_logits, sample_targe...
method test_all_zeros_targets (line 168) | def test_all_zeros_targets(self):
method test_all_ones_targets (line 178) | def test_all_ones_targets(self):
FILE: tests/test_poolings.py
class TestGlobalMaxPooling1D (line 19) | class TestGlobalMaxPooling1D:
method pooling_layer (line 23) | def pooling_layer(self):
method sample_input (line 28) | def sample_input(self):
method test_returns_max_across_sequence (line 32) | def test_returns_max_across_sequence(self, pooling_layer, sample_input):
method test_output_shape (line 39) | def test_output_shape(self, pooling_layer, sample_input):
class TestGlobalAvgPooling1D (line 46) | class TestGlobalAvgPooling1D:
method pooling_layer (line 50) | def pooling_layer(self):
method test_returns_average_across_sequence (line 54) | def test_returns_average_across_sequence(self, pooling_layer):
method test_handles_attention_mask (line 63) | def test_handles_attention_mask(self, pooling_layer):
method test_output_shape (line 73) | def test_output_shape(self, pooling_layer):
class TestGlobalSumPooling1D (line 82) | class TestGlobalSumPooling1D:
method pooling_layer (line 86) | def pooling_layer(self):
method test_returns_sum_across_sequence (line 90) | def test_returns_sum_across_sequence(self, pooling_layer):
method test_handles_attention_mask (line 99) | def test_handles_attention_mask(self, pooling_layer):
class TestFirstTokenPooling1D (line 110) | class TestFirstTokenPooling1D:
method pooling_layer (line 114) | def pooling_layer(self):
method test_returns_first_token (line 118) | def test_returns_first_token(self, pooling_layer):
method test_works_with_batch (line 127) | def test_works_with_batch(self, pooling_layer):
method test_output_shape (line 136) | def test_output_shape(self, pooling_layer):
class TestLastTokenPooling1D (line 145) | class TestLastTokenPooling1D:
method pooling_layer (line 149) | def pooling_layer(self):
method test_returns_last_token (line 153) | def test_returns_last_token(self, pooling_layer):
method test_works_with_batch (line 162) | def test_works_with_batch(self, pooling_layer):
method test_output_shape (line 171) | def test_output_shape(self, pooling_layer):
class TestGlobalRMSPooling1D (line 180) | class TestGlobalRMSPooling1D:
method pooling_layer (line 184) | def pooling_layer(self):
method test_returns_rms_across_sequence (line 188) | def test_returns_rms_across_sequence(self, pooling_layer):
method test_handles_attention_mask (line 197) | def test_handles_attention_mask(self, pooling_layer):
method test_output_shape (line 207) | def test_output_shape(self, pooling_layer):
class TestGlobalAbsMaxPooling1D (line 216) | class TestGlobalAbsMaxPooling1D:
method pooling_layer (line 220) | def pooling_layer(self):
method test_returns_abs_max_across_sequence (line 224) | def test_returns_abs_max_across_sequence(self, pooling_layer):
method test_handles_attention_mask (line 233) | def test_handles_attention_mask(self, pooling_layer):
method test_output_shape (line 243) | def test_output_shape(self, pooling_layer):
class TestGlobalAbsAvgPooling1D (line 252) | class TestGlobalAbsAvgPooling1D:
method pooling_layer (line 256) | def pooling_layer(self):
method test_returns_abs_avg_across_sequence (line 260) | def test_returns_abs_avg_across_sequence(self, pooling_layer):
method test_handles_attention_mask (line 269) | def test_handles_attention_mask(self, pooling_layer):
method test_output_shape (line 279) | def test_output_shape(self, pooling_layer):
class TestPassPooling1D (line 288) | class TestPassPooling1D:
method pooling_layer (line 292) | def pooling_layer(self):
method test_returns_input_unchanged (line 296) | def test_returns_input_unchanged(self, pooling_layer):
method test_ignores_attention_mask (line 304) | def test_ignores_attention_mask(self, pooling_layer):
method test_maintains_shape (line 313) | def test_maintains_shape(self, pooling_layer):
FILE: tests/test_scorers.py
class TestScorerWeightedDot (line 15) | class TestScorerWeightedDot:
method scorer (line 17) | def scorer(self):
method test_forward_pass (line 20) | def test_forward_pass(self, scorer):
method test_gradient_flow (line 29) | def test_gradient_flow(self, scorer):
class TestScorerDot (line 41) | class TestScorerDot:
method scorer (line 43) | def scorer(self):
method test_forward_pass (line 46) | def test_forward_pass(self, scorer):
method test_gradient_flow (line 55) | def test_gradient_flow(self, scorer):
class TestMLPScorer (line 67) | class TestMLPScorer:
method scorer (line 69) | def scorer(self):
method test_forward_pass (line 72) | def test_forward_pass(self, scorer):
method test_different_batch_sizes (line 81) | def test_different_batch_sizes(self, scorer):
method test_gradient_flow (line 90) | def test_gradient_flow(self, scorer):
class TestHopfieldScorer (line 102) | class TestHopfieldScorer:
method scorer (line 104) | def scorer(self):
method test_forward_pass (line 107) | def test_forward_pass(self, scorer):
method test_multiple_iterations (line 116) | def test_multiple_iterations(self):
method test_gradient_flow (line 125) | def test_gradient_flow(self, scorer):
class TestCrossAttnScorer (line 137) | class TestCrossAttnScorer:
method scorer (line 139) | def scorer(self):
method test_forward_pass_with_text_mask (line 142) | def test_forward_pass_with_text_mask(self, scorer):
method test_forward_pass_without_text_mask (line 153) | def test_forward_pass_without_text_mask(self, scorer):
method test_different_seq_lengths (line 162) | def test_different_seq_lengths(self, scorer):
method test_gradient_flow (line 171) | def test_gradient_flow(self, scorer):
method test_eval_mode (line 182) | def test_eval_mode(self, scorer):
FILE: tests/test_utils.py
class TestIsModuleAvailable (line 9) | class TestIsModuleAvailable:
method test_detects_installed_module (line 12) | def test_detects_installed_module(self):
method test_detects_missing_module (line 17) | def test_detects_missing_module(self):
method test_handles_submodules (line 21) | def test_handles_submodules(self):
class TestRetrievalAugmentedText (line 26) | class TestRetrievalAugmentedText:
method test_with_structured_examples (line 29) | def test_with_structured_examples(self):
method test_empty_examples_returns_original_text (line 42) | def test_empty_examples_returns_original_text(self):
method test_includes_true_label_markers (line 51) | def test_includes_true_label_markers(self):
class TestDefaultF1Reward (line 62) | class TestDefaultF1Reward:
method sample_inputs (line 66) | def sample_inputs(self):
method test_returns_tensor (line 77) | def test_returns_tensor(self, sample_inputs):
method test_output_shape (line 83) | def test_output_shape(self, sample_inputs):
method test_perfect_predictions (line 89) | def test_perfect_predictions(self):
method test_zero_f1_for_wrong_predictions (line 100) | def test_zero_f1_for_wrong_predictions(self):
method test_handles_valid_mask (line 111) | def test_handles_valid_mask(self):
FILE: train.py
class CustomTrainer (line 20) | class CustomTrainer(Trainer):
method __init__ (line 23) | def __init__(self, *args, use_weighted_sampling=False, **kwargs):
method _get_train_sampler (line 27) | def _get_train_sampler(self, train_dataset) -> torch.utils.data.Sampler:
function compute_metrics (line 38) | def compute_metrics(p, problem_type='multi_label_classification'):
function load_dataset (line 78) | def load_dataset(data_path: str) -> list:
function main (line 92) | def main(args):
FILE: train_rl.py
function accuracy_reward (line 20) | def accuracy_reward(probs, actions, targets, valid_mask):
function recall_reward (line 27) | def recall_reward(
function compute_metrics (line 43) | def compute_metrics(p):
function main (line 71) | def main(args):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (390K chars).
[
{
"path": ".github/workflows/release.yaml",
"chars": 1962,
"preview": "name: Release GLiClass to PyPI\n\non:\n push:\n tags:\n - 'v*' # Trigger on version tags (e.g., v1.0.0, v2.1.3)\n\nco"
},
{
"path": ".github/workflows/tests.yml",
"chars": 1580,
"preview": "name: Tests\n\non:\n push:\n branches:\n - main\n pull_request:\n branches:\n - main\n workflow_dispatch:\n\ncon"
},
{
"path": ".gitignore",
"chars": 3236,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n#custom\nmodels/\nwandb/\ng"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 10900,
"preview": "# ⭐ GLiClass: Generalist and Lightweight Model for Sequence Classification\n\n**GLiClass** is an efficient, zero-shot sequ"
},
{
"path": "demo.py",
"chars": 41524,
"preview": "\"\"\"\nGLiClass Enhanced Demo with Advanced Features\n\nFeatures:\n- Task description prompts\n- Hierarchical label inputs (JSO"
},
{
"path": "gliclass/__init__.py",
"chars": 418,
"preview": "from .model import GLiClassModel, GLiClassBiEncoder, GLiClassUniEncoder, GLiClassEncoderDecoderCLS\nfrom .config import G"
},
{
"path": "gliclass/config.py",
"chars": 5522,
"preview": "from transformers import AutoConfig\nfrom transformers.utils import logging\nfrom transformers.models.auto import CONFIG_M"
},
{
"path": "gliclass/data_processing.py",
"chars": 17349,
"preview": "import copy\nimport random\nfrom dataclasses import dataclass\n\nimport torch\nfrom torch.utils.data import Dataset\nfrom torc"
},
{
"path": "gliclass/layers.py",
"chars": 10366,
"preview": "# Copyright 2020 Microsoft and the Hugging Face Inc. team and Knowledgator.\n#\n# Licensed under the Apache License, Versi"
},
{
"path": "gliclass/loss_functions.py",
"chars": 3851,
"preview": "import torch\nimport torch.nn.functional as F\n\n\ndef sequence_contrastive_loss(embeddings, mask):\n # embeddings shape: "
},
{
"path": "gliclass/model.py",
"chars": 47469,
"preview": "import os\nimport warnings\nfrom typing import Tuple\nfrom pathlib import Path\nfrom dataclasses import dataclass\n\nimport to"
},
{
"path": "gliclass/ops.py",
"chars": 1180,
"preview": "import torch\nimport torch.nn.functional as F\n\n# ─── Attention (padded) ─────────────────────────────────────────────────"
},
{
"path": "gliclass/pipeline.py",
"chars": 48526,
"preview": "from abc import ABC, abstractmethod\nfrom typing import Any, Dict, List\n\nimport torch\nfrom tqdm import tqdm\nfrom transfor"
},
{
"path": "gliclass/poolings.py",
"chars": 3154,
"preview": "import torch\nfrom torch import nn\n\n\nclass GlobalMaxPooling1D(nn.Module):\n \"\"\"Applies Global Max Pooling on the timest"
},
{
"path": "gliclass/scorers.py",
"chars": 6508,
"preview": "import torch\nfrom torch import nn\n\nfrom .ops import attn_padded\n\n\nclass ScorerWeightedDot(nn.Module):\n def __init__(s"
},
{
"path": "gliclass/serve/__init__.py",
"chars": 407,
"preview": "\"\"\"GLiClass serving module.\"\"\"\n\nfrom .client import GLiClassClient\nfrom .config import GLiClassServeConfig\nfrom .memory "
},
{
"path": "gliclass/serve/__main__.py",
"chars": 6961,
"preview": "\"\"\"CLI entry point for GLiClass serving.\"\"\"\n\nimport sys\nimport signal\nimport logging\nimport argparse\n\nimport ray\nfrom ra"
},
{
"path": "gliclass/serve/client.py",
"chars": 3428,
"preview": "\"\"\"Client for GLiClass serving endpoint.\"\"\"\n\nimport requests\n\n\nclass GLiClassClient:\n \"\"\"Client for interacting with "
},
{
"path": "gliclass/serve/config.py",
"chars": 3302,
"preview": "\"\"\"Configuration for GLiClass Ray Serve deployment.\"\"\"\n\nfrom pathlib import Path\nfrom dataclasses import field, asdict, "
},
{
"path": "gliclass/serve/memory.py",
"chars": 5778,
"preview": "\"\"\"Memory estimation for GLiClass via precomputed calibration table.\n\nStartup calibration runs the model on probe batche"
},
{
"path": "gliclass/serve/server.py",
"chars": 17381,
"preview": "\"\"\"Ray Serve deployment for GLiClass with dynamic batching.\"\"\"\n\nimport os\nimport logging\nfrom typing import Any\n\nimport "
},
{
"path": "gliclass/training.py",
"chars": 34982,
"preview": "import os\nfrom typing import Any, Dict, List, Tuple, Callable\nfrom dataclasses import field, dataclass\n\nimport numpy as "
},
{
"path": "gliclass/utils.py",
"chars": 3932,
"preview": "import torch\n\n\ndef is_module_available(module_name):\n \"\"\"\n Checks whether the specified Python module is available"
},
{
"path": "notebooks/finetuning.ipynb",
"chars": 8995,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": "
},
{
"path": "pyproject.toml",
"chars": 4481,
"preview": "[build-system]\nrequires = [\"setuptools>=61.0.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.setuptools.packages.find"
},
{
"path": "serve_configs/serve_config.yaml",
"chars": 1248,
"preview": "# Model configuration\nmodel: knowledgator/gliclass-edge-v3.0\ndevice: cuda\ndtype: float16\n\n# Limits\nmax_model_len: 2048\nm"
},
{
"path": "test_gliclass.py",
"chars": 7806,
"preview": "from gliclass import GLiClassModel, ZeroShotClassificationPipeline\nfrom transformers import AutoTokenizer\nfrom datasets "
},
{
"path": "tests/test_data_processing.py",
"chars": 4650,
"preview": "\"\"\"Tests for gliclass.data_processing module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.data_processing import pad_2"
},
{
"path": "tests/test_loss_functions.py",
"chars": 7025,
"preview": "\"\"\"Tests for gliclass.loss_functions module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.loss_functions import sequenc"
},
{
"path": "tests/test_poolings.py",
"chars": 10182,
"preview": "\"\"\"Tests for gliclass.poolings module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.poolings import (\n GlobalMaxPool"
},
{
"path": "tests/test_scorers.py",
"chars": 5367,
"preview": "\"\"\"Tests for gliclass.scorers module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.scorers import (\n ScorerWeightedD"
},
{
"path": "tests/test_utils.py",
"chars": 4415,
"preview": "\"\"\"Tests for gliclass.utils module.\"\"\"\n\nimport pytest\nimport torch\n\nfrom gliclass.utils import is_module_available, retr"
},
{
"path": "train.py",
"chars": 16566,
"preview": "import os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\nimport numpy as np\nimport argparse\nimport json\n\nfrom sklearn.met"
},
{
"path": "train_rl.py",
"chars": 11175,
"preview": "import os\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\nimport numpy as np\nimport argparse\nimport json\n\nfrom sklearn.met"
}
]
About this extraction
This page contains the full source code of the Knowledgator/GLiClass GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (364.2 KB), approximately 83.2k tokens, and a symbol index with 409 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.