Full Code of bytedance/XVerse for AI

main 65e2581aac0f cached
90 files
3.7 MB
963.7k tokens
574 symbols
1 requests
Download .txt
Showing preview only (3,963K chars total). Download the full file or copy to clipboard to get everything.
Repository: bytedance/XVerse
Branch: main
Commit: 65e2581aac0f
Files: 90
Total size: 3.7 MB

Directory structure:
gitextract_6xpfra6z/

├── .gitignore
├── .gradio/
│   └── certificate.pem
├── LICENCE
├── README.md
├── assets/
│   ├── ReadMe.md
│   ├── crop_faces.py
│   ├── rename.py
│   └── segmentation.py
├── eval/
│   ├── eval_scripts/
│   │   ├── run_eval_multi.sh
│   │   └── run_eval_single.sh
│   ├── grounded_sam/
│   │   ├── florence2/
│   │   │   ├── config.json
│   │   │   ├── configuration_florence2.py
│   │   │   ├── generation_config.json
│   │   │   ├── modeling_florence2.py
│   │   │   ├── preprocessor_config.json
│   │   │   ├── processing_florence2.py
│   │   │   ├── tokenizer.json
│   │   │   ├── tokenizer_config.json
│   │   │   └── vocab.json
│   │   ├── grounded_sam2_florence2_autolabel_pipeline.py
│   │   └── sam2/
│   │       ├── __init__.py
│   │       ├── automatic_mask_generator.py
│   │       ├── build_sam.py
│   │       ├── configs/
│   │       │   ├── sam2/
│   │       │   │   ├── sam2_hiera_b+.yaml
│   │       │   │   ├── sam2_hiera_l.yaml
│   │       │   │   ├── sam2_hiera_s.yaml
│   │       │   │   └── sam2_hiera_t.yaml
│   │       │   ├── sam2.1/
│   │       │   │   ├── sam2.1_hiera_b+.yaml
│   │       │   │   ├── sam2.1_hiera_l.yaml
│   │       │   │   ├── sam2.1_hiera_s.yaml
│   │       │   │   └── sam2.1_hiera_t.yaml
│   │       │   └── sam2.1_training/
│   │       │       └── sam2.1_hiera_b+_MOSE_finetune.yaml
│   │       ├── csrc/
│   │       │   └── connected_components.cu
│   │       ├── modeling/
│   │       │   ├── __init__.py
│   │       │   ├── backbones/
│   │       │   │   ├── __init__.py
│   │       │   │   ├── hieradet.py
│   │       │   │   ├── image_encoder.py
│   │       │   │   └── utils.py
│   │       │   ├── memory_attention.py
│   │       │   ├── memory_encoder.py
│   │       │   ├── position_encoding.py
│   │       │   ├── sam/
│   │       │   │   ├── __init__.py
│   │       │   │   ├── mask_decoder.py
│   │       │   │   ├── prompt_encoder.py
│   │       │   │   └── transformer.py
│   │       │   ├── sam2_base.py
│   │       │   └── sam2_utils.py
│   │       ├── sam2_hiera_b+.yaml
│   │       ├── sam2_hiera_l.yaml
│   │       ├── sam2_hiera_s.yaml
│   │       ├── sam2_hiera_t.yaml
│   │       ├── sam2_image_predictor.py
│   │       ├── sam2_video_predictor.py
│   │       └── utils/
│   │           ├── __init__.py
│   │           ├── amg.py
│   │           ├── misc.py
│   │           └── transforms.py
│   └── tools/
│       ├── XVerseBench_multi.json
│       ├── XVerseBench_multi_DSG.json
│       ├── XVerseBench_single.json
│       ├── XVerseBench_single_DSG.json
│       ├── dino.py
│       ├── dpg_score.py
│       ├── face_id.py
│       ├── face_utils/
│       │   ├── face.py
│       │   └── face_recg.py
│       ├── florence_sam.py
│       ├── idip_aes_score.py
│       ├── idip_dpg_score.py
│       ├── idip_face_score.py
│       ├── idip_gen_split_idip.py
│       ├── idip_sam-dino_score.py
│       └── log_scores.py
├── inference_single_sample.py
├── requirements.txt
├── run_demo.sh
├── run_gradio.py
├── src/
│   ├── adapters/
│   │   ├── __init__.py
│   │   └── mod_adapters.py
│   ├── flux/
│   │   ├── block.py
│   │   ├── condition.py
│   │   ├── generate.py
│   │   ├── lora_controller.py
│   │   ├── pipeline_tools.py
│   │   └── transformer.py
│   └── utils/
│       ├── data_utils.py
│       ├── gpu_momory_utils.py
│       └── modulation_utils.py
└── train/
    └── config/
        ├── XVerse_config_INF.yaml
        └── XVerse_config_demo.yaml

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
assets/XVerseBench/animal/*
assets/XVerseBench/object/*
__pycache__
checkpoints/*
generated_*
tmp
*.png

================================================
FILE: .gradio/certificate.pem
================================================
-----BEGIN CERTIFICATE-----
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
-----END CERTIFICATE-----


================================================
FILE: LICENCE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation

<p align="center">
    <a href="https://arxiv.org/abs/2506.21416">
            <img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2506.21416-b31b1b.svg">
    </a>
    <a href="https://bytedance.github.io/XVerse/">
        <img alt="Project Page" src="https://img.shields.io/badge/Project-Page-blue">
    </a>
    <a href="https://github.com/bytedance/XVerse/tree/main/assets">
        <img alt="Build" src="https://img.shields.io/badge/XVerseBench-Dataset-green">
    </a>
    <a href="https://huggingface.co/ByteDance/XVerse">
        <img alt="Build" src="https://img.shields.io/badge/🤗-HF%20Model-yellow">
    </a>    
    <a href="https://huggingface.co/spaces/ByteDance/XVerse">
        <img alt="Build" src="https://img.shields.io/badge/🤗-HF%20Demo-yellow">
    </a>
</p>

## 🔥 News
- **2025.9.19**: 🎉 Congratulations! XVerse has been accepted by NeurIPS 2025! 🎉
- **2025.7.18**: Supports quantized diffusion models, and add group offload to run the XVerse model in 16GB VRAM.
- **2025.7.10**: Release huggingface space demo.
- **2025.7.8**: Supports low VRAM inference, can run the XVerse model in 24GB VRAM.
- **2025.6.26**: The code has been released!

![XVerse's capability in single/multi-subject personalization and semantic attribute control (pose, style, lighting)](sample/first_page.png)

## 📖 Introduction

**XVerse** introduces a novel approach to multi-subject image synthesis, offering **precise and independent control over individual subjects** without disrupting the overall image latents or features. We achieve this by transforming reference images into offsets for token-specific text-stream modulation.

This innovation enables high-fidelity, editable image generation where you can robustly control both **individual subject characteristics** (identity) and their **semantic attributes**. XVerse significantly enhances capabilities for personalized and complex scene generation.

## ⚡️ Quick Start

### Requirements and Installation

First, install the necessary dependencies:

```bash
# Create a conda environment named XVerse with Python version 3.10.16
conda create -n XVerse python=3.10.16 -y
# Activate the XVerse environment
conda activate XVerse
# Install the correct version of pytorch (According to your machine)
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
# Use pip to install the dependencies specified in requirements.txt
pip install -r requirements.txt
# Install flash-attn
pip install flash-attn==2.7.4.post1 --no-build-isolation
# Update version of httpx
pip install httpx==0.23.3
```

Next, download the required checkpoints:
```bash
cd checkpoints
bash ./download_ckpts.sh
cd ..
```
**Important**: You'll also need to download the face recognition model `model_ir_se50.pth` from [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) and place it directly into the `./checkpoints/` folder.

After that, you can export the model paths as environment variables. This step ensures that the subsequent inference scripts can locate the necessary models correctly:
``` bash
export FLORENCE2_MODEL_PATH="./checkpoints/Florence-2-large"
export SAM2_MODEL_PATH="./checkpoints/sam2.1_hiera_large.pt"
export FACE_ID_MODEL_PATH="./checkpoints/model_ir_se50.pth"
export CLIP_MODEL_PATH="./checkpoints/clip-vit-large-patch14"
export FLUX_MODEL_PATH="./checkpoints/FLUX.1-dev"
export DPG_VQA_MODEL_PATH="./checkpoints/mplug_visual-question-answering_coco_large_en"
export DINO_MODEL_PATH="./checkpoints/dino-vits16"
```

### Local Gradio Demo

To run the interactive Gradio demo locally, execute the following command:
```bash
python run_gradio.py
```

#### Input Settings Explained
The Gradio demo provides several parameters to control your image generation process:
* **Prompt**: The textual description guiding the image generation.
* **Generated Height/Width**: Use the sliders to set the shape of the output image.
* **Weight_id/ip**: Adjust these weight parameters. Higher values generally lead to better subject consistency but might slightly impact the naturalness of the generated image.
* **latent_lora_scale and vae_lora_scale**: Control the LoRA scale. Similar to Weight_id/ip, larger LoRA values can improve subject consistency but may reduce image naturalness.
* **vae_skip_iter_before and vae_skip_iter_after**: Configure VAE skip iterations. Skipping more steps can result in better naturalness but might compromise subject consistency.

#### Input Images

The demo provides detailed control over your input images:

* **Expand Panel**: Click "Input Image X" to reveal the options for each image.
* **Upload Image**: Click "Image X" to upload your desired reference image.
* **Image Description**: Enter a description in the "Caption X" input box. You can also click "Auto Caption" to generate a description automatically.
* **Detection & Segmentation**: Click "Det & Seg" to perform detection and segmentation on the uploaded image.
* **Crop Face**: Use "Crop Face" to automatically crop the face from the image.
* **ID Checkbox**: Check or uncheck "ID or not" to determine whether to use ID-related weights for that specific input image.

> **⚠️ Important Usage Notes:**
>
> * **Prompt Construction**: The main text prompt **MUST** include the exact text you entered in the `Image Description` field for each active image. **Generation will fail if this description is missing from the prompt.**
>     * *Example*: If you upload two images and set their descriptions as "a man with red hair" (for Image 1) and "a woman with blue eyes" (for Image 2), your main prompt might be: "A `a man with red hair` walking beside `a woman with blue eyes` in a park."
>     * You can then write your main prompt simply as: "`ENT1` walking beside `ENT2` in a park." The code will **automatically replace** these placeholders with the full description text before generation.
> * **Active Images**: Only images in **expanded** (un-collapsed) panels will be fed into the model. Collapsed image panels are ignored.

### Inference with Single Sample

To perform inference on a single sample, run the following command. You can customize the image generation by adjusting the parameters such as the prompt, seed, and output size:
```bash
python inference_single_sample.py --prompt "ENT1 wearing a tiny hat" --seed 42 --cond_size 256 --target_height 768 --target_width 768 --weight_id 3 --weight_ip 5 --latent_lora_scale 0.85 --vae_lora_scale 1.3 --vae_skip_iter_s1 0.05 --vae_skip_iter_s2 0.8 --images "sample/hamster.jpg" --captions "a hamster" --idips false --save_path "generated_image_1.png" --num_images 1
```

For inference with multiple condition images, use the command below. This allows you to incorporate multiple reference images into the generation process. Make sure to match the number of --images, --captions, and --ids values:
```bash
python inference_single_sample.py --prompt "ENT1, and ENT2 standing together in a park." --seed 42 --cond_size 256 --target_height 768 --target_width 768 --weight_id 2 --weight_ip 5 --latent_lora_scale 0.85 --vae_lora_scale 1.3 --vae_skip_iter_s1 0.05 --vae_skip_iter_s2 0.8 --images "sample/woman.jpg" "sample/girl.jpg" --captions "a woman" "a girl" --idips true true --save_path "generated_image_2.png" --num_images 1
```

## ⚡️ Low-VRAM Inference

### Offload Modules to CPU

- During inference with a single sample or running gradio demo, you can enable low VRAM mode by adding the parameter `--use_low_vram True` or `--use_lower_vram True`. 
- `use_low_vram` allows you to perform inference with up to two conditional images on a GPU equipped with 24GB of VRAM.
- `use_lower_vram` allows you to perform inference with up to three conditional images on a GPU equipped with 16GB of VRAM. 
- Note that CPU offload significantly reduces inference speed and should only be enabled when necessary.

### Quantized Diffusion Models

- You can download the quantized model from [here](https://huggingface.co/collections/diffusers/flux-quantized-checkpoints-682c951aebd378a2462984a0) into the checkpoints folder‌. Using the bnb-nf4 quantized model, you can run inference with a single condition on 32GB of VRAM, or perform inference with three conditions on 24GB of VRAM by enabling the CPU offloading feature. You need to modify the `FLUX_MODEL_PATH` environment variable and add the parameter `--dit_quant None`.
- You can also download the GGUF quantized model from [here](https://huggingface.co/city96/FLUX.1-dev-gguf) into the checkpoints folder‌. Using this GGUF quantized model, you can perform two-condition inference with 32GB of video memory, or achieve four-condition-based inference in a 24GB video memory environment by enabling the CPU offloading feature. You can run the inference using the following commands:
```bash
export FLUX_TRANSFORMERS_PATH="./checkpoints/FLUX.1-dev-gguf/flux1-dev-Q3_K_S.gguf"
export FLUX_MODEL_PATH="./checkpoints/FLUX.1-dev"
python inference_single_sample.py --prompt "ENT1, and ENT2 standing together in a park." --seed 42 --cond_size 256 --target_height 768 --target_width 768 --weight_id 3 --weight_ip 5 --latent_lora_scale 0.7 --vae_lora_scale 1.2 --vae_skip_iter_s1 0.05 --vae_skip_iter_s2 0.8 --images "sample/woman.jpg" "sample/girl.jpg" --captions "a woman" "a girl" --idips true true --save_path "generated_image_2-GGUF.png" --num_images 1 --dit_quant GGUF
```
**Note**: Quantized models may degrade the model's performance to some extent, and parameters like `weight_id`, `weight_ip`, and `lora_scale` may need to be re-adjusted.

## Inference with XVerseBench

![XVerseBench](sample/XVerseBench.png)

First, please download XVerseBench according to the contents in the `assets` folder. Then, when running inference, please execute the following command:
```bash
bash ./eval/eval_scripts/run_eval.sh
```
The script will automatically evaluate the model on the XVerseBench dataset and save the results in the `./results` folder.

## 📌 ToDo

- [x] Release github repo.
- [x] Release arXiv paper.
- [x] Release model checkpoints.
- [x] Release inference data: XVerseBench.
- [x] Release inference code for XVerseBench.
- [x] Release inference code for gradio demo.
- [x] Release inference code for single sample.
- [x] Support inference in consumer-grade GPUs.
- [x] Release huggingface space demo.
- [x] Support quantized diffusion models.
- [ ] Release Benchmark Leaderboard.
- [ ] Release ComfyUI implementation.

## License
    
The code in this project is licensed under Apache 2.0; the dataset is licensed under CC0, subject to the intellctual property owned by Bytedance. Meanwhile, the dataset is adapted from [dreambench++](https://dreambenchplus.github.io/), you should also comply with the license of dreambench++.

## Acknowledgments
We sincerely thank Alex Nasa for deploying the Hugging Face demo with the FLUX.1-schnell model. You can experience this online demo by clicking [here](https://huggingface.co/spaces/alexnasa/XVerse).
    
##  Citation
If XVerse is helpful, please help to ⭐ the repo.

If you find this project useful for your research, please consider citing our paper:
```bibtex
@article{chen2025xverse,
  title={XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation},
  author={Chen, Bowen and Zhao, Mengyi and Sun, Haomiao and Chen, Li and Wang, Xu and Du, Kang and Wu, Xinglong},
  journal={arXiv preprint arXiv:2506.21416},
  year={2025}


}


================================================
FILE: assets/ReadMe.md
================================================
# Install of XVerseBench

Existing controlled image generation benchmarks often focus on either maintaining identity or object appearance consistency, rarely encompassing datasets that rigorously test both aspects. To comprehensively assess the models' single-subject and multi-subject conditional generation and editing capabilities, we constructed a new benchmark by merging and curating data from DreamBench++ and some generated human images.

Our resulting benchmark XVerseBench comprises 20 distinct human identities, 74 unique objects, and 45 different animal species/individuals. To thoroughly evaluate model effectiveness in subject-driven generation tasks, we developed test sets specifically for single-subject, dual-subject, and triple-subject control scenarios. This benchmark includes 300 unique test prompts covering diverse combinations of humans, objects, and animals. 

<p align="center">
  <img src="../sample/XVerseBench.png" alt="XVerseBench">
</p>
<p align="center"><strong>Figure 1. XVerseBench</strong></p>

The above figure shows more detail information and samples for each categories. For evaluation, we employ a suite of metrics to quantify different aspects of generation quality and control fidelity: including DPG score to assess the model's editing capability, Face ID similarity and DINOv2 similarity to assess the model's preservation of human identity and objects, and Aesthetic Score to measure to evaluate the aesthetics of the generated image. XVerseBench aims to provide a more challenging and holistic evaluation framework for state-of-the-art multi-subject controllable text-to-image generation models.

## Usage

1. Download **DreamBench++** from [https://dreambenchplus.github.io/](https://dreambenchplus.github.io/) and place it into the `data/DreamBench++` directory.
2. Run the following command to rename and segementate the images:
   ```bash
   python assets/rename.py
   python assets/segmentation_sample.py
   ```

## Citation
If XVerseBench is helpful, please help to ⭐ the repo.

If you find this project useful for your research, please consider citing our paper:
```bibtex
@article{chen2025xverse,
  title={XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation},
  author={Chen, Bowen and Zhao, Mengyi and Sun, Haomiao and Chen, Li and Wang, Xu and Du, Kang and Wu, Xinglong},
  journal={arXiv preprint arXiv:2506.21416},
  year={2025}
}
```


> Disclaimer:
>
> Your access to and use of this dataset are at your own risk. We do not guarantee the accuracy of this dataset. The dataset is provided “as is” and we make no warranty or representation to you with respect to it and we expressly disclaim, and hereby expressly waive, all warranties, express, implied, statutory or otherwise. This includes, without limitation, warranties of quality, performance, merchantability or fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable.
> 
> In no event will we be liable to you on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this public license or use of the licensed material.
>
> The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.



================================================
FILE: assets/crop_faces.py
================================================
import os
import face_recognition
from PIL import Image, ImageOps
import numpy as np

def detect_and_crop_faces(input_dir, output_dir):
    # 确保输出目录存在
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 遍历输入目录中的所有文件
    for filename in os.listdir(input_dir):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            input_path = os.path.join(input_dir, filename)
            output_path = os.path.join(output_dir, filename.replace('.png', '.jpg'))

            # 加载图像并处理透明背景
            image = Image.open(input_path).convert("RGBA")
            background = Image.new("RGBA", image.size, "WHITE")
            alpha_composite = Image.alpha_composite(background, image).convert("RGB")

            # 添加白色边缘,这里 padding 设为 10 像素,可按需调整
            padded_image = ImageOps.expand(alpha_composite, border=10, fill='white')

            # 尝试不同尺度的图像检测
            scales = [0.6, 0.4, 0.2]
            face_locations = []
            for scale in scales:
                resized_image = padded_image.resize((int(padded_image.width * scale), int(padded_image.height * scale)), Image.LANCZOS)
                image_np = np.array(resized_image)
                # Use the cnn model for detection
                face_locations = face_recognition.face_locations(image_np, model="cnn")
                if face_locations:
                    # Adjust the detected face positions to the original image size
                    face_locations = [(int(top / scale), int(right / scale), int(bottom / scale), int(left / scale)) for top, right, bottom, left in face_locations]
                    break

            if face_locations:
                # 假设第一个检测到的人脸是需要裁剪的
                top, right, bottom, left = face_locations[0]
                height = bottom - top
                width = right - left

                # 计算扩充后的区域
                new_top = max(0, int(top - height * 0.3))
                new_bottom = min(np.array(padded_image).shape[0], int(bottom + height * 0.3))
                new_left = max(0, int(left - width * 0.3))
                new_right = min(np.array(padded_image).shape[1], int(right + width * 0.3))

                face_image = np.array(padded_image)[new_top:new_bottom, new_left:new_right]
                # 将 NumPy 数组转换为 PIL 图像
                face_pil = Image.fromarray(face_image)
                # 保存裁剪后的人脸图像
                face_pil.save(output_path)
                print(f"已裁剪并保存: {output_path}")
            else:
                print(f"未在 {input_path} 中检测到人脸")

if __name__ == "__main__":
    input_directory = "/mnt/bn/yg-butterfly-algo/personal/sunhm/code/XVerse/assets/XVerseBench_seg/human_seg"
    output_directory = "/mnt/bn/yg-butterfly-algo/personal/sunhm/code/XVerse/assets/XVerseBench_seg/human"
    detect_and_crop_faces(input_directory, output_directory)


================================================
FILE: assets/rename.py
================================================
import os
import shutil

split = [("live_subject/animal", "animal"), ("object", "object")]

# 定义目录路径
caption_dir_base = './data/DreamBench_plus/captions'
image_dir_base = './data/DreamBench_plus/images'
new_image_dir_base = './data/XVerseBench_rename'

for s, ts in split:
    caption_dir = os.path.join(caption_dir_base, s)
    image_dir = os.path.join(image_dir_base, s)
    new_image_dir = os.path.join(new_image_dir_base, ts)

    # 创建新的目标目录(如果不存在)
    if not os.path.exists(new_image_dir):
        os.makedirs(new_image_dir)

    # 获取所有 caption 文件
    caption_files = sorted([f for f in os.listdir(caption_dir) if f.endswith('.txt')])

    for caption_file in caption_files:
        # 提取索引
        index = os.path.splitext(caption_file)[0]
        # 构建 caption 文件完整路径
        caption_file_path = os.path.join(caption_dir, caption_file)
        # 构建对应的图片文件路径
        image_file_name = f'{index}.jpg'
        image_file_path = os.path.join(image_dir, image_file_name)

        # 检查图片文件是否存在
        if os.path.exists(image_file_path):
            # 读取 caption 文件内容
            with open(caption_file_path, 'r', encoding='utf-8') as f:
                caption = f.read().split('\n')[0].strip()

            # 生成新的文件名
            new_file_name = f'{index}_{caption}.jpg'
            new_file_path_in_new_dir = os.path.join(new_image_dir, new_file_name)

            # 移动并重命名文件
            shutil.copy2(image_file_path, new_file_path_in_new_dir)
            print(f'文件 {image_file_path} 已移动并重命名为 {new_file_path_in_new_dir}')
        else:
            print(f'未找到对应的图片文件: {image_file_path}')


old_human_index = ['00', '05', '06', '09', '12', '13', '14', '16', '17']

# 新增的文件映射
new_files = [
    "object/65_anime space ranger.jpg", "object/66_anime girl.jpg", "object/67_pixelated warrior.jpg",
    "object/68_anime girl.jpg", "object/69_anime samurai.jpg", "object/70_anime girl.jpg",
    "object/71_anime Spider-Man.jpg", "object/72_Avatar.jpg", "object/73_anime man.jpg"
]

# 新增复制文件的代码
for old_human_index, new_file in zip(old_human_index, new_files):
    # 构建原始图片文件路径
    original_image_path = os.path.join(image_dir_base, "live_subject/human", f"{old_human_index}.jpg")
    # 构建新的图片文件路径
    new_image_path = os.path.join(new_image_dir_base, new_file)
    
    # 创建新文件的目录(如果不存在)
    new_image_dir = os.path.dirname(new_image_path)
    if not os.path.exists(new_image_dir):
        os.makedirs(new_image_dir)
    
    # 检查原始图片文件是否存在
    if os.path.exists(original_image_path):
        # 复制文件
        shutil.copy2(original_image_path, new_image_path)
        print(f'文件 {original_image_path} 已复制到 {new_image_path}')
    else:
        print(f'未找到对应的图片文件: {original_image_path}')

================================================
FILE: assets/segmentation.py
================================================
from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxes
from eval.tools.florence_sam import ObjectDetector
import torch
import os
from PIL import Image  # 补充导入 Image 模块
import numpy as np

def merge_instances(orig_img, indices, ins_bboxes, ins_images):
    orig_image_width, orig_image_height = orig_img.width, orig_img.height
    final_img = Image.new("RGB", (orig_image_width, orig_image_height), color=(255, 255, 255))
    bboxes = []
    for i in indices:
        bbox = np.array(ins_bboxes[i], dtype=int).tolist()
        bboxes.append(bbox)
        
        img = cv2pil(ins_images[i])
        mask = (np.array(img)[..., :3] != 255).any(axis=-1)
        mask = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
        final_img.paste(img, (bbox[0], bbox[1]), mask)
    
    bbox = merge_bboxes(bboxes)
    img = final_img.crop(bbox)
    return img, bbox

dtype = torch.bfloat16
device = "cuda"
detector = ObjectDetector(device)
def det_seg_img(image, label):
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")
    instance_result_dict = detector.get_multiple_instances(image, label, min_size=image.size[0]//20)
    indices = list(range(len(instance_result_dict["instance_images"])))
    ins, bbox = merge_instances(image, indices, instance_result_dict["instance_bboxes"], instance_result_dict["instance_images"])
    return ins

def segment_images_in_folder(input_folder, output_folder):
    """
    对输入文件夹内所有图像进行分割,并将结果保存到输出文件夹。

    :param input_folder: 输入图像文件夹路径
    :param output_folder: 输出分割结果的文件夹路径
    """
    # 确保输出文件夹存在
    os.makedirs(output_folder, exist_ok=True)

    # 遍历输入文件夹及其子文件夹内的所有文件
    for root, _, filenames in os.walk(input_folder):
        for filename in filenames:
            # 检查是否为图像文件
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                file_path = os.path.join(root, filename)
                try:
                    # 从文件名中提取标签,假设文件名格式为 "数字_标签.png"
                    label = filename.split('_')[-1].rsplit('.', 1)[0].strip()
                    # 进行图像分割
                    segmentation_result = det_seg_img(file_path, label)
                    # 构建输出文件路径,保持原文件名
                    relative_path = os.path.relpath(root, input_folder)
                    output_subfolder = os.path.join(output_folder, relative_path)
                    os.makedirs(output_subfolder, exist_ok=True)
                    output_path = os.path.join(output_subfolder, filename)
                    # 保存分割结果
                    if isinstance(segmentation_result, Image.Image):
                        segmentation_result.save(output_path)
                    else:
                        # 假设 segmentation_result 是可转换为 PIL Image 的对象
                        Image.fromarray(segmentation_result).save(output_path)
                except Exception as e:
                    print(f"处理文件 {file_path} 时出错: {e}")


# 使用示例
if __name__ == "__main__":
    input_folder = "./assets/XverseBench_rename"
    output_folder = "./assets/XVerseBench"
    segment_images_in_folder(input_folder, output_folder)


================================================
FILE: eval/eval_scripts/run_eval_multi.sh
================================================
export config_path="./train/config/XVerse_config_INF.yaml"
export model_checkpoint="./checkpoints/XVerse"
export target_size=768
export condition_size=256
export test_list_name="XVerseBench_multi"
export save_name="./eval/XVerseBench_multi"

ports=(`echo $METIS_WORKER_0_PORT | tr ',' ' '`)
port=${ports[-1]}

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_gen_split_idip \
    --config_name "$config_path" \
    --model_path "$model_checkpoint" \
    --target_size "$target_size" \
    --condition_size "$condition_size" \
    --save_name "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_dpg_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_aes_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_face_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_sam-dino_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

python \
    -m eval.tools.log_scores \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"


================================================
FILE: eval/eval_scripts/run_eval_single.sh
================================================
export config_path="./train/config/XVerse_config_INF.yaml"
export model_checkpoint="./checkpoints/XVerse"
export target_size=768
export condition_size=256
export test_list_name="XVerseBench_single"
export save_name="./eval/XVerseBench_singleidip"

ports=(`echo $METIS_WORKER_0_PORT | tr ',' ' '`)
port=${ports[-1]}

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_gen_split_idip \
    --config_name "$config_path" \
    --model_path "$model_checkpoint" \
    --target_size "$target_size" \
    --condition_size "$condition_size" \
    --save_name "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_dpg_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_aes_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_face_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

accelerate launch \
    --main_process_port $port \
    -m eval.tools.idip_sam-dino_score \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"

python \
    -m eval.tools.log_scores \
    --input_dir "$save_name" \
    --test_list_name "$test_list_name"


================================================
FILE: eval/grounded_sam/florence2/config.json
================================================
{
  "_name_or_path": "florence2",
  "architectures": [
    "Florence2ForConditionalGeneration"
  ],
  "auto_map": {
    "AutoConfig": "configuration_florence2.Florence2Config",
    "AutoModelForCausalLM": "modeling_florence2.Florence2ForConditionalGeneration"
  },
  "bos_token_id": 0,
  "eos_token_id": 2,
  "ignore_index": -100,
  "model_type": "florence2",
  "pad_token_id": 1,
  "projection_dim": 1024,
  "text_config": {
      "vocab_size": 51289,
      "activation_dropout": 0.1,
      "activation_function": "gelu",
      "add_bias_logits": false,
      "add_final_layer_norm": false,
      "attention_dropout": 0.1,
      "bos_token_id": 0,
      "classif_dropout": 0.1,
      "classifier_dropout": 0.0,
      "d_model": 1024,
      "decoder_attention_heads": 16,
      "decoder_ffn_dim": 4096,
      "decoder_layerdrop": 0.0,
      "decoder_layers": 12,
      "decoder_start_token_id": 2,
      "dropout": 0.1,
      "early_stopping": true,
      "encoder_attention_heads": 16,
      "encoder_ffn_dim": 4096,
      "encoder_layerdrop": 0.0,
      "encoder_layers": 12,
      "eos_token_id": 2,
      "forced_eos_token_id": 2,
      "forced_bos_token_id": 0,
      "gradient_checkpointing": false,
      "init_std": 0.02,
      "is_encoder_decoder": true,
      "label2id": {
        "LABEL_0": 0,
        "LABEL_1": 1,
        "LABEL_2": 2
      },
      "max_position_embeddings": 1024,
      "no_repeat_ngram_size": 3,
      "normalize_before": false,
      "num_hidden_layers": 12,
      "pad_token_id": 1,
      "scale_embedding": false,
      "num_beams": 3
  },
  "vision_config": {
    "model_type": "davit",
    "drop_path_rate": 0.1,  
    "patch_size": [7, 3, 3, 3],  
    "patch_stride": [4, 2, 2, 2],  
    "patch_padding": [3, 1, 1, 1],  
    "patch_prenorm": [false, true, true, true],  
    "enable_checkpoint": false,  
    "dim_embed": [256, 512, 1024, 2048],  
    "num_heads": [8, 16, 32, 64],  
    "num_groups": [8, 16, 32, 64],  
    "depths": [1, 1, 9, 1],  
    "window_size": 12,
    "projection_dim": 1024,
    "visual_temporal_embedding": {
        "type": "COSINE",
        "max_temporal_embeddings": 100
    },
    "image_pos_embed": {
        "type": "learned_abs_2d",
        "max_pos_embeddings": 50
    },
    "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
  },
  "vocab_size": 51289,
  "torch_dtype": "float16",
  "transformers_version": "4.41.0.dev0",
  "is_encoder_decoder": true
}

================================================
FILE: eval/grounded_sam/florence2/configuration_florence2.py
================================================
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
""" Florence-2 configuration"""

from typing import Optional

from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

class Florence2VisionConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
    according to the specified arguments, defining the model architecture. Instantiating a configuration with the 
    defaults will yield a similar configuration to that of the Florence2VisionModel architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        drop_path_rate (`float`, *optional*, defaults to 0.1):
            The dropout rate of the drop path layer.
        patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
            The patch size of the image.
        patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
            The patch stride of the image.
        patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
            The patch padding of the image.
        patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
            Whether to apply layer normalization before the patch embedding layer.
        enable_checkpoint (`bool`, *optional*, defaults to False):
            Whether to enable checkpointing.
        dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
            The dimension of the embedding layer.
        num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
            The number of attention heads.
        num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
            The number of groups.
        depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
            The depth of the model.
        window_size (`int`, *optional*, defaults to 12):
            The window size of the model.
        projection_dim (`int`, *optional*, defaults to 1024):
            The dimension of the projection layer.
        visual_temporal_embedding (`dict`, *optional*):
            The configuration of the visual temporal embedding.
        image_pos_embed (`dict`, *optional*):
            The configuration of the image position embedding.
        image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
            The source of the image feature.
    Example:

    ```python
    >>> from transformers import Florence2VisionConfig, Florence2VisionModel

    >>> # Initializing a Florence2 Vision style configuration
    >>> configuration = Florence2VisionConfig()

    >>> # Initializing a model (with random weights)
    >>> model = Florence2VisionModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "florence2_vision"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        drop_path_rate=0.1,
        patch_size=[7, 3, 3, 3],
        patch_stride=[4, 2, 2, 2],
        patch_padding=[3, 1, 1, 1],
        patch_prenorm=[False, True, True, True],
        enable_checkpoint=False,
        dim_embed=[256, 512, 1024, 2048],
        num_heads=[8, 16, 32, 64],
        num_groups=[8, 16, 32, 64],
        depths=[1, 1, 9, 1],
        window_size=12,
        projection_dim=1024,
        visual_temporal_embedding=None,
        image_pos_embed=None,
        image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
        **kwargs,
    ):
        self.drop_path_rate = drop_path_rate
        self.patch_size = patch_size
        self.patch_stride = patch_stride
        self.patch_padding = patch_padding
        self.patch_prenorm = patch_prenorm
        self.enable_checkpoint = enable_checkpoint
        self.dim_embed = dim_embed
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.depths = depths
        self.window_size = window_size
        self.projection_dim = projection_dim
        self.visual_temporal_embedding = visual_temporal_embedding
        self.image_pos_embed = image_pos_embed
        self.image_feature_source = image_feature_source

        super().__init__(**kwargs)



class Florence2LanguageConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the BART
    [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 51289):
            Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Florence2LanguageModel`].
        d_model (`int`, *optional*, defaults to 1024):
            Dimensionality of the layers and the pooler layer.
        encoder_layers (`int`, *optional*, defaults to 12):
            Number of encoder layers.
        decoder_layers (`int`, *optional*, defaults to 12):
            Number of decoder layers.
        encoder_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer encoder.
        decoder_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer decoder.
        decoder_ffn_dim (`int`, *optional*, defaults to 4096):
            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
        encoder_ffn_dim (`int`, *optional*, defaults to 4096):
            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
        activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"silu"` and `"gelu_new"` are supported.
        dropout (`float`, *optional*, defaults to 0.1):
            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        activation_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for activations inside the fully connected layer.
        classifier_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for classifier.
        max_position_embeddings (`int`, *optional*, defaults to 1024):
            The maximum sequence length that this model might ever be used with. Typically set this to something large
            just in case (e.g., 512 or 1024 or 2048).
        init_std (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        encoder_layerdrop (`float`, *optional*, defaults to 0.0):
            The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
            for more details.
        decoder_layerdrop (`float`, *optional*, defaults to 0.0):
            The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
            for more details.
        scale_embedding (`bool`, *optional*, defaults to `False`):
            Scale embeddings by diving by sqrt(d_model).
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models).
        num_labels (`int`, *optional*, defaults to 3):
            The number of labels to use in [`Florence2LanguageForSequenceClassification`].
        forced_eos_token_id (`int`, *optional*, defaults to 2):
            The id of the token to force as the last generated token when `max_length` is reached. Usually set to
            `eos_token_id`.

    Example:

    ```python
    >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel

    >>> # Initializing a Florence2 Language style configuration
    >>> configuration = Florence2LanguageConfig()

    >>> # Initializing a model (with random weights)
    >>> model = Florence2LangaugeModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "florence2_language"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}

    def __init__(
        self,
        vocab_size=51289,
        max_position_embeddings=1024,
        encoder_layers=12,
        encoder_ffn_dim=4096,
        encoder_attention_heads=16,
        decoder_layers=12,
        decoder_ffn_dim=4096,
        decoder_attention_heads=16,
        encoder_layerdrop=0.0,
        decoder_layerdrop=0.0,
        activation_function="gelu",
        d_model=1024,
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=0.02,
        classifier_dropout=0.0,
        scale_embedding=False,
        use_cache=True,
        num_labels=3,
        pad_token_id=1,
        bos_token_id=0,
        eos_token_id=2,
        is_encoder_decoder=True,
        decoder_start_token_id=2,
        forced_eos_token_id=2,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.d_model = d_model
        self.encoder_ffn_dim = encoder_ffn_dim
        self.encoder_layers = encoder_layers
        self.encoder_attention_heads = encoder_attention_heads
        self.decoder_ffn_dim = decoder_ffn_dim
        self.decoder_layers = decoder_layers
        self.decoder_attention_heads = decoder_attention_heads
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.activation_dropout = activation_dropout
        self.activation_function = activation_function
        self.init_std = init_std
        self.encoder_layerdrop = encoder_layerdrop
        self.decoder_layerdrop = decoder_layerdrop
        self.classifier_dropout = classifier_dropout
        self.use_cache = use_cache
        self.num_hidden_layers = encoder_layers
        self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True

        super().__init__(
            num_labels=num_labels,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            decoder_start_token_id=decoder_start_token_id,
            forced_eos_token_id=forced_eos_token_id,
            **kwargs,
        )

        # ensure backward compatibility for BART CNN models
        if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
            self.forced_bos_token_id = self.bos_token_id
            warnings.warn(
                f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
                "The config can simply be saved and uploaded again to be fixed."
            )

class Florence2Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
    Florence-2 model according to the specified arguments, defining the model architecture. 

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vision_config (`Florence2VisionConfig`,  *optional*):
            Custom vision config or dict
        text_config (`Union[AutoConfig, dict]`, *optional*):
            The config object of the text backbone. 
        ignore_index (`int`, *optional*, defaults to -100):
            The ignore index for the loss function.
        vocab_size (`int`, *optional*, defaults to 51289):
            Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
        projection_dim (`int`, *optional*, defaults to 1024):
            Dimension of the multimodal projection space.

    Example:

    ```python
    >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig

    >>> # Initializing a clip-like vision config
    >>> vision_config = CLIPVisionConfig()

    >>> # Initializing a Bart config
    >>> text_config = BartConfig()

    >>> # Initializing a Florence-2 configuration
    >>> configuration = Florence2Config(vision_config, text_config)

    >>> # Initializing a model from the florence-2 configuration
    >>> model = Florence2ForConditionalGeneration(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "florence2"
    is_composition = False

    def __init__(
        self,
        vision_config=None,
        text_config=None,
        ignore_index=-100,
        vocab_size=51289,
        projection_dim=1024,
        **kwargs,
    ):
        self.ignore_index = ignore_index
        self.vocab_size = vocab_size
        self.projection_dim = projection_dim
        if vision_config is not None:
            vision_config = PretrainedConfig(**vision_config)
        self.vision_config = vision_config
        self.vocab_size = self.vocab_size

        self.text_config = text_config
        if text_config is not None:
            self.text_config = Florence2LanguageConfig(**text_config)


        super().__init__(**kwargs)



================================================
FILE: eval/grounded_sam/florence2/generation_config.json
================================================
{
    "num_beams": 3,
    "early_stopping": false
}

================================================
FILE: eval/grounded_sam/florence2/modeling_florence2.py
================================================
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" PyTorch Florence-2 model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import math
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import CrossEntropyLoss 
from collections import OrderedDict
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_

from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    logging,
    replace_return_docstrings,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
)
from .configuration_florence2 import Florence2Config 
from .configuration_florence2 import Florence2LanguageConfig
from .configuration_florence2 import Florence2VisionConfig


from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
    _prepare_4d_attention_mask,
    _prepare_4d_attention_mask_for_sdpa,
    _prepare_4d_causal_attention_mask,
    _prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)


if is_flash_attn_2_available():
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "Florence2Config"

class LearnedAbsolutePositionEmbedding2D(nn.Module):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, embedding_dim=256, num_pos=50):
        super().__init__()
        self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
        self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2))

    def forward(self, pixel_values):
        """
        pixel_values: (batch_size, height, width, num_channels) 
        returns: (batch_size, height, width, embedding_dim * 2)
        """
        if len(pixel_values.shape) != 4:
            raise ValueError('pixel_values must be a 4D tensor')
        height, width = pixel_values.shape[1:3]
        width_values = torch.arange(width, device=pixel_values.device)
        height_values = torch.arange(height, device=pixel_values.device)
        x_emb = self.column_embeddings(width_values)
        y_emb = self.row_embeddings(height_values)
        # (height, width, embedding_dim * 2)
        pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
        # (embedding_dim * 2, height, width)
        pos = pos.permute(2, 0, 1)
        pos = pos.unsqueeze(0)
        # (batch_size, embedding_dim * 2, height, width)
        pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
        # (batch_size, height, width, embedding_dim * 2)
        pos = pos.permute(0, 2, 3, 1)
        return pos

class PositionalEmbeddingCosine1D(nn.Module):
    """
    This class implements a very simple positional encoding. It follows closely
    the encoder from the link below:
    https://pytorch.org/tutorials/beginner/translation_transformer.html

    Args:
        embed_dim: The dimension of the embeddings.
        dropout_prob: The dropout probability.
        max_seq_len: The maximum length to precompute the positional encodings.
    """
    def __init__(
            self,
            embed_dim: int = 512,
            max_seq_len: int = 1024) -> None:
        super(PositionalEmbeddingCosine1D, self).__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        # Generate the sinusoidal arrays.
        factor = math.log(10000)
        denominator = torch.exp(
            -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim)
        # Matrix where rows correspond to a positional embedding as a function
        # of the position index (i.e., the row index).
        frequencies = \
            torch.arange(0, self.max_seq_len) \
            .reshape(self.max_seq_len, 1) * denominator
        pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
        # Populate uneven entries.
        pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
        pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
        # Save the positional embeddings in a constant buffer.
        self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)

    def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
        """
        Args:
            seq_embeds: The sequence embeddings in order. Allowed size:
                1. [T, D], where T is the length of the sequence, and D is the
                frame embedding dimension.
                2. [B, T, D], where B is the batch size and T and D are the
                same as above.

        Returns a tensor of with the same dimensions as the input: i.e.,
        [1, T, D] or [T, D].
        """
        shape_len = len(seq_embeds.shape)
        assert 2 <= shape_len <= 3
        len_seq = seq_embeds.size(-2)
        assert len_seq <= self.max_seq_len
        pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
        # Adapt pre-computed positional embeddings to the input.
        if shape_len == 3:
            pos_embeds = pos_embeds.view(
                (1, pos_embeds.size(0), pos_embeds.size(1)))
        return pos_embeds


class LearnedAbsolutePositionEmbedding1D(nn.Module):
    """
    Learnable absolute positional embeddings for 1D sequences.

    Args:
        embed_dim: The dimension of the embeddings.
        max_seq_len: The maximum length to precompute the positional encodings.
    """
    def __init__(
            self,
            embedding_dim: int = 512,
            num_pos: int = 1024) -> None:
        super(LearnedAbsolutePositionEmbedding1D, self).__init__()
        self.embeddings = nn.Embedding(num_pos, embedding_dim)
        self.num_pos = num_pos

    def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
        """
        Args:
            seq_embeds: The sequence embeddings in order. Allowed size:
                1. [T, D], where T is the length of the sequence, and D is the
                frame embedding dimension.
                2. [B, T, D], where B is the batch size and T and D are the
                same as above.

        Returns a tensor of with the same dimensions as the input: i.e.,
        [1, T, D] or [T, D].
        """
        shape_len = len(seq_embeds.shape)
        assert 2 <= shape_len <= 3
        len_seq = seq_embeds.size(-2)
        assert len_seq <= self.num_pos
        # [T, D]
        pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
        # Adapt pre-computed positional embeddings to the input.
        if shape_len == 3:
            pos_embeds = pos_embeds.view(
                (1, pos_embeds.size(0), pos_embeds.size(1)))
        return pos_embeds



class MySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs


class PreNorm(nn.Module):
    def __init__(self, norm, fn, drop_path=None):
        super().__init__()
        self.norm = norm
        self.fn = fn
        self.drop_path = drop_path

    def forward(self, x, *args, **kwargs):
        shortcut = x
        if self.norm != None:
            x, size = self.fn(self.norm(x), *args, **kwargs)
        else:
            x, size = self.fn(x, *args, **kwargs)

        if self.drop_path:
            x = self.drop_path(x)

        x = shortcut + x

        return x, size


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.net = nn.Sequential(OrderedDict([
            ("fc1", nn.Linear(in_features, hidden_features)),
            ("act", act_layer()),
            ("fc2", nn.Linear(hidden_features, out_features))
        ]))

    def forward(self, x, size):
        return self.net(x), size


class DepthWiseConv2d(nn.Module):
    def __init__(
        self,
        dim_in,
        kernel_size,
        padding,
        stride,
        bias=True,
    ):
        super().__init__()
        self.dw = nn.Conv2d(
            dim_in, dim_in,
            kernel_size=kernel_size,
            padding=padding,
            groups=dim_in,
            stride=stride,
            bias=bias
        )

    def forward(self, x, size):
        B, N, C = x.shape
        H, W = size
        assert N == H * W

        x = self.dw(x.transpose(1, 2).view(B, C, H, W))
        size = (x.size(-2), x.size(-1))
        x = x.flatten(2).transpose(1, 2)
        return x, size


class ConvEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(
        self,
        patch_size=7,
        in_chans=3,
        embed_dim=64,
        stride=4,
        padding=2,
        norm_layer=None,
        pre_norm=True
    ):
        super().__init__()
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )

        dim_norm = in_chans if pre_norm else embed_dim
        self.norm = norm_layer(dim_norm) if norm_layer else None

        self.pre_norm = pre_norm

    def forward(self, x, size):
        H, W = size
        if len(x.size()) == 3:
            if self.norm and self.pre_norm:
                x = self.norm(x)
            x = rearrange(
                x, 'b (h w) c -> b c h w',
                h=H, w=W
            )

        x = self.proj(x)

        _, _, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm and not self.pre_norm:
            x = self.norm(x)

        return x, (H, W)


class ChannelAttention(nn.Module):

    def __init__(self, dim, groups=8, qkv_bias=True):
        super().__init__()

        self.groups = groups
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, size):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * (float(N) ** -0.5)
        attention = q.transpose(-1, -2) @ k
        attention = attention.softmax(dim=-1)
        x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x, size


class ChannelBlock(nn.Module):

    def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
                 drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 conv_at_attn=True, conv_at_ffn=True):
        super().__init__()

        drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
        self.channel_attn = PreNorm(
            norm_layer(dim),
            ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
            drop_path
        )
        self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
        self.ffn = PreNorm(
            norm_layer(dim),
            Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
            drop_path
        )

    def forward(self, x, size):
        if self.conv1:
            x, size = self.conv1(x, size)
        x, size = self.channel_attn(x, size)

        if self.conv2:
            x, size = self.conv2(x, size)
        x, size = self.ffn(x, size)

        return x, size


def window_partition(x, window_size: int):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
    B = batch_size 
    # this will cause onnx conversion failed for dynamic axis, because treated as constant
    # int(windows.shape[0] / (H * W / window_size / window_size)) 
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size, qkv_bias=True):

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = float(head_dim) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, size):

        H, W = size
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        x = window_partition(x, self.window_size)
        x = x.view(-1, self.window_size * self.window_size, C)

        # W-MSA/SW-MSA
        # attn_windows = self.attn(x_windows)

        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        attn = self.softmax(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)

        # merge windows
        x = x.view(
            -1, self.window_size, self.window_size, C
        )
        x = window_reverse(x, B, self.window_size, Hp, Wp)

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        return x, size


class SpatialBlock(nn.Module):

    def __init__(self, dim, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
        super().__init__()

        drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
        self.window_attn = PreNorm(
            norm_layer(dim),
            WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
            drop_path
        )
        self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
        self.ffn = PreNorm(
            norm_layer(dim),
            Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
            drop_path
        )

    def forward(self, x, size):
        if self.conv1:
            x, size = self.conv1(x, size)
        x, size = self.window_attn(x, size)

        if self.conv2:
            x, size = self.conv2(x, size)
        x, size = self.ffn(x, size)
        return x, size


class DaViT(nn.Module):
    """ DaViT: Dual-Attention Transformer

    Args:
        in_chans (int): Number of input image channels. Default: 3.
        num_classes (int): Number of classes for classification head. Default: 1000.
        patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
        patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
        patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
        patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
        embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
        num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
        num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
        window_size (int): Window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
        drop_path_rate (float): Stochastic depth rate. Default: 0.1.
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        enable_checkpoint (bool): If True, enable checkpointing. Default: False.
        conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
        conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
    """

    def __init__(
        self,
        in_chans=3,
        num_classes=1000,
        depths=(1, 1, 3, 1),
        patch_size=(7, 2, 2, 2),
        patch_stride=(4, 2, 2, 2),
        patch_padding=(3, 0, 0, 0),
        patch_prenorm=(False, False, False, False),
        embed_dims=(64, 128, 192, 256),
        num_heads=(3, 6, 12, 24),
        num_groups=(3, 6, 12, 24),
        window_size=7,
        mlp_ratio=4.,
        qkv_bias=True,
        drop_path_rate=0.1,
        norm_layer=nn.LayerNorm,
        enable_checkpoint=False,
        conv_at_attn=True,
        conv_at_ffn=True,
     ):
        super().__init__()

        self.num_classes = num_classes
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.num_stages = len(self.embed_dims)
        self.enable_checkpoint = enable_checkpoint
        assert self.num_stages == len(self.num_heads) == len(self.num_groups)

        num_stages = len(embed_dims)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]

        depth_offset = 0
        convs = []
        blocks = []
        for i in range(num_stages):
            conv_embed = ConvEmbed(
                patch_size=patch_size[i],
                stride=patch_stride[i],
                padding=patch_padding[i],
                in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
                embed_dim=self.embed_dims[i],
                norm_layer=norm_layer,
                pre_norm=patch_prenorm[i]
            )
            convs.append(conv_embed)

            block = MySequential(
                *[
                    MySequential(OrderedDict([
                        (
                            'spatial_block', SpatialBlock(
                                embed_dims[i],
                                num_heads[i],
                                window_size,
                                drop_path_rate=dpr[depth_offset+j*2],
                                qkv_bias=qkv_bias,
                                mlp_ratio=mlp_ratio,
                                conv_at_attn=conv_at_attn,
                                conv_at_ffn=conv_at_ffn,
                            )
                        ),
                        (
                            'channel_block', ChannelBlock(
                                embed_dims[i],
                                num_groups[i],
                                drop_path_rate=dpr[depth_offset+j*2+1],
                                qkv_bias=qkv_bias,
                                mlp_ratio=mlp_ratio,
                                conv_at_attn=conv_at_attn,
                                conv_at_ffn=conv_at_ffn,
                            )
                        )
                    ])) for j in range(depths[i])
                ]
            )
            blocks.append(block)
            depth_offset += depths[i]*2

        self.convs = nn.ModuleList(convs)
        self.blocks = nn.ModuleList(blocks)

        self.norms = norm_layer(self.embed_dims[-1])
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    @property
    def dim_out(self):
        return self.embed_dims[-1]

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, std=0.02)
            for name, _ in m.named_parameters():
                if name in ['bias']:
                    nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

    def forward_features_unpool(self, x):
        """
        forward until avg pooling 
        Args:
            x (_type_): input image tensor
        """
        input_size = (x.size(2), x.size(3))
        for conv, block in zip(self.convs, self.blocks):
            x, input_size = conv(x, input_size)
            if self.enable_checkpoint:
                x, input_size = checkpoint.checkpoint(block, x, input_size)
            else:
                x, input_size = block(x, input_size)
        return x

    def forward_features(self, x):
        x = self.forward_features_unpool(x)

        # (batch_size, num_tokens, token_dim)
        x = self.avgpool(x.transpose(1, 2))
        # (batch_size, 1, num_tokens)
        x = torch.flatten(x, 1)
        x = self.norms(x)

        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x
    
    @classmethod
    def from_config(cls, config):
        return cls(
            depths=config.depths,
            embed_dims=config.dim_embed,
            num_heads=config.num_heads,
            num_groups=config.num_groups,
            patch_size=config.patch_size,
            patch_stride=config.patch_stride,
            patch_padding=config.patch_padding,
            patch_prenorm=config.patch_prenorm,
            drop_path_rate=config.drop_path_rate,
            window_size=config.window_size,
        )




if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


class Florence2LearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int):
        # Florence2 is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        bsz, seq_len = input_ids.shape[:2]
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        ).expand(bsz, -1)

        return super().forward(positions + self.offset)


class Florence2ScaledWordEmbedding(nn.Embedding):
    """
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.embed_scale = embed_scale

    def forward(self, input_ids: torch.Tensor):
        return super().forward(input_ids) * self.embed_scale


class Florence2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[Florence2LanguageConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.config = config

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.is_causal = is_causal

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
        # is checking that the `sequence_length` of the `past_key_value` is the same as
        # the provided `key_value_states` to support prefix tuning
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.reshape(*proj_shape)
        value_states = value_states.reshape(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            if layer_head_mask.size() != (self.num_heads,):
                raise ValueError(
                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
                    f" {layer_head_mask.size()}"
                )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned across GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value


class Florence2FlashAttention2(Florence2Attention):
    """
    Florence2 flash attention module. This module inherits from `Florence2Attention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # Florence2FlashAttention2 attention does not support output_attentions
        if output_attentions:
            raise ValueError("Florence2FlashAttention2 attention does not support output_attentions")

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, q_len, _ = hidden_states.size()

        # get query proj
        query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
        # get key, value proj
        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
        # is checking that the `sequence_length` of the `past_key_value` is the same as
        # the provided `key_value_states` to support prefix tuning
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0].transpose(1, 2)
            value_states = past_key_value[1].transpose(1, 2)
        elif is_cross_attention:
            # cross_attentions
            key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
            value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
        else:
            # self_attention
            key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        attn_output = self._flash_attention_forward(
            query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
        )

        attn_output = attn_output.reshape(bsz, q_len, -1)
        attn_output = self.out_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
    def _flash_attention_forward(
        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
    ):
        """
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Contains at least one padding token in the sequence
        if attention_mask is not None:
            batch_size = query_states.shape[0]
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )

            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )

            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        return attn_output

    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )
        if query_length == kv_seq_len:
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # The -q_len: slice assumes left padding.
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )


class Florence2SdpaAttention(Florence2Attention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""
        if output_attentions or layer_head_mask is not None:
            # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "Florence2Model is using Florence2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
                ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states,
                key_value_states=key_value_states,
                past_key_value=past_key_value,
                attention_mask=attention_mask,
                layer_head_mask=layer_head_mask,
                output_attentions=output_attentions,
            )

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states)
        # get key, value proj
        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
        # is checking that the `sequence_length` of the `past_key_value` is the same as
        # the provided `key_value_states` to support prefix tuning
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)

        query_states = self._shape(query_states, tgt_len, bsz)

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
        is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False

        # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
        # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned across GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, None, past_key_value


FLORENCE2_ATTENTION_CLASSES = {
    "eager": Florence2Attention,
    "sdpa": Florence2SdpaAttention,
    "flash_attention_2": Florence2FlashAttention2,
}


class Florence2EncoderLayer(nn.Module):
    def __init__(self, config: Florence2LanguageConfig):
        super().__init__()
        self.embed_dim = config.d_model

        self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: torch.FloatTensor,
        layer_head_mask: torch.FloatTensor,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states
        hidden_states, attn_weights, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs


class Florence2DecoderLayer(nn.Module):
    def __init__(self, config: Florence2LanguageConfig):
        super().__init__()
        self.embed_dim = config.d_model

        self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            is_causal=True,
            config=config,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.encoder_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            config=config,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        # Self Attention
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # add present self-attn cache to positions 1,2 of present_key_value tuple
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=self_attn_past_key_value,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Cross-Attention Block
        cross_attn_present_key_value = None
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            residual = hidden_states

            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
            )
            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
            hidden_states = residual + hidden_states
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

            # add cross-attn to positions 3,4 of present_key_value tuple
            present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        if use_cache:
            outputs += (present_key_value,)

        return outputs



class Florence2LanguagePreTrainedModel(PreTrainedModel):
    config_class = Florence2LanguageConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
    _no_split_modules = [r"Florence2EncoderLayer", r"Florence2DecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True

    def _init_weights(self, module):
        std = self.config.init_std
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @property
    def dummy_inputs(self):
        pad_token = self.config.pad_token_id
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
        }
        return dummy_inputs


class Florence2Encoder(Florence2LanguagePreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`Florence2EncoderLayer`].

    Args:
        config: Florence2LanguageConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_position_embeddings
        embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.embed_tokens = Florence2ScaledWordEmbedding(
            config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
        )

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        self.embed_positions = Florence2LearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )
        self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)])
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input = input_ids
            input_ids = input_ids.view(-1, input_ids.shape[-1])
        elif inputs_embeds is not None:
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        embed_pos = self.embed_positions(input)
        embed_pos = embed_pos.to(inputs_embeds.device)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # expand attention_mask
        if attention_mask is not None:
            if self._use_flash_attention_2:
                attention_mask = attention_mask if 0 in attention_mask else None
            elif self._use_sdpa and head_mask is None and not output_attentions:
                # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
                # the manual implementation that requires a 4D causal mask in all cases.
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
            else:
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
                    f" {head_mask.size()[0]}."
                )

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            if to_drop:
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:
                    layer_outputs = self._gradient_checkpointing_func(
                        encoder_layer.__call__,
                        hidden_states,
                        attention_mask,
                        (head_mask[idx] if head_mask is not None else None),
                        output_attentions,
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask,
                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )


class Florence2Decoder(Florence2LanguagePreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Florence2DecoderLayer`]

    Args:
        config: Florence2LanguageConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings
        embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0

        self.embed_tokens = Florence2ScaledWordEmbedding(
            config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
        )

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        self.embed_positions = Florence2LearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
        )
        self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)])
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"

        self.layernorm_embedding = nn.LayerNorm(config.d_model)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
                cross-attention on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            input = input_ids
            input_shape = input.shape
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input)

        if self._use_flash_attention_2:
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
            # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
            # the manual implementation that requires a 4D causal mask in all cases.
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask,
                input_shape,
                inputs_embeds,
                past_key_values_length,
            )
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask, input_shape, inputs_embeds, past_key_values_length
            )

        # expand encoder attention mask
        if encoder_hidden_states is not None and encoder_attention_mask is not None:
            if self._use_flash_attention_2:
                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
            elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
                # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
                # the manual implementation that requires a 4D causal mask in all cases.
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    encoder_attention_mask,
                    inputs_embeds.dtype,
                    tgt_len=input_shape[-1],
                )
            else:
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                encoder_attention_mask = _prepare_4d_attention_mask(
                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
                )

        # embed positions
        positions = self.embed_positions(input, past_key_values_length)
        positions = positions.to(inputs_embeds.device)

        hidden_states = inputs_embeds + positions
        hidden_states = self.layernorm_embedding(hidden_states)

        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
        next_decoder_cache = () if use_cache else None

        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
            if attn_mask is not None:
                if attn_mask.size()[0] != (len(self.layers)):
                    raise ValueError(
                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
                        f" {head_mask.size()[0]}."
                    )

        for idx, decoder_layer in enumerate(self.layers):
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:
                    continue

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    head_mask[idx] if head_mask is not None else None,
                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
                    None,
                    output_attentions,
                    use_cache,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    cross_attn_layer_head_mask=(
                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
                    ),
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )


class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: Florence2LanguageConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        self.encoder = Florence2Encoder(config, self.shared)
        self.decoder = Florence2Decoder(config, self.shared)

        # Initialize weights and apply final processing
        self.post_init()

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, value):
        self.shared = value
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Seq2SeqModelOutput]:
        # different to other models, Florence2 automatically creates decoder_input_ids from
        # input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
                    "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
                    "passed, `input_ids` cannot be `None`. Please pass either "
                    "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
                )

            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
            )

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
    base_model_prefix = "model"
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
    _keys_to_ignore_on_load_missing = ["final_logits_bias"]

    def __init__(self, config: Florence2LanguageConfig):
        super().__init__(config)
        self.model = Florence2LanguageModel(config)
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_encoder(self):
        return self.model.get_encoder()

    def get_decoder(self):
        return self.model.get_decoder()

    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        old_num_tokens = self.final_logits_bias.shape[-1]
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        lm_logits = self.lm_head(outputs[0])
        lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)

        masked_lm_loss = None
        if labels is not None:
            labels = labels.to(lm_logits.device)
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        return reordered_past

@dataclass
class Florence2Seq2SeqLMOutput(ModelOutput):
    """
    Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential
    decoding.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the decoder of the model.

            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
            hidden_size)` is output.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
            weighted average in the cross-attention heads.
        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size,
            num_image_tokens, hidden_size)`.

            image_hidden_states of the model produced by the vision encoder
    """
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None


FLORENCE2_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Florence2Config`] or [`Florence2VisionConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
    "The bare Florence-2 Model outputting raw hidden-states without any specific head on top.",
    FLORENCE2_START_DOCSTRING,
)
class Florence2PreTrainedModel(PreTrainedModel):
    config_class = Florence2Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = "past_key_values"

    @property
    def _supports_flash_attn_2(self):
        """
        Retrieve language_model's attribute to check whether the model supports
        Flash Attention 2 or not.
        """
        return self.language_model._supports_flash_attn_2

    @property
    def _supports_sdpa(self):
        """
        Retrieve language_model's attribute to check whether the model supports
        SDPA or not.
        """
        return self.language_model._supports_sdpa


FLORENCE2_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
            The tensors corresponding to the input images. Pixel values can be obtained using
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Florence2Processor`] uses
            [`CLIPImageProcessor`] for processing images).
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

@add_start_docstrings(
    """The FLORENCE2 vision model without any head""",
    FLORENCE2_START_DOCSTRING,
)
class Florence2VisionModel(Florence2PreTrainedModel):
    def __init__(self, config: Florence2VisionConfig):
        super().__init__(config)
        assert config.model_type == 'davit', 'only DaViT is supported for now'
        self.vision_tower = DaViT.from_config(config=config)

        self.post_init()
    
    def forward(self, pixel_values):
        if len(pixel_values.shape) == 4:
            x = self.vision_tower.forward_features_unpool(pixel_values)
        else:
            raise ValueError(f'invalid image shape {pixel_values.shape}')
        return x


@add_start_docstrings(
    """The FLORENCE2 vision model with projection layer""",
    FLORENCE2_START_DOCSTRING,
)
class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
    def __init__(self, config: Florence2VisionConfig):
        super().__init__(config)
        assert config.model_type == 'davit', 'only DaViT is supported for now'
        self.vision_tower = DaViT.from_config(config=config)

        self._build_image_projection_layers(config)

        self.post_init()
    
    def _build_image_projection_layers(self, config):
        image_dim_out = config.dim_embed[-1]
        dim_projection = config.projection_dim
        self.image_projection = nn.Parameter(
            torch.empty(image_dim_out, dim_projection)
        )
        self.image_proj_norm = nn.LayerNorm(dim_projection)
        image_pos_embed_config = config.image_pos_embed
        if image_pos_embed_config['type'] == 'learned_abs_2d':
            self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
                embedding_dim=image_dim_out,
                num_pos=image_pos_embed_config['max_pos_embeddings']
            )
        else:
            raise NotImplementedError('Not implemented yet')

        self.image_feature_source = config.image_feature_source

        # temporal embedding
        visual_temporal_embedding_config = config.visual_temporal_embedding
        if visual_temporal_embedding_config['type'] == 'COSINE':
            self.visual_temporal_embed = PositionalEmbeddingCosine1D(
                embed_dim=image_dim_out,
                max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
            )
        else:
            raise NotImplementedError('Not implemented yet')

    def forward(self, pixel_values):
        if len(pixel_values.shape) == 4:
            batch_size, C, H, W = pixel_values.shape
            T = 1
            x = self.vision_tower.forward_features_unpool(pixel_values)
        else:
            raise ValueError(f'invalid image shape {pixel_values.shape}')
        
        if self.image_pos_embed is not None:
            x = x.view(batch_size * T, -1, x.shape[-1])
            num_tokens = x.shape[-2]
            h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
            assert h * w == num_tokens, 'only support square feature maps for now'
            x = x.view(batch_size * T, h, w, x.shape[-1])
            pos_embed = self.image_pos_embed(x)
            x = x + pos_embed
            x = x.view(batch_size, T * h*w, x.shape[-1])

        if self.visual_temporal_embed is not None:
            visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
            x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])

        x_feat_dict = {}

        spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
        x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x

        temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
        x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x

        x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
        x_feat_dict['last_frame'] = x

        new_x = []
        for _image_feature_source in self.image_feature_source:
            if _image_feature_source not in x_feat_dict:
                raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
            new_x.append(x_feat_dict[_image_feature_source])

        x = torch.cat(new_x, dim=1)

        x = x @ self.image_projection
        x = self.image_proj_norm(x)


        return x



@add_start_docstrings(
    """The FLORENCE2 model which consists of a vision backbone and a language model.""",
    FLORENCE2_START_DOCSTRING,
)
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
    def __init__(self, config: Florence2Config):
        super().__init__(config)
        assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
        self.vision_tower = DaViT.from_config(config=config.vision_config)
        # remove unused layers 
        del self.vision_tower.head
        del self.vision_tower.norms

        self.vocab_size = config.vocab_size
        self._attn_implementation = config._attn_implementation
        self._build_image_projection_layers(config)

        language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)

        if language_model._tied_weights_keys is not None:
            self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
        self.language_model = language_model

        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
        self.post_init()
    
    def _build_image_projection_layers(self, config):
        image_dim_out = config.vision_config.dim_embed[-1]
        dim_projection = config.vision_config.projection_dim
        self.image_projection = nn.Parameter(
            torch.empty(image_dim_out, dim_projection)
        )
        self.image_proj_norm = nn.LayerNorm(dim_projection)
        image_pos_embed_config = config.vision_config.image_pos_embed
        if image_pos_embed_config['type'] == 'learned_abs_2d':
            self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
                embedding_dim=image_dim_out,
                num_pos=image_pos_embed_config['max_pos_embeddings']
            )
        else:
            raise NotImplementedError('Not implemented yet')

        self.image_feature_source = config.vision_config.image_feature_source

        # temporal embedding
        visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
        if visual_temporal_embedding_config['type'] == 'COSINE':
            self.visual_temporal_embed = PositionalEmbeddingCosine1D(
                embed_dim=image_dim_out,
                max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings']
            )
        else:
            raise NotImplementedError('Not implemented yet')

    def get_encoder(self):
        return self.language_model.get_encoder()

    def get_decoder(self):
        return self.language_model.get_decoder()

    def get_input_embeddings(self):
        return self.language_model.get_input_embeddings()

    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
        model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # update vocab size
        self.config.text_config.vocab_size = model_embeds.num_embeddings
        self.config.vocab_size = model_embeds.num_embeddings
        self.vocab_size = model_embeds.num_embeddings
        return model_embeds
    
    def _encode_image(self, pixel_values):
        if len(pixel_values.shape) == 4:
            batch_size, C, H, W = pixel_values.shape
            T = 1
            x = self.vision_tower.forward_features_unpool(pixel_values)
        else:
            raise ValueError(f'invalid image shape {pixel_values.shape}')
        
        if self.image_pos_embed is not None:
            x = x.view(batch_size * T, -1, x.shape[-1])
            num_tokens = x.shape[-2]
            h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5)
            assert h * w == num_tokens, 'only support square feature maps for now'
            x = x.view(batch_size * T, h, w, x.shape[-1])
            pos_embed = self.image_pos_embed(x)
            x = x + pos_embed
            x = x.view(batch_size, T * h*w, x.shape[-1])

        if self.visual_temporal_embed is not None:
            visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
            x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])

        x_feat_dict = {}

        spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
        x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x

        temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
        x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x

        x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
        x_feat_dict['last_frame'] = x

        new_x = []
        for _image_feature_source in self.image_feature_source:
            if _image_feature_source not in x_feat_dict:
                raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
            new_x.append(x_feat_dict[_image_feature_source])

        x = torch.cat(new_x, dim=1)

        x = x @ self.image_projection
        x = self.image_proj_norm(x)

        return x 

    def _merge_input_ids_with_image_features(
        self, image_features, inputs_embeds 
    ):
        batch_size, image_token_length = image_features.size()[:-1]
        device = image_features.device
        image_attention_mask = torch.ones(batch_size, image_token_length, device=device)

        # task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
        # task_prefix_attention_mask: [batch_size, context_length]
        if inputs_embeds is None:
            return image_features, image_attention_mask

        task_prefix_embeds = inputs_embeds
        task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)

        if len(task_prefix_attention_mask.shape) == 3:
            task_prefix_attention_mask = task_prefix_attention_mask[:, 0]

        # concat [image embeds, task prefix embeds]
        inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
        attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1)

        return inputs_embeds, attention_mask


    @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration

        >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
        >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")

        >>> prompt = "<CAPTION>"
        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(text=prompt, images=image, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_length=100)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "A green car parked in front of a yellow building."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        image_features = None
        if inputs_embeds is None:
            # 1. Extra the input embeddings
            if input_ids is not None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            # 2. Merge text and images
            if pixel_values is not None:
                # (batch_size, num_image_tokens, hidden_size)
                image_features = self._encode_image(pixel_values)
                inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)

        if inputs_embeds is not None:
            attention_mask = attention_mask.to(inputs_embeds.dtype)
        outputs = self.language_model(
            attention_mask=attention_mask,
            labels=labels,
            inputs_embeds=inputs_embeds,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = outputs.logits
        logits = logits.float()
        loss = outputs.loss
        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return Florence2Seq2SeqLMOutput(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
            image_hidden_states=image_features
        )

    def generate(
        self,
        input_ids, 
        inputs_embeds=None,
        pixel_values=None,
        **kwargs
        ):

        if inputs_embeds is None:
            # 1. Extra the input embeddings
            if input_ids is not None:
                inputs_embeds = self.get_input_embeddings()(input_ids)
            # 2. Merge text and images
            if pixel_values is not None:
                image_features = self._encode_image(pixel_values)
                inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
        
        return self.language_model.generate(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        pixel_values=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
        
        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "decoder_attention_mask": decoder_attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }
    
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self.language_model.shift_tokens_right(labels)

    def _reorder_cache(self, *args, **kwargs):
        return self.language_model._reorder_cache(*args, **kwargs)

================================================
FILE: eval/grounded_sam/florence2/preprocessor_config.json
================================================
{
  "auto_map": {
    "AutoProcessor": "processing_florence2.Florence2Processor"
   },
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format",
    "do_convert_rgb"
  ],
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "do_center_crop": false,
  "image_processor_type": "CLIPImageProcessor",
  "image_seq_length": 577,
  "image_mean": [0.485, 0.456, 0.406],
  "image_std":  [0.229, 0.224, 0.225],
  "processor_class": "Florence2Processor",
  "resample": 3,
  "size": {
    "height": 768,
    "width":768 
  },
  "crop_size": {
    "height": 768,
    "width": 768
  }
}

================================================
FILE: eval/grounded_sam/florence2/processing_florence2.py
================================================
# coding=utf-8
# Copyright 2024 Microsoft and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Florence-2.
"""

import re
import logging
from typing import List, Optional, Union
import numpy as np
import math

import torch

from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, is_valid_image
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import (
    PaddingStrategy,
    PreTokenizedInput,
    TextInput,
    TruncationStrategy,
)
from transformers import BartTokenizer, BartTokenizerFast
from transformers.utils import TensorType


logger = logging.getLogger(__name__)

# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
    return isinstance(val, str) and val.startswith("http")

# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
def is_image_or_image_url(elem):
    return is_url(elem) or is_valid_image(elem)


def _is_str_or_image(elem):
    return isinstance(elem, (str)) or is_image_or_image_url(elem)


class Florence2Processor(ProcessorMixin):
    r"""
    Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.

    [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
    [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.

    Args:
        image_processor ([`CLIPImageProcessor`], *optional*):
            The image processor is a required input.
        tokenizer ([`BartTokenizerFast`], *optional*):
            The tokenizer is a required input.
    """

    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "CLIPImageProcessor"
    tokenizer_class = ("BartTokenizer", "BartTokenizerFast")

    def __init__(
        self,
        image_processor=None,
        tokenizer=None,
    ):
        if image_processor is None:
            raise ValueError("You need to specify an `image_processor`.")
        if tokenizer is None:
            raise ValueError("You need to specify a `tokenizer`.")
        if not hasattr(image_processor, "image_seq_length"):
            raise ValueError("Image processor is missing an `image_seq_length` attribute.")

        self.image_seq_length = image_processor.image_seq_length

        tokens_to_add = {
                'additional_special_tokens': \
                    tokenizer.additional_special_tokens + \
                    ['<od>', '</od>', '<ocr>', '</ocr>'] + \
                    [f'<loc_{x}>' for x in range(1000)] + \
                    ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
            }
        tokenizer.add_special_tokens(tokens_to_add)

        self.tasks_answer_post_processing_type = {
            '<OCR>': 'pure_text',
            '<OCR_WITH_REGION>': 'ocr',
            '<CAPTION>': 'pure_text',
            '<DETAILED_CAPTION>': 'pure_text',
            '<MORE_DETAILED_CAPTION>': 'pure_text',
            '<OD>': 'description_with_bboxes',
            '<DENSE_REGION_CAPTION>': 'description_with_bboxes',
            '<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
            '<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
            '<REGION_TO_SEGMENTATION>': 'polygons',
            '<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
            '<REGION_TO_CATEGORY>': 'pure_text',
            '<REGION_TO_DESCRIPTION>': 'pure_text',
            '<REGION_TO_OCR>': 'pure_text',
            '<REGION_PROPOSAL>': 'bboxes'
        }

        self.task_prompts_without_inputs = {
            '<OCR>': 'What is the text in the image?',
            '<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
            '<CAPTION>': 'What does the image describe?',
            '<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
            '<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
            '<OD>': 'Locate the objects with category name in the image.',
            '<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
            '<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
        }

        self.task_prompts_with_input = {
            '<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
            '<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
            '<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
            '<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
            '<REGION_TO_CATEGORY>': 'What is the region {input}?',
            '<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
            '<REGION_TO_OCR>': 'What text is in the region {input}?',
        }

        self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)


        super().__init__(image_processor, tokenizer)
    
    def _construct_prompts(self, text):
        # replace the task tokens with the task prompts if task token is in the text
        prompts = []
        for _text in text:
            # 1. fixed task prompts without additional inputs
            for task_token, task_prompt in self.task_prompts_without_inputs.items():
                if task_token in _text:
                    assert _text == task_token, f"Task token {task_token} should be the only token in the text."
                    _text = task_prompt
                    break
            # 2. task prompts with additional inputs 
            for task_token, task_prompt in self.task_prompts_with_input.items():
                if task_token in _text:
                    _text = task_prompt.format(input=_text.replace(task_token, ''))
                    break
            prompts.append(_text)
        return prompts

    def __call__(
        self,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        images: ImageInput = None,
        tokenize_newline_separately: bool = True,
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = None,
        max_length=None,
        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
        do_resize: bool = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        data_format: Optional["ChannelDimension"] = "channels_first",  # noqa: F821
        input_data_format: Optional[
            Union[str, "ChannelDimension"]  # noqa: F821
        ] = None,
        resample: "PILImageResampling" = None,  # noqa: F821
        do_convert_rgb: bool = None,
        do_thumbnail: bool = None,
        do_align_long_axis: bool = None,
        do_rescale: bool = None,
    ) -> BatchFeature:
        """
        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
        and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
        the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
        CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
        of the above two methods for more information.

        Args:
            text (`str`, `List[str]`, `List[List[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
                number of channels, H and W are image height and width.
            tokenize_newline_separately (`bool`, defaults to `True`):
                Adds a separately tokenized '\n' at the end of the prompt.
            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
                Select a strategy to pad the returned sequences (according to the model's padding side and padding
                index) among:
                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
                  sequence if provided).
                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
                  acceptable input length for the model if that argument is not provided.
                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
                  lengths).
            max_length (`int`, *optional*):
                Maximum length of the returned list and optionally padding length (see above).
            truncation (`bool`, *optional*):
                Activates truncation to cut input sequences longer than `max_length` to `max_length`.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
              is provided, the `input_ids` will also contain the suffix input ids.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
            - **labels** -- Labels compatible with training if `suffix` is not None
        """

        return_token_type_ids = False

        if images is None:
            raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
        if text is None:
            logger.warning_once(
                "You are using Florence-2 without a text prompt."
            )
            text = ""

        if isinstance(text, List) and isinstance(images, List):
            if len(images) < len(text):
                raise ValueError(
                    f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
                )
        if _is_st
Download .txt
gitextract_6xpfra6z/

├── .gitignore
├── .gradio/
│   └── certificate.pem
├── LICENCE
├── README.md
├── assets/
│   ├── ReadMe.md
│   ├── crop_faces.py
│   ├── rename.py
│   └── segmentation.py
├── eval/
│   ├── eval_scripts/
│   │   ├── run_eval_multi.sh
│   │   └── run_eval_single.sh
│   ├── grounded_sam/
│   │   ├── florence2/
│   │   │   ├── config.json
│   │   │   ├── configuration_florence2.py
│   │   │   ├── generation_config.json
│   │   │   ├── modeling_florence2.py
│   │   │   ├── preprocessor_config.json
│   │   │   ├── processing_florence2.py
│   │   │   ├── tokenizer.json
│   │   │   ├── tokenizer_config.json
│   │   │   └── vocab.json
│   │   ├── grounded_sam2_florence2_autolabel_pipeline.py
│   │   └── sam2/
│   │       ├── __init__.py
│   │       ├── automatic_mask_generator.py
│   │       ├── build_sam.py
│   │       ├── configs/
│   │       │   ├── sam2/
│   │       │   │   ├── sam2_hiera_b+.yaml
│   │       │   │   ├── sam2_hiera_l.yaml
│   │       │   │   ├── sam2_hiera_s.yaml
│   │       │   │   └── sam2_hiera_t.yaml
│   │       │   ├── sam2.1/
│   │       │   │   ├── sam2.1_hiera_b+.yaml
│   │       │   │   ├── sam2.1_hiera_l.yaml
│   │       │   │   ├── sam2.1_hiera_s.yaml
│   │       │   │   └── sam2.1_hiera_t.yaml
│   │       │   └── sam2.1_training/
│   │       │       └── sam2.1_hiera_b+_MOSE_finetune.yaml
│   │       ├── csrc/
│   │       │   └── connected_components.cu
│   │       ├── modeling/
│   │       │   ├── __init__.py
│   │       │   ├── backbones/
│   │       │   │   ├── __init__.py
│   │       │   │   ├── hieradet.py
│   │       │   │   ├── image_encoder.py
│   │       │   │   └── utils.py
│   │       │   ├── memory_attention.py
│   │       │   ├── memory_encoder.py
│   │       │   ├── position_encoding.py
│   │       │   ├── sam/
│   │       │   │   ├── __init__.py
│   │       │   │   ├── mask_decoder.py
│   │       │   │   ├── prompt_encoder.py
│   │       │   │   └── transformer.py
│   │       │   ├── sam2_base.py
│   │       │   └── sam2_utils.py
│   │       ├── sam2_hiera_b+.yaml
│   │       ├── sam2_hiera_l.yaml
│   │       ├── sam2_hiera_s.yaml
│   │       ├── sam2_hiera_t.yaml
│   │       ├── sam2_image_predictor.py
│   │       ├── sam2_video_predictor.py
│   │       └── utils/
│   │           ├── __init__.py
│   │           ├── amg.py
│   │           ├── misc.py
│   │           └── transforms.py
│   └── tools/
│       ├── XVerseBench_multi.json
│       ├── XVerseBench_multi_DSG.json
│       ├── XVerseBench_single.json
│       ├── XVerseBench_single_DSG.json
│       ├── dino.py
│       ├── dpg_score.py
│       ├── face_id.py
│       ├── face_utils/
│       │   ├── face.py
│       │   └── face_recg.py
│       ├── florence_sam.py
│       ├── idip_aes_score.py
│       ├── idip_dpg_score.py
│       ├── idip_face_score.py
│       ├── idip_gen_split_idip.py
│       ├── idip_sam-dino_score.py
│       └── log_scores.py
├── inference_single_sample.py
├── requirements.txt
├── run_demo.sh
├── run_gradio.py
├── src/
│   ├── adapters/
│   │   ├── __init__.py
│   │   └── mod_adapters.py
│   ├── flux/
│   │   ├── block.py
│   │   ├── condition.py
│   │   ├── generate.py
│   │   ├── lora_controller.py
│   │   ├── pipeline_tools.py
│   │   └── transformer.py
│   └── utils/
│       ├── data_utils.py
│       ├── gpu_momory_utils.py
│       └── modulation_utils.py
└── train/
    └── config/
        ├── XVerse_config_INF.yaml
        └── XVerse_config_demo.yaml
Download .txt
SYMBOL INDEX (574 symbols across 48 files)

FILE: assets/crop_faces.py
  function detect_and_crop_faces (line 6) | def detect_and_crop_faces(input_dir, output_dir):

FILE: assets/segmentation.py
  function merge_instances (line 8) | def merge_instances(orig_img, indices, ins_bboxes, ins_images):
  function det_seg_img (line 28) | def det_seg_img(image, label):
  function segment_images_in_folder (line 36) | def segment_images_in_folder(input_folder, output_folder):

FILE: eval/grounded_sam/florence2/configuration_florence2.py
  class Florence2VisionConfig (line 25) | class Florence2VisionConfig(PretrainedConfig):
    method __init__ (line 83) | def __init__(
  class Florence2LanguageConfig (line 122) | class Florence2LanguageConfig(PretrainedConfig):
    method __init__ (line 202) | def __init__(
  class Florence2Config (line 272) | class Florence2Config(PretrainedConfig):
    method __init__ (line 317) | def __init__(

FILE: eval/grounded_sam/florence2/modeling_florence2.py
  class LearnedAbsolutePositionEmbedding2D (line 70) | class LearnedAbsolutePositionEmbedding2D(nn.Module):
    method __init__ (line 75) | def __init__(self, embedding_dim=256, num_pos=50):
    method forward (line 80) | def forward(self, pixel_values):
  class PositionalEmbeddingCosine1D (line 103) | class PositionalEmbeddingCosine1D(nn.Module):
    method __init__ (line 114) | def __init__(
    method forward (line 137) | def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
  class LearnedAbsolutePositionEmbedding1D (line 161) | class LearnedAbsolutePositionEmbedding1D(nn.Module):
    method __init__ (line 169) | def __init__(
    method forward (line 177) | def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
  class MySequential (line 203) | class MySequential(nn.Sequential):
    method forward (line 204) | def forward(self, *inputs):
  class PreNorm (line 213) | class PreNorm(nn.Module):
    method __init__ (line 214) | def __init__(self, norm, fn, drop_path=None):
    method forward (line 220) | def forward(self, x, *args, **kwargs):
  class Mlp (line 235) | class Mlp(nn.Module):
    method __init__ (line 236) | def __init__(
    method forward (line 252) | def forward(self, x, size):
  class DepthWiseConv2d (line 256) | class DepthWiseConv2d(nn.Module):
    method __init__ (line 257) | def __init__(
    method forward (line 275) | def forward(self, x, size):
  class ConvEmbed (line 286) | class ConvEmbed(nn.Module):
    method __init__ (line 290) | def __init__(
    method forward (line 315) | def forward(self, x, size):
  class ChannelAttention (line 335) | class ChannelAttention(nn.Module):
    method __init__ (line 337) | def __init__(self, dim, groups=8, qkv_bias=True):
    method forward (line 344) | def forward(self, x, size):
  class ChannelBlock (line 359) | class ChannelBlock(nn.Module):
    method __init__ (line 361) | def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True,
    method forward (line 381) | def forward(self, x, size):
  function window_partition (line 393) | def window_partition(x, window_size: int):
  function window_reverse (line 400) | def window_reverse(windows, batch_size: int, window_size: int, H: int, W...
  class WindowAttention (line 409) | class WindowAttention(nn.Module):
    method __init__ (line 410) | def __init__(self, dim, num_heads, window_size, qkv_bias=True):
    method forward (line 424) | def forward(self, x, size):
  class SpatialBlock (line 469) | class SpatialBlock(nn.Module):
    method __init__ (line 471) | def __init__(self, dim, num_heads, window_size,
    method forward (line 491) | def forward(self, x, size):
  class DaViT (line 502) | class DaViT(nn.Module):
    method __init__ (line 525) | def __init__(
    method dim_out (line 616) | def dim_out(self):
    method _init_weights (line 619) | def _init_weights(self, m):
    method forward_features_unpool (line 636) | def forward_features_unpool(self, x):
    method forward_features (line 651) | def forward_features(self, x):
    method forward (line 662) | def forward(self, x):
    method from_config (line 668) | def from_config(cls, config):
  function _get_unpad_data (line 690) | def _get_unpad_data(attention_mask):
  function shift_tokens_right (line 702) | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decod...
  class Florence2LearnedPositionalEmbedding (line 718) | class Florence2LearnedPositionalEmbedding(nn.Embedding):
    method __init__ (line 723) | def __init__(self, num_embeddings: int, embedding_dim: int):
    method forward (line 729) | def forward(self, input_ids: torch.Tensor, past_key_values_length: int...
  class Florence2ScaledWordEmbedding (line 740) | class Florence2ScaledWordEmbedding(nn.Embedding):
    method __init__ (line 745) | def __init__(self, num_embeddings: int, embedding_dim: int, padding_id...
    method forward (line 749) | def forward(self, input_ids: torch.Tensor):
  class Florence2Attention (line 753) | class Florence2Attention(nn.Module):
    method __init__ (line 756) | def __init__(
    method _shape (line 787) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 790) | def forward(
  class Florence2FlashAttention2 (line 911) | class Florence2FlashAttention2(Florence2Attention):
    method __init__ (line 919) | def __init__(self, *args, **kwargs):
    method _reshape (line 927) | def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 930) | def forward(
    method _flash_attention_forward (line 1031) | def _flash_attention_forward(
    method _upad_input (line 1091) | def _upad_input(self, query_layer, key_layer, value_layer, attention_m...
  class Florence2SdpaAttention (line 1130) | class Florence2SdpaAttention(Florence2Attention):
    method forward (line 1131) | def forward(
  class Florence2EncoderLayer (line 1243) | class Florence2EncoderLayer(nn.Module):
    method __init__ (line 1244) | def __init__(self, config: Florence2LanguageConfig):
    method forward (line 1262) | def forward(
  class Florence2DecoderLayer (line 1313) | class Florence2DecoderLayer(nn.Module):
    method __init__ (line 1314) | def __init__(self, config: Florence2LanguageConfig):
    method forward (line 1343) | def forward(
  class Florence2LanguagePreTrainedModel (line 1434) | class Florence2LanguagePreTrainedModel(PreTrainedModel):
    method _init_weights (line 1444) | def _init_weights(self, module):
    method dummy_inputs (line 1456) | def dummy_inputs(self):
  class Florence2Encoder (line 1466) | class Florence2Encoder(Florence2LanguagePreTrainedModel):
    method __init__ (line 1476) | def __init__(self, config: Florence2LanguageConfig, embed_tokens: Opti...
    method get_input_embeddings (line 1507) | def get_input_embeddings(self):
    method set_input_embeddings (line 1510) | def set_input_embeddings(self, value):
    method forward (line 1513) | def forward(
  class Florence2Decoder (line 1654) | class Florence2Decoder(Florence2LanguagePreTrainedModel):
    method __init__ (line 1663) | def __init__(self, config: Florence2LanguageConfig, embed_tokens: Opti...
    method get_input_embeddings (line 1692) | def get_input_embeddings(self):
    method set_input_embeddings (line 1695) | def set_input_embeddings(self, value):
    method forward (line 1698) | def forward(
  class Florence2LanguageModel (line 1941) | class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
    method __init__ (line 1944) | def __init__(self, config: Florence2LanguageConfig):
    method _tie_weights (line 1956) | def _tie_weights(self):
    method get_input_embeddings (line 1961) | def get_input_embeddings(self):
    method set_input_embeddings (line 1964) | def set_input_embeddings(self, value):
    method get_encoder (line 1969) | def get_encoder(self):
    method get_decoder (line 1972) | def get_decoder(self):
    method forward (line 1975) | def forward(
  class Florence2LanguageForConditionalGeneration (line 2063) | class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrai...
    method __init__ (line 2068) | def __init__(self, config: Florence2LanguageConfig):
    method get_encoder (line 2077) | def get_encoder(self):
    method get_decoder (line 2080) | def get_decoder(self):
    method resize_token_embeddings (line 2083) | def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple...
    method _resize_final_logits_bias (line 2088) | def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
    method get_output_embeddings (line 2097) | def get_output_embeddings(self):
    method set_output_embeddings (line 2100) | def set_output_embeddings(self, new_embeddings):
    method forward (line 2103) | def forward(
    method prepare_inputs_for_generation (line 2184) | def prepare_inputs_for_generation(
    method prepare_decoder_input_ids_from_labels (line 2223) | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
    method _reorder_cache (line 2227) | def _reorder_cache(past_key_values, beam_idx):
  class Florence2Seq2SeqLMOutput (line 2238) | class Florence2Seq2SeqLMOutput(ModelOutput):
  class Florence2PreTrainedModel (line 2330) | class Florence2PreTrainedModel(PreTrainedModel):
    method _supports_flash_attn_2 (line 2337) | def _supports_flash_attn_2(self):
    method _supports_sdpa (line 2345) | def _supports_sdpa(self):
  class Florence2VisionModel (line 2422) | class Florence2VisionModel(Florence2PreTrainedModel):
    method __init__ (line 2423) | def __init__(self, config: Florence2VisionConfig):
    method forward (line 2430) | def forward(self, pixel_values):
  class Florence2VisionModelWithProjection (line 2442) | class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
    method __init__ (line 2443) | def __init__(self, config: Florence2VisionConfig):
    method _build_image_projection_layers (line 2452) | def _build_image_projection_layers(self, config):
    method forward (line 2480) | def forward(self, pixel_values):
  class Florence2ForConditionalGeneration (line 2533) | class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
    method __init__ (line 2534) | def __init__(self, config: Florence2Config):
    method _build_image_projection_layers (line 2555) | def _build_image_projection_layers(self, config):
    method get_encoder (line 2583) | def get_encoder(self):
    method get_decoder (line 2586) | def get_decoder(self):
    method get_input_embeddings (line 2589) | def get_input_embeddings(self):
    method resize_token_embeddings (line 2592) | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None...
    method _encode_image (line 2600) | def _encode_image(self, pixel_values):
    method _merge_input_ids_with_image_features (line 2646) | def _merge_input_ids_with_image_features(
    method forward (line 2673) | def forward(
    method generate (line 2780) | def generate(
    method prepare_inputs_for_generation (line 2803) | def prepare_inputs_for_generation(
    method prepare_decoder_input_ids_from_labels (line 2844) | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
    method _reorder_cache (line 2847) | def _reorder_cache(self, *args, **kwargs):

FILE: eval/grounded_sam/florence2/processing_florence2.py
  function is_url (line 43) | def is_url(val) -> bool:
  function is_image_or_image_url (line 47) | def is_image_or_image_url(elem):
  function _is_str_or_image (line 51) | def _is_str_or_image(elem):
  class Florence2Processor (line 55) | class Florence2Processor(ProcessorMixin):
    method __init__ (line 73) | def __init__(
    method _construct_prompts (line 140) | def _construct_prompts(self, text):
    method __call__ (line 158) | def __call__(
    method batch_decode (line 287) | def batch_decode(self, *args, **kwargs):
    method decode (line 295) | def decode(self, *args, **kwargs):
    method model_input_names (line 304) | def model_input_names(self):
    method post_process_generation (line 309) | def post_process_generation(self, text=None, sequence=None, transition...
  class BoxQuantizer (line 384) | class BoxQuantizer(object):
    method __init__ (line 385) | def __init__(self, mode, bins):
    method quantize (line 389) | def quantize(self, boxes: torch.Tensor, size):
    method dequantize (line 418) | def dequantize(self, boxes: torch.Tensor, size):
  class CoordinatesQuantizer (line 446) | class CoordinatesQuantizer(object):
    method __init__ (line 451) | def __init__(self, mode, bins):
    method quantize (line 455) | def quantize(self, coordinates: torch.Tensor, size):
    method dequantize (line 479) | def dequantize(self, coordinates: torch.Tensor, size):
  class Florence2PostProcesser (line 505) | class Florence2PostProcesser(object):
    method __init__ (line 532) | def __init__(
    method _create_black_list_of_phrase_grounding (line 554) | def _create_black_list_of_phrase_grounding(self):
    method _create_default_config (line 590) | def _create_default_config(self):
    method init_quantizers (line 637) | def init_quantizers(self):
    method decode_with_spans (line 655) | def decode_with_spans(self, tokenizer, token_ids):
    method parse_od_from_text_and_spans (line 678) | def parse_od_from_text_and_spans(
    method parse_ocr_from_text_and_spans (line 709) | def parse_ocr_from_text_and_spans(self,
    method parse_phrase_grounding_from_text_and_spans (line 750) | def parse_phrase_grounding_from_text_and_spans(self, text, pattern, im...
    method parse_description_with_bboxes_from_text_and_spans (line 814) | def parse_description_with_bboxes_from_text_and_spans(
    method parse_description_with_polygons_from_text_and_spans (line 922) | def parse_description_with_polygons_from_text_and_spans(self, text, pa...
    method __call__ (line 1033) | def __call__(

FILE: eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py
  class FlorenceSAM (line 17) | class FlorenceSAM:
    method __init__ (line 39) | def __init__(self, device):
    method __str__ (line 76) | def __str__(self):
    method run_florence2 (line 81) | def run_florence2(self, task_prompt, text_input, image):
    method caption (line 120) | def caption(self, image, caption_task_prompt='<CAPTION>'):
    method segmentation (line 128) | def segmentation(self, image, input_boxes, seg_model="sam"):
    method post_process_results (line 148) | def post_process_results(self, image, caption, labels, detections, out...
    method caption_phrase_grounding_and_segmentation (line 205) | def caption_phrase_grounding_and_segmentation(
    method od_grounding_and_segmentation (line 245) | def od_grounding_and_segmentation(
    method od_grounding (line 277) | def od_grounding(
    method phrase_grounding_and_segmentation (line 302) | def phrase_grounding_and_segmentation(

FILE: eval/grounded_sam/sam2/automatic_mask_generator.py
  class SAM2AutomaticMaskGenerator (line 36) | class SAM2AutomaticMaskGenerator:
    method __init__ (line 37) | def __init__(
    method from_pretrained (line 153) | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMas...
    method generate (line 170) | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
    method _generate_masks (line 224) | def _generate_masks(self, image: np.ndarray) -> MaskData:
    method _process_crop (line 251) | def _process_crop(
    method _process_batch (line 294) | def _process_batch(
    method postprocess_small_regions (line 387) | def postprocess_small_regions(
    method refine_with_m2m (line 437) | def refine_with_m2m(self, points, point_labels, low_res_masks, points_...

FILE: eval/grounded_sam/sam2/build_sam.py
  function build_sam2 (line 76) | def build_sam2(
  function build_sam2_video_predictor (line 105) | def build_sam2_video_predictor(
  function _hf_download (line 142) | def _hf_download(model_id):
  function build_sam2_hf (line 150) | def build_sam2_hf(model_id, **kwargs):
  function build_sam2_video_predictor_hf (line 155) | def build_sam2_video_predictor_hf(model_id, **kwargs):
  function _load_checkpoint (line 162) | def _load_checkpoint(model, ckpt_path):

FILE: eval/grounded_sam/sam2/modeling/backbones/hieradet.py
  function do_pool (line 25) | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) ->...
  class MultiScaleAttention (line 39) | class MultiScaleAttention(nn.Module):
    method __init__ (line 40) | def __init__(
    method forward (line 56) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class MultiScaleBlock (line 84) | class MultiScaleBlock(nn.Module):
    method __init__ (line 85) | def __init__(
    method forward (line 134) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Hiera (line 169) | class Hiera(nn.Module):
    method __init__ (line 174) | def __init__(
    method _get_pos_embed (line 273) | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
    method forward (line 283) | def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
    method get_layer_id (line 301) | def get_layer_id(self, layer_name):
    method get_num_layers (line 316) | def get_num_layers(self) -> int:

FILE: eval/grounded_sam/sam2/modeling/backbones/image_encoder.py
  class ImageEncoder (line 14) | class ImageEncoder(nn.Module):
    method __init__ (line 15) | def __init__(
    method forward (line 29) | def forward(self, sample: torch.Tensor):
  class FpnNeck (line 45) | class FpnNeck(nn.Module):
    method __init__ (line 52) | def __init__(
    method forward (line 102) | def forward(self, xs: List[torch.Tensor]):

FILE: eval/grounded_sam/sam2/modeling/backbones/utils.py
  function window_partition (line 16) | def window_partition(x, window_size):
  function window_unpartition (line 41) | def window_unpartition(windows, window_size, pad_hw, hw):
  class PatchEmbed (line 65) | class PatchEmbed(nn.Module):
    method __init__ (line 70) | def __init__(
    method forward (line 91) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: eval/grounded_sam/sam2/modeling/memory_attention.py
  class MemoryAttentionLayer (line 17) | class MemoryAttentionLayer(nn.Module):
    method __init__ (line 19) | def __init__(
    method _forward_sa (line 58) | def _forward_sa(self, tgt, query_pos):
    method _forward_ca (line 66) | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
    method forward (line 83) | def forward(
  class MemoryAttention (line 102) | class MemoryAttention(nn.Module):
    method __init__ (line 103) | def __init__(
    method forward (line 119) | def forward(

FILE: eval/grounded_sam/sam2/modeling/memory_encoder.py
  class MaskDownSampler (line 17) | class MaskDownSampler(nn.Module):
    method __init__ (line 26) | def __init__(
    method forward (line 57) | def forward(self, x):
  class CXBlock (line 62) | class CXBlock(nn.Module):
    method __init__ (line 74) | def __init__(
    method forward (line 104) | def forward(self, x):
  class Fuser (line 120) | class Fuser(nn.Module):
    method __init__ (line 121) | def __init__(self, layer, num_layers, dim=None, input_projection=False):
    method forward (line 130) | def forward(self, x):
  class MemoryEncoder (line 138) | class MemoryEncoder(nn.Module):
    method __init__ (line 139) | def __init__(
    method forward (line 158) | def forward(

FILE: eval/grounded_sam/sam2/modeling/position_encoding.py
  class PositionEmbeddingSine (line 16) | class PositionEmbeddingSine(nn.Module):
    method __init__ (line 22) | def __init__(
    method _encode_xy (line 42) | def _encode_xy(self, x, y):
    method encode_boxes (line 62) | def encode_boxes(self, x, y, w, h):
    method encode_points (line 70) | def encode_points(self, x, y, labels):
    method forward (line 79) | def forward(self, x: torch.Tensor):
  class PositionEmbeddingRandom (line 115) | class PositionEmbeddingRandom(nn.Module):
    method __init__ (line 120) | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = N...
    method _pe_encoding (line 129) | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
    method forward (line 138) | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
    method forward_with_coords (line 151) | def forward_with_coords(
  function init_t_xy (line 167) | def init_t_xy(end_x: int, end_y: int):
  function compute_axial_cis (line 174) | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 1...
  function reshape_for_broadcast (line 186) | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  function apply_rotary_enc (line 194) | def apply_rotary_enc(

FILE: eval/grounded_sam/sam2/modeling/sam/mask_decoder.py
  class MaskDecoder (line 15) | class MaskDecoder(nn.Module):
    method __init__ (line 16) | def __init__(
    method forward (line 110) | def forward(
    method predict_masks (line 168) | def predict_masks(
    method _get_stability_scores (line 247) | def _get_stability_scores(self, mask_logits):
    method _dynamic_multimask_via_stability (line 259) | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_sc...

FILE: eval/grounded_sam/sam2/modeling/sam/prompt_encoder.py
  class PromptEncoder (line 17) | class PromptEncoder(nn.Module):
    method __init__ (line 18) | def __init__(
    method get_dense_pe (line 68) | def get_dense_pe(self) -> torch.Tensor:
    method _embed_points (line 79) | def _embed_points(
    method _embed_boxes (line 103) | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    method _embed_masks (line 114) | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
    method _get_batch_size (line 119) | def _get_batch_size(
    method _get_device (line 137) | def _get_device(self) -> torch.device:
    method forward (line 140) | def forward(

FILE: eval/grounded_sam/sam2/modeling/sam/transformer.py
  function sdp_kernel_context (line 28) | def sdp_kernel_context(dropout_p):
  class TwoWayTransformer (line 44) | class TwoWayTransformer(nn.Module):
    method __init__ (line 45) | def __init__(
    method forward (line 90) | def forward(
  class TwoWayAttentionBlock (line 137) | class TwoWayAttentionBlock(nn.Module):
    method __init__ (line 138) | def __init__(
    method forward (line 181) | def forward(
  class Attention (line 215) | class Attention(nn.Module):
    method __init__ (line 221) | def __init__(
    method _separate_heads (line 245) | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
    method _recombine_heads (line 250) | def _recombine_heads(self, x: Tensor) -> Tensor:
    method forward (line 255) | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  class RoPEAttention (line 289) | class RoPEAttention(Attention):
    method __init__ (line 292) | def __init__(
    method forward (line 311) | def forward(

FILE: eval/grounded_sam/sam2/modeling/sam2_base.py
  class SAM2Base (line 22) | class SAM2Base(torch.nn.Module):
    method __init__ (line 23) | def __init__(
    method device (line 198) | def device(self):
    method forward (line 201) | def forward(self, *args, **kwargs):
    method _build_sam_heads (line 207) | def _build_sam_heads(self):
    method _forward_sam_heads (line 257) | def _forward_sam_heads(
    method _use_mask_as_output (line 415) | def _use_mask_as_output(self, backbone_features, high_res_features, ma...
    method forward_image (line 467) | def forward_image(self, img_batch: torch.Tensor):
    method _prepare_backbone_features (line 482) | def _prepare_backbone_features(self, backbone_out):
    method _prepare_memory_conditioned_features (line 498) | def _prepare_memory_conditioned_features(
    method _encode_new_memory (line 677) | def _encode_new_memory(
    method _track_step (line 727) | def _track_step(
    method _encode_memory_in_output (line 788) | def _encode_memory_in_output(
    method track_step (line 813) | def track_step(
    method _use_multimask (line 880) | def _use_multimask(self, is_init_cond_frame, point_inputs):
    method _apply_non_overlapping_constraints (line 890) | def _apply_non_overlapping_constraints(self, pred_masks):

FILE: eval/grounded_sam/sam2/modeling/sam2_utils.py
  function select_closest_cond_frames (line 19) | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_f...
  function get_1d_sine_pe (line 64) | def get_1d_sine_pe(pos_inds, dim, temperature=10000):
  function get_activation_fn (line 77) | def get_activation_fn(activation):
  function get_clones (line 88) | def get_clones(module, N):
  class DropPath (line 92) | class DropPath(nn.Module):
    method __init__ (line 94) | def __init__(self, drop_prob=0.0, scale_by_keep=True):
    method forward (line 99) | def forward(self, x):
  class MLP (line 112) | class MLP(nn.Module):
    method __init__ (line 113) | def __init__(
    method forward (line 131) | def forward(self, x):
  class LayerNorm2d (line 141) | class LayerNorm2d(nn.Module):
    method __init__ (line 142) | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
    method forward (line 148) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function sample_box_points (line 156) | def sample_box_points(
  function sample_random_points_from_errors (line 202) | def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
  function sample_one_point_from_error_center (line 252) | def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
  function get_next_point (line 317) | def get_next_point(gt_masks, pred_masks, method):

FILE: eval/grounded_sam/sam2/sam2_image_predictor.py
  class SAM2ImagePredictor (line 19) | class SAM2ImagePredictor:
    method __init__ (line 20) | def __init__(
    method from_pretrained (line 68) | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredict...
    method set_image (line 85) | def set_image(
    method set_image_batch (line 131) | def set_image_batch(
    method predict_batch (line 174) | def predict_batch(
    method predict (line 236) | def predict(
    method _prep_prompts (line 304) | def _prep_prompts(
    method _predict (line 336) | def _predict(
    method get_image_embedding (line 439) | def get_image_embedding(self) -> torch.Tensor:
    method device (line 455) | def device(self) -> torch.device:
    method reset_predictor (line 458) | def reset_predictor(self) -> None:

FILE: eval/grounded_sam/sam2/sam2_video_predictor.py
  class SAM2VideoPredictor (line 18) | class SAM2VideoPredictor(SAM2Base):
    method __init__ (line 21) | def __init__(
    method init_state (line 44) | def init_state(
    method from_pretrained (line 114) | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredict...
    method _obj_id_to_idx (line 130) | def _obj_id_to_idx(self, inference_state, obj_id):
    method _obj_idx_to_id (line 164) | def _obj_idx_to_id(self, inference_state, obj_idx):
    method _get_obj_num (line 168) | def _get_obj_num(self, inference_state):
    method add_new_points_or_box (line 173) | def add_new_points_or_box(
    method add_new_points (line 316) | def add_new_points(self, *args, **kwargs):
    method add_new_mask (line 321) | def add_new_mask(
    method _get_orig_video_res_output (line 404) | def _get_orig_video_res_output(self, inference_state, any_res_masks):
    method _consolidate_temp_output_across_obj (line 426) | def _consolidate_temp_output_across_obj(
    method _get_empty_mask_ptr (line 556) | def _get_empty_mask_ptr(self, inference_state, frame_idx):
    method propagate_in_video_preflight (line 593) | def propagate_in_video_preflight(self, inference_state):
    method propagate_in_video (line 663) | def propagate_in_video(
    method _add_output_per_object (line 747) | def _add_output_per_object(
    method clear_all_prompts_in_frame (line 777) | def clear_all_prompts_in_frame(
    method reset_state (line 848) | def reset_state(self, inference_state):
    method _reset_tracking_results (line 860) | def _reset_tracking_results(self, inference_state):
    method _get_image_feature (line 879) | def _get_image_feature(self, inference_state, frame_idx, batch_size):
    method _run_single_frame_inference (line 912) | def _run_single_frame_inference(
    method _run_memory_encoder (line 980) | def _run_memory_encoder(
    method _get_maskmem_pos_enc (line 1016) | def _get_maskmem_pos_enc(self, inference_state, current_out):
    method remove_object (line 1042) | def remove_object(self, inference_state, obj_id, strict=False, need_ou...
    method _clear_non_cond_mem_around_input (line 1155) | def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):

FILE: eval/grounded_sam/sam2/utils/amg.py
  class MaskData (line 18) | class MaskData:
    method __init__ (line 24) | def __init__(self, **kwargs) -> None:
    method __setitem__ (line 31) | def __setitem__(self, key: str, item: Any) -> None:
    method __delitem__ (line 37) | def __delitem__(self, key: str) -> None:
    method __getitem__ (line 40) | def __getitem__(self, key: str) -> Any:
    method items (line 43) | def items(self) -> ItemsView[str, Any]:
    method filter (line 46) | def filter(self, keep: torch.Tensor) -> None:
    method cat (line 61) | def cat(self, new_stats: "MaskData") -> None:
    method to_numpy (line 74) | def to_numpy(self) -> None:
  function is_box_near_crop_edge (line 80) | def is_box_near_crop_edge(
  function box_xyxy_to_xywh (line 93) | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
  function batch_iterator (line 100) | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None,...
  function mask_to_rle_pytorch (line 109) | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
  function rle_to_mask (line 140) | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
  function area_from_rle (line 154) | def area_from_rle(rle: Dict[str, Any]) -> int:
  function calculate_stability_score (line 158) | def calculate_stability_score(
  function build_point_grid (line 181) | def build_point_grid(n_per_side: int) -> np.ndarray:
  function build_all_layer_point_grids (line 191) | def build_all_layer_point_grids(
  function generate_crop_boxes (line 202) | def generate_crop_boxes(
  function uncrop_boxes_xyxy (line 239) | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch...
  function uncrop_points (line 248) | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Te...
  function uncrop_masks (line 257) | def uncrop_masks(
  function remove_small_regions (line 269) | def remove_small_regions(
  function coco_encode_rle (line 296) | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
  function batched_mask_to_box (line 305) | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:

FILE: eval/grounded_sam/sam2/utils/misc.py
  function get_sdpa_settings (line 17) | def get_sdpa_settings():
  function get_connected_components (line 47) | def get_connected_components(mask):
  function mask_to_box (line 66) | def mask_to_box(masks: torch.Tensor):
  function _load_img_as_tensor (line 92) | def _load_img_as_tensor(img_path, image_size):
  class AsyncVideoFrameLoader (line 104) | class AsyncVideoFrameLoader:
    method __init__ (line 109) | def __init__(
    method __getitem__ (line 147) | def __getitem__(self, index):
    method __len__ (line 168) | def __len__(self):
  function load_video_frames (line 172) | def load_video_frames(
  function load_video_frames_from_jpg_images (line 213) | def load_video_frames_from_jpg_images(
  function load_video_frames_from_video_file (line 280) | def load_video_frames_from_video_file(
  function fill_holes_in_mask_scores (line 312) | def fill_holes_in_mask_scores(mask, max_area):
  function concat_points (line 341) | def concat_points(old_point_inputs, new_points, new_labels):

FILE: eval/grounded_sam/sam2/utils/transforms.py
  class SAM2Transforms (line 15) | class SAM2Transforms(nn.Module):
    method __init__ (line 16) | def __init__(
    method __call__ (line 37) | def __call__(self, x):
    method forward_batch (line 41) | def forward_batch(self, img_list):
    method transform_coords (line 46) | def transform_coords(
    method transform_boxes (line 66) | def transform_boxes(
    method postprocess_masks (line 76) | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Ten...

FILE: eval/tools/dino.py
  class DINOScore (line 23) | class DINOScore:
    method __init__ (line 25) | def __init__(self, device, use_center_crop=True):
    method __call__ (line 49) | def __call__(self, image_x, image_y, similarity_type="class"):
    method avg_similairty (line 62) | def avg_similairty(self, x, y):
    method cls_similarity (line 65) | def cls_similarity(self, x, y):

FILE: eval/tools/dpg_score.py
  class MPLUG (line 22) | class MPLUG(torch.nn.Module):
    method __init__ (line 23) | def __init__(self, ckpt='damo/mplug_visual-question-answering_coco_lar...
    method vqa (line 29) | def vqa(self, image, question):
  class DPGScore (line 35) | class DPGScore:
    method __init__ (line 36) | def __init__(self, device):
    method __call__ (line 42) | def __call__(self, image, q_dict):
  function prepare_dpg_data (line 89) | def prepare_dpg_data(csv_path):
  function parse_args (line 136) | def parse_args():

FILE: eval/tools/face_id.py
  function expand_bounding_box (line 26) | def expand_bounding_box(x_min, y_min, x_max, y_max, factor=1.3):
  class FaceID (line 47) | class FaceID:
    method __init__ (line 48) | def __init__(self, device):
    method detect (line 61) | def detect(self, image, expand_scale=1.3):
    method __call__ (line 68) | def __call__(self, image_x, image_y, normalize=False):

FILE: eval/tools/face_utils/face.py
  function resize_image (line 29) | def resize_image(image, max_size=1024):
  function open_and_resize_image (line 41) | def open_and_resize_image(image_file, max_size=1024, return_type='numpy'):
  function loose_warp_face (line 75) | def loose_warp_face(input_image, face_detector, face_target_shape=(512, ...
  function tight_warp_face (line 207) | def tight_warp_face(input_image, face_detector, face_parser=None, device...

FILE: eval/tools/face_utils/face_recg.py
  class Flatten (line 22) | class Flatten(Module):
    method forward (line 23) | def forward(self, input):
  function l2_norm (line 26) | def l2_norm(input,axis=1):
  class SEModule (line 31) | class SEModule(Module):
    method __init__ (line 32) | def __init__(self, channels, reduction):
    method forward (line 42) | def forward(self, x):
  class bottleneck_IR (line 51) | class bottleneck_IR(Module):
    method __init__ (line 52) | def __init__(self, in_channel, depth, stride):
    method forward (line 64) | def forward(self, x):
  class bottleneck_IR_SE (line 69) | class bottleneck_IR_SE(Module):
    method __init__ (line 70) | def __init__(self, in_channel, depth, stride):
    method forward (line 86) | def forward(self,x):
  class Bottleneck (line 91) | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
  function get_block (line 94) | def get_block(in_channel, depth, num_units, stride = 2):
  function get_blocks (line 97) | def get_blocks(num_layers):
  class Backbone (line 121) | class Backbone(Module):
    method __init__ (line 122) | def __init__(self, num_layers, drop_ratio, mode='ir'):
    method forward (line 148) | def forward(self,x):

FILE: eval/tools/florence_sam.py
  class ObjectDetector (line 22) | class ObjectDetector:
    method __init__ (line 23) | def __init__(self, device):
    method get_instances (line 27) | def get_instances(self, gen_image, label, min_size=64):
    method get_multiple_instances (line 46) | def get_multiple_instances(self, gen_image, label, min_size=64):

FILE: eval/tools/idip_aes_score.py
  function parse_args (line 33) | def parse_args():
  function main (line 41) | def main():

FILE: eval/tools/idip_dpg_score.py
  function parse_args (line 36) | def parse_args():
  function main (line 44) | def main():

FILE: eval/tools/idip_face_score.py
  function parse_args (line 37) | def parse_args():
  function main (line 45) | def main():

FILE: eval/tools/idip_gen_split_idip.py
  function parse_args (line 33) | def parse_args():
  function main (line 44) | def main():

FILE: eval/tools/idip_sam-dino_score.py
  function parse_args (line 37) | def parse_args():
  function main (line 44) | def main():

FILE: eval/tools/log_scores.py
  function parse_args (line 25) | def parse_args():
  function read_txt_first_line (line 32) | def read_txt_first_line(file_path):
  function read_txt_second_line (line 36) | def read_txt_second_line(file_path):

FILE: inference_single_sample.py
  function generate_image (line 42) | def generate_image(model, prompt, cond_size, target_height, target_width...
  function main (line 168) | def main():

FILE: run_gradio.py
  function clear_images (line 122) | def clear_images():
  function det_seg_img (line 125) | def det_seg_img(image, label):
  function crop_face_img (line 133) | def crop_face_img(image):
  function vlm_img_caption (line 145) | def vlm_img_caption(image):
  function generate_random_string (line 166) | def generate_random_string(length=4):
  function resize_keep_aspect_ratio (line 171) | def resize_keep_aspect_ratio(pil_image, target_size=1024):
  function open_accordion_on_example_selection (line 190) | def open_accordion_on_example_selection(*args):
  function generate_image (line 204) | def generate_image(
  function create_image_input (line 367) | def create_image_input(index, open=True, indexs_state=None):
  function merge_instances (line 391) | def merge_instances(orig_img, indices, ins_bboxes, ins_images):
  function change_accordion (line 409) | def change_accordion(at: bool, index: int, state: list):
  function update_inputs (line 424) | def update_inputs(is_open, index, state: list):

FILE: src/adapters/mod_adapters.py
  class SquaredReLU (line 36) | class SquaredReLU(nn.Module):
    method forward (line 37) | def forward(self, x: torch.Tensor):
  class AdaLayerNorm (line 40) | class AdaLayerNorm(nn.Module):
    method __init__ (line 41) | def __init__(self, embedding_dim: int, time_embedding_dim: Optional[in...
    method forward (line 54) | def forward(
  class PerceiverAttentionBlock (line 62) | class PerceiverAttentionBlock(nn.Module):
    method __init__ (line 63) | def __init__(
    method attention (line 86) | def attention(self, q: torch.Tensor, kv: torch.Tensor, attn_mask: torc...
    method forward (line 90) | def forward(
  class CLIPModAdapter (line 117) | class CLIPModAdapter(ModelMixin, ConfigMixin):
    method __init__ (line 119) | def __init__(
    method enable_gradient_checkpointing (line 153) | def enable_gradient_checkpointing(self):
    method forward (line 159) | def forward(self, t_emb, llm_hidden_states, clip_outputs):
  class TextImageResampler (line 176) | class TextImageResampler(nn.Module):
    method __init__ (line 177) | def __init__(
    method enable_gradient_checkpointing (line 211) | def enable_gradient_checkpointing(self):
    method forward (line 215) | def forward(

FILE: src/flux/block.py
  function scaled_dot_product_attention (line 25) | def scaled_dot_product_attention(query, key, value, attn_mask=None, drop...
  function attn_forward (line 49) | def attn_forward(
  function set_delta_by_start_end (line 363) | def set_delta_by_start_end(
  function norm1_context_forward (line 376) | def norm1_context_forward(
  function norm1_forward (line 415) | def norm1_forward(
  function block_forward (line 452) | def block_forward(
  function single_norm_forward (line 670) | def single_norm_forward(
  function single_block_forward (line 707) | def single_block_forward(

FILE: src/flux/condition.py
  class Condition (line 37) | class Condition(object):
    method __init__ (line 38) | def __init__(
    method get_condition (line 56) | def get_condition(
    method type_id (line 94) | def type_id(self) -> int:
    method get_type_id (line 101) | def get_type_id(cls, condition_type: str) -> int:
    method encode (line 107) | def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tens...

FILE: src/flux/generate.py
  function get_config (line 40) | def get_config(config_path: str = None):
  function prepare_params (line 49) | def prepare_params(
  function seed_everything (line 94) | def seed_everything(seed: int = 42):
  function generate (line 101) | def generate(
  function generate_from_test_sample (line 519) | def generate_from_test_sample(

FILE: src/flux/lora_controller.py
  class enable_lora (line 20) | class enable_lora:
    method __init__ (line 21) | def __init__(self, lora_modules: List[BaseTunerLayer], dit_activated: ...
    method __enter__ (line 40) | def __enter__(self) -> None:
    method __exit__ (line 57) | def __exit__(
  class set_lora_scale (line 69) | class set_lora_scale:
    method __init__ (line 70) | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -...
    method __enter__ (line 83) | def __enter__(self) -> None:
    method __exit__ (line 89) | def __exit__(

FILE: src/flux/pipeline_tools.py
  function encode_vae_images (line 57) | def encode_vae_images(pipeline: FluxPipeline, images: Tensor):
  function decode_vae_images (line 82) | def decode_vae_images(pipeline: FluxPipeline, latents: Tensor, height, w...
  function _get_clip_prompt_embeds (line 89) | def _get_clip_prompt_embeds(
  function encode_prompt_with_clip_t5 (line 127) | def encode_prompt_with_clip_t5(
  function prepare_text_input (line 213) | def prepare_text_input(
  function prepare_t5_input (line 239) | def prepare_t5_input(
  function tokenize_t5_prompt (line 265) | def tokenize_t5_prompt(pipe, input_prompt, max_length, **kargs):
  function clear_attn_maps (line 277) | def clear_attn_maps(transformer):
  function gather_attn_maps (line 286) | def gather_attn_maps(transformer, clear=False):
  function process_token (line 315) | def process_token(token, startofword):
  function save_attention_image (line 331) | def save_attention_image(attn_map, tokens, batch_dir, to_pil):
  function save_attention_maps (line 342) | def save_attention_maps(attn_maps, pipe, prompts, base_dir='attn_maps'):
  function gather_cond2latents (line 389) | def gather_cond2latents(transformer, clear=False):
  function save_cond2latent_image (line 415) | def save_cond2latent_image(attn_map, batch_dir, to_pil):
  function save_cond2latent (line 421) | def save_cond2latent(attn_maps, base_dir='attn_maps'):
  function quantization (line 463) | def quantization(pipe, qtype, t5_only=False):
  class CustomFluxPipeline (line 519) | class CustomFluxPipeline:
    method __init__ (line 520) | def __init__(
    method add_modulation_adapter (line 583) | def add_modulation_adapter(self, modulation_adapter):
    method clear_modulation_adapters (line 587) | def clear_modulation_adapters(self):
  function load_clip (line 592) | def load_clip(pipeline, config, torch_dtype, device, ckpt_dir=None, is_t...
  function load_dit_lora (line 599) | def load_dit_lora(pipeline, pipe, config, torch_dtype, device, ckpt_dir=...
  function load_modulation_adapter (line 670) | def load_modulation_adapter(pipeline, config, torch_dtype, device, ckpt_...
  function load_ckpt (line 722) | def load_ckpt(pipeline, ckpt_dir, is_training=False):

FILE: src/flux/transformer.py
  function prepare_params (line 38) | def prepare_params(
  function transformer_forward (line 67) | def transformer_forward(

FILE: src/utils/data_utils.py
  function get_rank_and_worldsize (line 29) | def get_rank_and_worldsize():
  function get_train_config (line 40) | def get_train_config(config_path=None):
  function calculate_aspect_ratios (line 48) | def calculate_aspect_ratios(resolution):
  function get_closest_ratio (line 81) | def get_closest_ratio(height: float, width: float, ratios: dict):
  function _aspect_ratio_batched (line 87) | def _aspect_ratio_batched(
  function apply_aspect_ratio_batched (line 154) | def apply_aspect_ratio_batched(dataset, batchsize, aspect_ratios, batch_...
  function get_aspect_ratios (line 165) | def get_aspect_ratios(enable_aspect_ratio, resolution):
  function bbox_to_grid (line 187) | def bbox_to_grid(bbox, image_size, output_size=(224, 224)):
  function random_crop_instance (line 214) | def random_crop_instance(instance, min_crop_ratio):
  function compute_psnr (line 236) | def compute_psnr(x, y):
  function replace_first_occurrence (line 243) | def replace_first_occurrence(sentence, word_or_phrase, replace_with):
  function decode_base64_to_image (line 264) | def decode_base64_to_image(base64_str):
  function jpeg_compression (line 273) | def jpeg_compression(pil_image, quality):
  function pad_to_square (line 278) | def pad_to_square(pil_image):
  function pad_to_target (line 286) | def pad_to_target(pil_image, target_size):
  function image_grid (line 310) | def image_grid(imgs, rows, cols):
  function split_grid (line 323) | def split_grid(image):
  function add_border (line 340) | def add_border(img, border_color, border_thickness):
  function merge_bboxes (line 357) | def merge_bboxes(bboxes):
  function flip_bbox_left_right (line 377) | def flip_bbox_left_right(bbox, image_width):
  function json_load (line 394) | def json_load(path, encoding='ascii'):
  function json_dump (line 398) | def json_dump(obj, path, encoding='ascii', indent=4, create_dir=True, ve...

FILE: src/utils/gpu_momory_utils.py
  class ForwardHookManager (line 5) | class ForwardHookManager:
    method __init__ (line 6) | def __init__(self, threshold_mem = 8 * 1024 * 1024 * 1024, use_lower_v...
    method _get_available_memory (line 13) | def _get_available_memory(self):
    method _free_up_memory (line 18) | def _free_up_memory(self, required_mem, cache_model = None):
    method model_to_cuda (line 37) | def model_to_cuda(self, model):
    method _register (line 88) | def _register(self, model):
    method replace_module_children (line 147) | def replace_module_children(self, model, deep = 0):
    method register (line 161) | def register(self, model):
    method revert (line 172) | def revert(self):

FILE: src/utils/modulation_utils.py
  function unpad_input_ids (line 19) | def unpad_input_ids(input_ids, attention_mask):
  function get_word_index (line 22) | def get_word_index(pipe, prompt, input_ids, word, word_count=1, max_leng...
Condensed preview — 90 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4,390K chars).
[
  {
    "path": ".gitignore",
    "chars": 103,
    "preview": "assets/XVerseBench/animal/*\nassets/XVerseBench/object/*\n__pycache__\ncheckpoints/*\ngenerated_*\ntmp\n*.png"
  },
  {
    "path": ".gradio/certificate.pem",
    "chars": 1939,
    "preview": "-----BEGIN CERTIFICATE-----\nMIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw\nTzELMAkGA1UEBhMCVVMxKTAnBgN"
  },
  {
    "path": "LICENCE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 11540,
    "preview": "# XVerse: Consistent Multi-Subject Control of Identity and Semantic Attributes via DiT Modulation\n\n<p align=\"center\">\n  "
  },
  {
    "path": "assets/ReadMe.md",
    "chars": 3579,
    "preview": "# Install of XVerseBench\n\nExisting controlled image generation benchmarks often focus on either maintaining identity or "
  },
  {
    "path": "assets/crop_faces.py",
    "chars": 2841,
    "preview": "import os\nimport face_recognition\nfrom PIL import Image, ImageOps\nimport numpy as np\n\ndef detect_and_crop_faces(input_di"
  },
  {
    "path": "assets/rename.py",
    "chars": 2677,
    "preview": "import os\nimport shutil\n\nsplit = [(\"live_subject/animal\", \"animal\"), (\"object\", \"object\")]\n\n# 定义目录路径\ncaption_dir_base = "
  },
  {
    "path": "assets/segmentation.py",
    "chars": 3147,
    "preview": "from src.utils.data_utils import get_train_config, image_grid, pil2tensor, json_dump, pad_to_square, cv2pil, merge_bboxe"
  },
  {
    "path": "eval/eval_scripts/run_eval_multi.sh",
    "chars": 1371,
    "preview": "export config_path=\"./train/config/XVerse_config_INF.yaml\"\nexport model_checkpoint=\"./checkpoints/XVerse\"\nexport target_"
  },
  {
    "path": "eval/eval_scripts/run_eval_single.sh",
    "chars": 1377,
    "preview": "export config_path=\"./train/config/XVerse_config_INF.yaml\"\nexport model_checkpoint=\"./checkpoints/XVerse\"\nexport target_"
  },
  {
    "path": "eval/grounded_sam/florence2/config.json",
    "chars": 2445,
    "preview": "{\n  \"_name_or_path\": \"florence2\",\n  \"architectures\": [\n    \"Florence2ForConditionalGeneration\"\n  ],\n  \"auto_map\": {\n    "
  },
  {
    "path": "eval/grounded_sam/florence2/configuration_florence2.py",
    "chars": 15125,
    "preview": "# coding=utf-8\n# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.\n# Licensed under the Apach"
  },
  {
    "path": "eval/grounded_sam/florence2/generation_config.json",
    "chars": 51,
    "preview": "{\n    \"num_beams\": 3,\n    \"early_stopping\": false\n}"
  },
  {
    "path": "eval/grounded_sam/florence2/modeling_florence2.py",
    "chars": 127294,
    "preview": "# coding=utf-8\n# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apa"
  },
  {
    "path": "eval/grounded_sam/florence2/preprocessor_config.json",
    "chars": 806,
    "preview": "{\n  \"auto_map\": {\n    \"AutoProcessor\": \"processing_florence2.Florence2Processor\"\n   },\n  \"_valid_processor_keys\": [\n    "
  },
  {
    "path": "eval/grounded_sam/florence2/processing_florence2.py",
    "chars": 48674,
    "preview": "# coding=utf-8\n# Copyright 2024 Microsoft and The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version "
  },
  {
    "path": "eval/grounded_sam/florence2/tokenizer.json",
    "chars": 1284209,
    "preview": "{\"version\":\"1.0\",\"truncation\":null,\"padding\":null,\"added_tokens\":[{\"id\":0,\"special\":true,\"content\":\"<s>\",\"single_word\":f"
  },
  {
    "path": "eval/grounded_sam/florence2/tokenizer_config.json",
    "chars": 34,
    "preview": "{\n    \"model_max_length\": 1024\n}\n\n"
  },
  {
    "path": "eval/grounded_sam/florence2/vocab.json",
    "chars": 1063976,
    "preview": "{\n    \"<s>\": 0,\n    \"<pad>\": 1,\n    \"</s>\": 2,\n    \"<unk>\": 3,\n    \".\": 4,\n    \"Ġthe\": 5,\n    \",\": 6,\n    \"Ġto\": 7,\n    "
  },
  {
    "path": "eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py",
    "chars": 14399,
    "preview": "import os\nimport cv2\nimport torch\nimport argparse\nimport numpy as np\nimport supervision as sv\nfrom PIL import Image\nimpo"
  },
  {
    "path": "eval/grounded_sam/sam2/__init__.py",
    "chars": 395,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/automatic_mask_generator.py",
    "chars": 18461,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/build_sam.py",
    "chars": 6355,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2/sam2_hiera_b+.yaml",
    "chars": 3548,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2/sam2_hiera_l.yaml",
    "chars": 3696,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2/sam2_hiera_s.yaml",
    "chars": 3659,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2/sam2_hiera_t.yaml",
    "chars": 3753,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml",
    "chars": 3650,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_l.yaml",
    "chars": 3798,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
    "chars": 3761,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2.1/sam2.1_hiera_t.yaml",
    "chars": 3855,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml",
    "chars": 11256,
    "preview": "# @package _global_\n\nscratch:\n  resolution: 1024\n  train_batch_size: 1\n  num_train_workers: 10\n  num_frames: 8\n  max_num"
  },
  {
    "path": "eval/grounded_sam/sam2/csrc/connected_components.cu",
    "chars": 7808,
    "preview": "// Copyright (c) Meta Platforms, Inc. and affiliates.\n// All rights reserved.\n\n// This source code is licensed under the"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/__init__.py",
    "chars": 197,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/backbones/__init__.py",
    "chars": 197,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/backbones/hieradet.py",
    "chars": 10003,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/backbones/image_encoder.py",
    "chars": 4706,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/backbones/utils.py",
    "chars": 3053,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/memory_attention.py",
    "chars": 5509,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/memory_encoder.py",
    "chars": 5657,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/position_encoding.py",
    "chars": 8361,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/sam/__init__.py",
    "chars": 197,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/sam/mask_decoder.py",
    "chars": 12657,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/sam/prompt_encoder.py",
    "chars": 7016,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/sam/transformer.py",
    "chars": 12870,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/sam2_base.py",
    "chars": 47012,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/modeling/sam2_utils.py",
    "chars": 13173,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/sam2_hiera_b+.yaml",
    "chars": 3548,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/sam2_hiera_l.yaml",
    "chars": 3696,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/sam2_hiera_s.yaml",
    "chars": 3659,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/sam2_hiera_t.yaml",
    "chars": 3753,
    "preview": "# @package _global_\n\n# Model\nmodel:\n  _target_: sam2.modeling.sam2_base.SAM2Base\n  image_encoder:\n    _target_: sam2.mod"
  },
  {
    "path": "eval/grounded_sam/sam2/sam2_image_predictor.py",
    "chars": 19936,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/sam2_video_predictor.py",
    "chars": 58791,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/utils/__init__.py",
    "chars": 197,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/utils/amg.py",
    "chars": 12842,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/utils/misc.py",
    "chars": 13118,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/grounded_sam/sam2/utils/transforms.py",
    "chars": 4962,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "eval/tools/XVerseBench_multi.json",
    "chars": 247061,
    "preview": "[\n    {\n        \"index\": 0,\n        \"input_images\": [],\n        \"position_delta\": [\n            0,\n            -32\n     "
  },
  {
    "path": "eval/tools/XVerseBench_multi_DSG.json",
    "chars": 256462,
    "preview": "[\n    {\n        \"prompt\": \"A man is riding a motorcycle.\",\n        \"qid2tuple\": {\n            \"1\": \"entity - whole\",\n   "
  },
  {
    "path": "eval/tools/XVerseBench_single.json",
    "chars": 62821,
    "preview": "\n[\n    {\n        \"index\": 0,\n        \"input_images\": [],\n        \"position_delta\": [\n            0,\n            -32\n    "
  },
  {
    "path": "eval/tools/XVerseBench_single_DSG.json",
    "chars": 70060,
    "preview": "[\n    {\n        \"prompt\": \"a polar bear standing on iceberg\",\n        \"qid2tuple\": {\n            \"1\": \"entity - whole\",\n"
  },
  {
    "path": "eval/tools/dino.py",
    "chars": 3474,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# License"
  },
  {
    "path": "eval/tools/dpg_score.py",
    "chars": 8321,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/face_id.py",
    "chars": 5144,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/face_utils/face.py",
    "chars": 8758,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/face_utils/face_recg.py",
    "chars": 5989,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/florence_sam.py",
    "chars": 2897,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/idip_aes_score.py",
    "chars": 5557,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/idip_dpg_score.py",
    "chars": 5176,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/idip_face_score.py",
    "chars": 6194,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/idip_gen_split_idip.py",
    "chars": 4744,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/idip_sam-dino_score.py",
    "chars": 6104,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "eval/tools/log_scores.py",
    "chars": 2883,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "inference_single_sample.py",
    "chars": 12222,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "requirements.txt",
    "chars": 536,
    "preview": "aesthetic_predictor_v2_5==2024.12.18.1\ndecord==0.6.0\ndiffusers==0.32.2\neinops==0.8.1\ngradio==5.33.1\nhuggingface_hub==0.2"
  },
  {
    "path": "run_demo.sh",
    "chars": 657,
    "preview": "export FLORENCE2_MODEL_PATH=\"./checkpoints/Florence-2-large\"\nexport SAM2_MODEL_PATH=\"./checkpoints/sam2.1_hiera_large.pt"
  },
  {
    "path": "run_gradio.py",
    "chars": 24425,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/adapters/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/adapters/mod_adapters.py",
    "chars": 8617,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/flux/block.py",
    "chars": 39367,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/flux/condition.py",
    "chars": 4505,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/flux/generate.py",
    "chars": 36665,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/flux/lora_controller.py",
    "chars": 4309,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/flux/pipeline_tools.py",
    "chars": 31188,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/flux/transformer.py",
    "chars": 14443,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright 2024 Black Forest Labs and The HuggingFace Team. A"
  },
  {
    "path": "src/utils/data_utils.py",
    "chars": 14421,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n#\n# Licensed under the Apache License, Version 2.0 (the \"Licen"
  },
  {
    "path": "src/utils/gpu_momory_utils.py",
    "chars": 8846,
    "preview": "import torch\nfrom diffusers import AutoencoderKL, FluxTransformer2DModel\nimport time\n\nclass ForwardHookManager:\n    def "
  },
  {
    "path": "src/utils/modulation_utils.py",
    "chars": 2822,
    "preview": "# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates\n# Copyright (c) Facebook, Inc. All rights reserved.\n#\n# Licens"
  },
  {
    "path": "train/config/XVerse_config_INF.yaml",
    "chars": 2850,
    "preview": "dtype: \"bfloat16\"\n\nmodel:\n  text_cond_attn: false\n  add_cond_attn: false\n\n  union_cond_attn: true\n  double_use_condition"
  },
  {
    "path": "train/config/XVerse_config_demo.yaml",
    "chars": 2582,
    "preview": "dtype: \"bfloat16\"\n\nmodel:\n  text_cond_attn: false\n  add_cond_attn: false\n\n  union_cond_attn: true\n  double_use_condition"
  }
]

About this extraction

This page contains the full source code of the bytedance/XVerse GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 90 files (3.7 MB), approximately 963.7k tokens, and a symbol index with 574 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!