Full Code of rongyaofang/GoT for AI

main 091220916124 cached
15 files
19.6 MB
9.0k tokens
15 symbols
1 requests
Download .txt
Repository: rongyaofang/GoT
Branch: main
Commit: 091220916124
Files: 15
Total size: 19.6 MB

Directory structure:
gitextract_ohc8frlq/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── clm_models/
│   │   ├── agent_got.yaml
│   │   └── llm_qwen25_vl_3b_lora.yaml
│   └── tokenizer/
│       └── qwen25_vl_tokenizer_token64.yaml
├── got/
│   ├── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── got_model.py
│   │   ├── peft_models.py
│   │   ├── projector.py
│   │   └── utils.py
│   └── processer/
│       └── qwen25_vl_processor.py
├── inference.ipynb
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
pretrained/
.DS_Store
.idea/

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2025 Rongyao Fang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing
<div align="center">
<a href="https://github.com/rongyaofang/GoT"><img src="https://img.shields.io/badge/Project-Homepage-green" alt="Home"></a>
<a href="https://arxiv.org/abs/2503.10639"><img src="https://img.shields.io/badge/ArXiv-2503.10639-red"></a>

[Rongyao Fang](https://scholar.google.com/citations?user=FtH3CW4AAAAJ&hl=en)<sup>1\*</sup>, [Chengqi Duan](https://scholar.google.com/citations?user=r9qb4ZwAAAAJ&hl=zh-CN)<sup>2\*</sup>, [Kun Wang]()<sup>3</sup>, [Linjiang Huang](https://leonhlj.github.io/)<sup>6</sup>, [Hao Li](https://scholar.google.com/citations?user=qHqQsY4AAAAJ&hl=zh-CN)<sup>1,4</sup>, [Shilin Yan](https://scholar.google.com/citations?user=2VhjOykAAAAJ&hl=zh-CN), [Hao Tian]()<sup>3</sup>, [Xingyu Zeng]()<sup>3</sup>, [Rui Zhao]()<sup>3</sup>, [Jifeng Dai](https://jifengdai.org/)<sup>4,5</sup>, [Xihui Liu](https://xh-liu.github.io/)<sup>2 :envelope:</sup>, [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)<sup>1 :envelope:</sup>

<sup>1</sup>CUHK MMLab, <sup>2</sup>HKU MMLab, <sup>3</sup>SenseTime, <sup>4</sup>Shanghai AI Laboratory, <sup>5</sup>Tsinghua University, <sup>6</sup>Beihang University

*Equal contribution, :envelope:Corresponding authors
</div>

<div align="center">
  <img src="figures/teaser.jpg" width="100%" alt="GoT Framework" />
</div>
<hr>
<div align="center" style="line-height: 1.2;">
  <a href="https://arxiv.org/abs/2503.10639" target="_blank"><b>Paper</b></a> •
  <a href="#introduction">Introduction</a> •
  <a href="#released-datasets">Datasets</a> •
  <a href="#released-model-got-framework">Model</a> •
  <a href="#results">Results</a> •
  <a href="https://huggingface.co/LucasFang/GoT-6B" target="_blank">🤗 Hugging Face</a> •
  <a href="#license">License</a>
</div>

## 🔥 News

- **[2025-9-19]** 📝 Our GoT paper has been accepted by **NeurIPS 2025**!
- **[2025-9-12]** 🎉 We open-sourced our latest work **FLUX-Reason-6M** dataset! This high-quality text-to-image reasoning dataset was constructed using 15,000 A100 GPU days with FLUX generation. Check it out at [FLUX-Reason-6M](https://github.com/rongyaofang/prism-bench)!

## Introduction

We present **Generation Chain-of-Thought (GoT)**, a novel paradigm that enables generation and editing through an explicit language reasoning process before outputting images. This approach transforms conventional text-to-image generation and editing into a reasoning-guided framework that analyzes semantic relationships and spatial arrangements.

GoT pioneers a new direction for reasoning-driven visual generation and editing, producing images that better align with human intent through:

- **Semantic-Spatial Reasoning**: Integrates both semantic understanding and explicit spatial coordinates
- **Unified Framework**: Handles both image generation and editing with the same architecture

## Released Datasets

| Dataset | Link | Amount |
|---------|------|--------|
| **Laion-Aesthetics-High-Resolution-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/Laion-Aesthetics-High-Resolution-GoT) | 3.77M  |
| **JourneyDB-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/JourneyDB-GoT) | 4.09M  |
| **OmniEdit-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/OmniEdit-GoT) | 736K   |
| **FLUX-Reason-6M** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/FLUX-Reason-6M) | 6M     |

## Dataset Features

### Laion-Aesthetics-High-Resolution-GoT
- 3.77 million High-quality images filtered for sizes larger than 512 pixels from Laion-Aesthetics
- Prompts and GoT descriptions from Qwen2-VL
- Prompts averaging 110.81 characters
- GoT descriptions averaging 811.56 characters
- 3.78 bounding boxes per image on average

### JourneyDB-GoT
- 4.09 million high-quality AI-generated images
- Prompts and GoT descriptions from Qwen2-VL
- Prompts averaging 149.78 characters
- GoT descriptions averaging 906.01 characters
- 4.09 bounding boxes per image on average
- Please download the images from [JourneyDB dataset](https://opendatalab.com/OpenDataLab/JourneyDB/tree/main/raw/JourneyDB/train/imgs)

### OmniEdit-GoT
- 736K high-quality image editing samples from OmniEdit
- Diverse editing operations (addition, removal, swap, attribute changes, style transfer)
- Detailed reasoning chains with step-by-step editing processes
- Precise spatial coordinate annotations for editing regions
- Please download the images from [OmniEdit dataset](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M)

### FLUX-Reason-6M
- 6 million high-quality text-to-image reasoning dataset constructed with pure FLUX generation
- Built using 15,000 A100 GPU days for superior quality and reasoning capabilities
- Comprehensive reasoning chains for complex visual generation tasks
- Designed to enhance multimodal reasoning in visual generation models

## Released Model: GoT Framework

| Model      | Link | Architecture         |
|------------|------|----------------------|
| **GoT-6B** | [🤗 HuggingFace](https://huggingface.co/LucasFang/GoT-6B) | Qwen2.5-VL-3B + SDXL |

## Model Features

<div align="center">
  <img src="figures/architecture.jpg" width="100%" alt="GoT Architecture" />
</div>

Our GoT framework consists of two key components:

1. **Semantic-Spatial MLLM**: Generates detailed reasoning chains with spatial information using Qwen2.5-VL as the backbone
2. **SSGM Diffusion Module**: Leverages the semantic guidance, spatial layouts, and reference images to create high-quality visual outputs

The Semantic-Spatial Guidance Module (SSGM) combines three guidance pathways:
- **Semantic Guidance**: Captures relationships and attributes
- **Spatial Guidance**: Controls precise object placement
- **Reference Guidance**: Provides context for editing tasks

## Results

### Text-to-Image Generation

GoT achieves state-of-the-art performance on the GenEval benchmark, particularly excelling in composition tasks:

<div align="center">

| Method | Architecture | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Attr. Binding |
|--------|--------------|---------|-------------|----------|----------|--------|----------|---------------|
| SD-XL | Unet+CLIP | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 |
| SD3 | MMDIT+CLIP+T5 | 0.62 | 0.98 | 0.74 | 0.63 | 0.67 | 0.34 | 0.36 |
| Emu3-Gen | Autoregressive | 0.54 | 0.98 | 0.71 | 0.34 | 0.81 | 0.17 | 0.21 |
| Janus | Autoregressive | 0.61 | 0.97 | 0.68 | 0.30 | 0.84 | 0.46 | 0.42 |
| JanusFlow | Autoregressive | 0.63 | 0.97 | 0.59 | 0.45 | 0.83 | 0.53 | 0.42 |
| **GoT Framework** | Unet+Qwen2.5-VL | **0.64** | **0.99** | 0.69 | **0.67** | **0.85** | 0.34 | 0.27 |

</div>

### Image Editing

Our approach also demonstrates superior performance on image editing benchmarks:

<div align="center">

| Method | Emu-Edit |  | ImagenHub | Reason-Edit |
|--------|----------|--------|-----------|------------|
|        | CLIP-I   | CLIP-T | GPT-4o Eval. | GPT-4o Eval. |
| IP2P | 0.834 | 0.219 | 0.308 | 0.286 |
| MagicBrush | 0.838 | 0.222 | 0.513 | 0.334 |
| SEED-X | 0.825 | 0.272 | 0.166 | 0.239 |
| CosXL-Edit | 0.860 | 0.274 | 0.464 | 0.325 |
| **GoT Framework** | **0.864** | **0.276** | **0.533** | 0.561 |

</div>

## Usage

### Dependencies
- Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
- [PyTorch >=2.0.1](https://pytorch.org/)
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)

### Installation
Clone the repo and install dependent packages

  ```bash
  git clone git@github.com:rongyaofang/GoT.git
  cd GoT
  pip install -r requirements.txt
  ```

### Model Weights
Place the required model weights in the `./pretrained` directory as follows:

1. GoT-6B model weights
2. [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
3. [Stable Diffusion XL Base 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)

Your directory structure should match the following:

```
GoT
├── pretrained
│   ├── GoT-6B
│   ├── Qwen2.5-VL-3B-Instruct
│   └── stable-diffusion-xl-base-1.0
├── ...
```

### Inference
Follow the instructions in the [inference notebook](https://github.com/rongyaofang/GoT/blob/main/inference.ipynb)

## License

This code is released under the MIT License.

## Citation

If you find this work helpful, please consider citing:

```
@article{fang2025got,
  title={GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing},
  author={Fang, Rongyao and Duan, Chengqi and Wang, Kun and Huang, Linjiang and Li, Hao and Yan, Shilin and Tian, Hao and Zeng, Xingyu and Zhao, Rui and Dai, Jifeng and Liu, Xihui and Li, Hongsheng},
  journal={arXiv preprint arXiv:2503.10639},
  year={2025}
}
```

## Contact

If you have any questions, please raise an issue or contact us at [rongyaofang@gmail.com](mailto:rongyaofang@gmail.com).


================================================
FILE: configs/clm_models/agent_got.yaml
================================================
_target_: got.models.got_model.GenCot.from_pretrained
output_projector:
  _target_: got.models.projector.LinearProjector
  in_hidden_size: 2048
  out_hidden_size: 2048

output_projector_add:
  _target_: got.models.projector.LinearProjector
  in_hidden_size: 2048
  out_hidden_size: 1280

scheduler:
  _target_: diffusers.DDPMScheduler.from_pretrained
  pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0
  subfolder: scheduler

vae:
  _target_: diffusers.AutoencoderKL.from_pretrained
  pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0
  subfolder: vae

unet:
  _target_: diffusers.UNet2DConditionModel.from_pretrained
  pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0
  subfolder: unet

processor:
  _target_: got.processer.qwen25_vl_processor.get_processor
  model_name: pretrained/Qwen2.5-VL-3B-Instruct
  add_gen_token_num: 64

num_img_out_tokens: 64
img_gen_start_id: 151667


================================================
FILE: configs/clm_models/llm_qwen25_vl_3b_lora.yaml
================================================
_target_: got.models.peft_models.get_peft_model_without_resize_embedding
model:
  _target_: transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained
  pretrained_model_name_or_path: pretrained/Qwen2.5-VL-3B-Instruct
peft_config:
  _target_: peft.LoraConfig
  _convert_: object
  r: 32
  lora_alpha: 32
  lora_dropout: 0.05
  target_modules:
    - q_proj
    - v_proj
    - k_proj
    - o_proj
    - gate_proj
    - down_proj
    - up_proj
  modules_to_save:
    - embed_tokens
    - lm_head
    - input_layernorm
    - post_attention_layernorm
  task_type: CAUSAL_LM


================================================
FILE: configs/tokenizer/qwen25_vl_tokenizer_token64.yaml
================================================
_target_: got.processer.qwen25_vl_processor.get_processor
model_name: pretrained/Qwen2.5-VL-3B-Instruct
add_gen_token_num: 64

================================================
FILE: got/__init__.py
================================================


================================================
FILE: got/models/__init__.py
================================================


================================================
FILE: got/models/got_model.py
================================================
import os
import torch
import torch.nn as nn
from PIL import Image, ImageDraw
from torchvision import transforms
from transformers import StoppingCriteriaList
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
from tqdm import tqdm
from .utils import (
    IMG_TOKEN, BOI_TOKEN, EOI_TOKEN, EOS_TOKEN, BOV_TOKEN, EOV_TOKEN, IMG_PAD_TOKEN,
    parse_coordinates_colors, StopOnToken
)


class GenCot(nn.Module):
    def __init__(self, mllm, output_projector, output_projector_add, scheduler, vae, unet, processor,
                 num_img_out_tokens=64, img_gen_start_id=151667, box_start_id=151648, box_end_id=151649) -> None:
        super().__init__()
        self.mllm = mllm  # qwen25-vl model
        self.output_projector = output_projector
        self.vae = vae
        self.unet = unet
        self.scheduler = scheduler
        self.output_projector_add = output_projector_add

        # uses an additional image for conditioning.
        # it uses 12 channels (instead of 4) in the first (conv) layer of the UNet.
        in_channels = 12
        self.unet.register_to_config(in_channels=in_channels)

        with torch.no_grad():
            conv = torch.nn.Conv2d(in_channels, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size,
                                   self.unet.conv_in.stride, self.unet.conv_in.padding)
            conv.weight.zero_()
            conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
            self.unet.conv_in = conv
        self.vae.requires_grad_(False)
        self.vae_batch = 1

        if is_xformers_available():
            import xformers
            unet.enable_xformers_memory_efficient_attention()

        self.img_gen_start_id = img_gen_start_id
        self.num_img_out_tokens = num_img_out_tokens
        self.box_start_id = box_start_id
        self.box_end_id = box_end_id
        self.diffusion_transform = None
        self.source_transform = None
        self.processor = processor

    def _get_add_time_ids(
            self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
    ):
        add_time_ids = list(original_size + crops_coords_top_left + target_size)

        passed_add_embed_dim = (
                self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
        )
        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

        if expected_add_embed_dim != passed_add_embed_dim:
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
            )

        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
        return add_time_ids

    @torch.no_grad()
    def generate(self,
                 text_input,
                 image=None,
                 max_new_tokens=1024,
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 image_guidance_scale=1.0,
                 cond_image_guidance_scale=4.0,
                 height=1024,
                 width=1024,
                 input_token_num=256,
                 do_classifier_free_guidance=True,
                 crops_coords_top_left=(0, 0),
                 prompt_type='t2i',
                 random_seed=42,
                 got_input=None,
                 only_return_got=False,
                 **generate_kwargs
                 ):
        """
        Generate text and optional images from the model.

        Args:
            text_input (str): The input text prompt.
            image (PIL.Image.Image, optional): A single image for Qwen2.5-VL context or editing.
            max_new_tokens (int): Maximum number of tokens to generate.
            num_inference_steps (int): Diffusion steps for stable diffusion.
            guidance_scale (float): CFG scale for stable diffusion.
            image_guidance_scale (float): Image guidance scale for stable diffusion.
            cond_image_guidance_scale (float): Conditional image guidance scale for stable diffusion.
            height (int): Height of the output image.
            width (int): Width of the output image.
            input_token_num (int): Number of image tokens in the input.
            do_classifier_free_guidance (bool): Whether to use classifier-free guidance during inference.
            crops_coords_top_left (Tuple[int, int]): The top-left coordinates of the crops.
            prompt_type (str): The prompt type to use.
            random_seed (int): Random seed for torch.random.
            got_input (Str): The customize got content. For interactive generation only.
            only_return_got (bool): Whether to return the got text for interactive generation.
            generate_kwargs: Additional kwargs for self.mllm.generate().

        Returns:
            A dict with:
                'text': str, the generated text.
                'images': List[PIL.Image.Image], the generated images if any.
        """
        device = next(self.parameters()).device
        vae_dtype = next(self.vae.parameters()).dtype

        if self.diffusion_transform is None:
            self.diffusion_transform = transforms.Compose([
                transforms.Resize((height, width), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])
        if self.source_transform is None:
            self.source_transform = transforms.Resize((448, 448), interpolation=transforms.InterpolationMode.BICUBIC)

        # Generate image tokens
        img_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_TOKEN.format(i)) for i in
                         range(self.num_img_out_tokens)]
        img_token_ids = torch.tensor(img_token_ids, device=device).unsqueeze(0)  # [1, num_img_out_tokens]

        # input image tokens
        input_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_PAD_TOKEN) for _ in
                           range(input_token_num)]
        input_token_ids = torch.tensor(input_token_ids, device=device).unsqueeze(0)  # [1, num_img_out_tokens]

        # Convert BOI_TOKEN to ID
        boi_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOI_TOKEN)
        eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN)
        bov_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOV_TOKEN)

        # Define stopping criteria to stop at BOI_TOKEN
        stopping_criteria = StoppingCriteriaList([
            StopOnToken(boi_token_id), StopOnToken(bov_token_id), StopOnToken(eos_token_id)
        ])
        ori_w, ori_h = image.size if image is not None else (width, height)
        input_images = [self.source_transform(image)] if image is not None else []
        original_images = [image] if image is not None else []
        generated_images = []
        output_text = ''

        if prompt_type == 't2i':
            prompt = f"Follow the caption to generate an image through a chain of thought process: {text_input}"
        elif prompt_type == 'edit':
            prompt = f"Follow the instruction to edit the given image through a chain of thought process: {text_input}"
        else:
            raise ValueError(f"Unknown prompt type {prompt_type}")

        # Prepare the conversation structure for Qwen2.5-VL
        messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]

        # If image is provided, add it to messages
        if image is not None:
            # Insert the image into the content
            messages[0]["content"].insert(0, {"type": "image"})

        # Apply chat template to form the prompt as Qwen2.5-VL expects
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(
            text=[text],
            images=None if not input_images else input_images,
            padding=False,
            return_tensors="pt"
        ).to(device)
        input_ids = inputs.input_ids  # shape: [1, seq_len]

        # if the last token is not EOS_TOKEN, continue generating
        while input_ids[0, -1] != eos_token_id:
            input_length = input_ids.shape[1]
            image_inputs = None if not input_images \
                else self.processor.image_processor(images=input_images, return_tensors="pt").to(device)

            if got_input is None:
                partial_generation = self.mllm.generate(
                    input_ids=input_ids,
                    attention_mask=torch.ones_like(input_ids),
                    pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None,
                    image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None,
                    max_new_tokens=max_new_tokens,
                    return_dict_in_generate=True,
                    output_hidden_states=False,  # No need yet, we will do a second pass
                    stopping_criteria=stopping_criteria,
                    **generate_kwargs
                )

                input_ids = partial_generation['sequences']  # shape: [1, seq_len]
            else:
                input_ids = self.processor.tokenizer.encode(got_input)
                input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
                got_input = None

            if only_return_got:
                return {"got_text": self.processor.tokenizer.decode(input_ids[0])}

            # Decode the newly generated text
            cur_decoded_text = self.processor.tokenizer.decode(input_ids[0, input_length:], skip_special_tokens=False)
            output_text += cur_decoded_text\
                .replace(EOS_TOKEN, '').replace(EOI_TOKEN, '').replace(BOV_TOKEN, '').replace(EOV_TOKEN, '')

            # generate a image
            if input_ids[0, -1] == boi_token_id:
                input_ids = torch.cat([input_ids, img_token_ids], dim=1)  # now includes BOI_TOKEN + image tokens

                second_out = self.mllm(
                    input_ids=input_ids,
                    attention_mask=torch.ones_like(input_ids),
                    pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None,
                    image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None,
                    output_hidden_states=True,
                    return_dict=True
                )
                last_hidden_states = second_out['hidden_states'][-1]  # [batch_size, seq_len, hidden_size]

                img_gen_mask = torch.logical_and(
                    self.img_gen_start_id <= input_ids, input_ids < self.img_gen_start_id + self.num_img_out_tokens)

                gen_hidden_states = last_hidden_states[img_gen_mask].view(-1, self.num_img_out_tokens,
                                                                          last_hidden_states.shape[-1])
                gen_hidden_states = gen_hidden_states[-1:]  # only take the last batch 64 image tokens
                gen_hidden_states = gen_hidden_states.to(self.output_projector.projector.weight.dtype)

                gen_conditioning = self.output_projector(gen_hidden_states)
                gen_conditioning_add = self.output_projector_add(gen_hidden_states)  # [bz, gen_num, dim]
                null_conditioning = self.output_projector(torch.zeros_like(gen_hidden_states))
                gen_conditioning_pooled = torch.mean(gen_conditioning_add, dim=1)

                self.scheduler.set_timesteps(num_inference_steps, device=device)
                timesteps = self.scheduler.timesteps

                # Prepare stable diffusion latents
                generator = torch.Generator(device=device).manual_seed(random_seed)

                latents = randn_tensor(
                    shape=(1, self.vae.config.latent_channels, height // 8, width // 8),
                    generator=generator,
                    device=device,
                    dtype=vae_dtype
                )
                latents = latents * self.scheduler.init_noise_sigma

                # The first 4 are the noisy latents, the next 4 are original image latents (for editing).
                # In tex-to-image generation scenario, we just provide zeros for original_image.
                original_image = original_images[-1] if original_images \
                    else Image.new('RGB', (width, height), (0, 0, 0))

                original_image_tensor = self.diffusion_transform(original_image).unsqueeze(0).to(device).to(vae_dtype)
                image_latents = self.vae.encode(original_image_tensor).latent_dist.mode()

                positions_colors = parse_coordinates_colors(cur_decoded_text)
                mask_num = max(len(positions_colors), 1)

                cond_images = [Image.new('RGB', (width, height), (0, 0, 0)) for _ in range(mask_num)]

                for i in range(len(positions_colors)):
                    p_c = positions_colors[i]
                    draw = ImageDraw.Draw(cond_images[i])
                    position = p_c['position']
                    color = p_c['color']
                    draw.rectangle(((position[0][0] / 1000 * width, position[0][1] / 1000 * height),
                                    (position[1][0] / 1000 * width, position[1][1] / 1000 * height)), fill=color)
                    del draw

                cond_images_tensor = []
                for c_image in cond_images:
                    c_image_tensor = self.diffusion_transform(c_image)
                    cond_images_tensor.append(c_image_tensor)

                # (1, mask_num, 3, target_size, target_size)
                cond_mask = torch.stack(cond_images_tensor, dim=0).unsqueeze(0)
                B, N, C, H, W = cond_mask.shape
                cond_mask = cond_mask.view(B * N, C, H, W)

                unet_cond_embeds = []
                for i in range(0, cond_mask.shape[0], self.vae_batch):
                    sub_batch = cond_mask[i: i + self.vae_batch]
                    embeds = self.vae.encode(sub_batch.to(device, dtype=vae_dtype)).latent_dist.mode()
                    embeds = embeds.to(device)
                    unet_cond_embeds.append(embeds)
                unet_cond_embeds = torch.cat(unet_cond_embeds, dim=0)
                unet_cond_embed = unet_cond_embeds.mean(dim=0, keepdim=True)

                if do_classifier_free_guidance:
                    uncond_image_latents = torch.zeros_like(image_latents)
                    image_latents = torch.cat([image_latents, image_latents, image_latents, uncond_image_latents],
                                              dim=0)

                    uncond_cond_image_latents = torch.zeros_like(unet_cond_embed)
                    unet_cond_embed = torch.cat([unet_cond_embed, uncond_cond_image_latents,
                                                 uncond_cond_image_latents, uncond_cond_image_latents], dim=0)

                combined_prompt_embeds = torch.cat(
                    [gen_conditioning, gen_conditioning, null_conditioning, null_conditioning],
                    dim=0) if do_classifier_free_guidance else gen_conditioning

                text_encoder_projection_dim = int(gen_conditioning_pooled.shape[-1])

                original_size = (height, width)
                target_size = (height, width)

                add_time_ids = self._get_add_time_ids(
                    original_size,
                    crops_coords_top_left,
                    target_size,
                    dtype=combined_prompt_embeds.dtype,
                    text_encoder_projection_dim=text_encoder_projection_dim,
                )

                added_cond_kwargs = {"text_embeds": gen_conditioning_pooled.to(device),
                                     "time_ids": add_time_ids.to(device)}

                for i, t in enumerate(tqdm(timesteps)):
                    latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
                    scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                    scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents, unet_cond_embed],
                                                          dim=1)

                    noise_pred = self.unet(
                        scaled_latent_model_input,
                        t,
                        encoder_hidden_states=combined_prompt_embeds,
                        added_cond_kwargs=added_cond_kwargs,
                        return_dict=False
                    )[0]

                    if do_classifier_free_guidance:
                        noise_pred_cond, noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(4,
                                                                                                                 dim=0)
                        noise_pred = (
                                noise_pred_uncond
                                + guidance_scale * (noise_pred_text - noise_pred_image)
                                + cond_image_guidance_scale * (noise_pred_cond - noise_pred_text)
                                + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
                        )

                    # step through scheduler
                    latents = self.scheduler.step(noise_pred, t, latents, generator=generator, return_dict=False)[0]

                final_latents = latents / self.vae.config.scaling_factor
                image_tensor = self.vae.decode(final_latents, generator=generator).sample
                image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1)
                pil_image = Image.fromarray(
                    (image_tensor[0].permute(1, 2, 0).cpu().float().numpy() * 255).astype("uint8"))

                generated_images.append(pil_image)
                original_images.append(pil_image)
            elif input_ids[0, -1] == bov_token_id:
                input_images.append(self.source_transform(generated_images[-1]))
                input_ids = torch.cat([input_ids, input_token_ids], dim=1)

        # resize generated images with ori_w, and ori_h, with the shortest side being 1024
        if ori_w < ori_h:
            target_size = (width, int(height * ori_h / ori_w))
        else:
            target_size = (int(width * ori_w / ori_h), height)
        generated_images = [img.resize(target_size) for img in generated_images]

        return {"got_text": output_text, "images": generated_images}

    @classmethod
    def from_pretrained(cls, mllm, output_projector, scheduler, vae, unet, pretrained_model_path=None, **kwargs):
        model = cls(mllm=mllm, output_projector=output_projector, scheduler=scheduler, vae=vae, unet=unet, **kwargs)
        if os.environ.get('DEBUG_FLAG', 'False') == 'True':
            return model

        if pretrained_model_path is not None:
            ckpt = torch.load(pretrained_model_path, map_location='cpu')
            logs = model.load_state_dict(ckpt, strict=False)
            print(logs)
        return model


================================================
FILE: got/models/peft_models.py
================================================
import torch
from omegaconf import DictConfig
import hydra
from peft import (
    LoraConfig,
    PeftModel,
    LoraModel,
    PeftModelForCausalLM,
    get_peft_model,
)


def get_peft_model_without_resize_embedding(model, peft_config=None, torch_dtype='bf16'):
    if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
        torch_dtype = torch.bfloat16
    elif torch_dtype == 'fp16' or torch_dtype == 'float16':
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    if isinstance(model, DictConfig):
        model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)

    print('peft config: ', peft_config)
    if isinstance(peft_config, DictConfig):
        peft_config = hydra.utils.instantiate(peft_config)
    peft_model = get_peft_model(model=model, peft_config=peft_config)

    # peft_model.print_trainable_parameters()

    return peft_model


================================================
FILE: got/models/projector.py
================================================
import torch.nn as nn


class LinearProjector(nn.Module):
    def __init__(self, in_hidden_size, out_hidden_size, bias=True):
        super().__init__()
        self.projector = nn.Linear(in_hidden_size, out_hidden_size, bias=bias)

    def forward(self, feature):
        return self.projector(feature)


================================================
FILE: got/models/utils.py
================================================
import re
import torch
from transformers import StoppingCriteria


BOI_TOKEN = '<|im_gen_start|>'
EOI_TOKEN = '<|im_gen_end|>'
IMG_TOKEN = '<|im_gen_{:04d}|>'
EOS_TOKEN = '<|endoftext|>'
BOV_TOKEN = '<|vision_start|>'
EOV_TOKEN = '<|vision_end|>'
IMG_PAD_TOKEN = '<|image_pad|>'


def remove_mismatched_weights(model, pretrained_state_dict):
    own_state = model.state_dict()
    mismatch_keys = []

    for name in list(pretrained_state_dict.keys()):
        if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape:
            mismatch_keys.append(name)
            pretrained_state_dict.pop(name)

    return pretrained_state_dict, mismatch_keys


def parse_coordinates_colors(cot_text):
    """
    Parse bounding box coordinates and their colors from the CoT text.

    Args:
        cot_text (str): Chain of Thought text containing bounding box information.

    Returns:
        list: A list of dictionaries with keys 'x1', 'y1', 'x2', 'y2', and 'color'.
    """
    # Regular expression to match bounding box and color patterns
    pattern = r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|> \((\w+)\)"

    # Parse all matches
    matches = re.findall(pattern, cot_text)

    # Extract bounding box coordinates and colors
    parsed_data = []
    for match in matches:
        x1, y1, x2, y2, color = match
        parsed_data.append({
            'position': [[int(x1), int(y1)], [int(x2), int(y2)]],
            'color': color
        })

    return parsed_data


class StopOnToken(StoppingCriteria):
    def __init__(self, token_id):
        self.token_id = token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # Check if the last generated token is BOI_TOKEN
        return input_ids[0, -1] == self.token_id


================================================
FILE: got/processer/qwen25_vl_processor.py
================================================
from transformers import AutoProcessor


BOI_TOKEN = '<|im_gen_start|>'
EOI_TOKEN = '<|im_gen_end|>'
IMG_TOKEN = '<|im_gen_{:04d}|>'


def get_processor(model_name, add_gen_token_num=64):
    processor = AutoProcessor.from_pretrained(model_name)
    add_token_list = [BOI_TOKEN, EOI_TOKEN]
    for i in range(add_gen_token_num):
        add_token_list.append(IMG_TOKEN.format(i))
    processor.tokenizer.add_tokens(add_token_list, special_tokens=True)
    return processor


================================================
FILE: inference.ipynb
================================================
[File too large to display: 19.6 MB]

================================================
FILE: requirements.txt
================================================
torch==2.0.1
torchvision==0.15.2
hydra-core
omegaconf
transformers==4.49.0
diffusers==0.29.0
sentencepiece
opencv-python
peft==0.13.2
pyrootutils
xformers==0.0.22
accelerate==1.3.0
transformers_stream_generator
tqdm
notebook
numpy==1.21.2
huggingface_hub==0.29.3
Download .txt
gitextract_ohc8frlq/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── clm_models/
│   │   ├── agent_got.yaml
│   │   └── llm_qwen25_vl_3b_lora.yaml
│   └── tokenizer/
│       └── qwen25_vl_tokenizer_token64.yaml
├── got/
│   ├── __init__.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── got_model.py
│   │   ├── peft_models.py
│   │   ├── projector.py
│   │   └── utils.py
│   └── processer/
│       └── qwen25_vl_processor.py
├── inference.ipynb
└── requirements.txt
Download .txt
SYMBOL INDEX (15 symbols across 5 files)

FILE: got/models/got_model.py
  class GenCot (line 16) | class GenCot(nn.Module):
    method __init__ (line 17) | def __init__(self, mllm, output_projector, output_projector_add, sched...
    method _get_add_time_ids (line 53) | def _get_add_time_ids(
    method generate (line 72) | def generate(self,
    method from_pretrained (line 373) | def from_pretrained(cls, mllm, output_projector, scheduler, vae, unet,...

FILE: got/models/peft_models.py
  function get_peft_model_without_resize_embedding (line 13) | def get_peft_model_without_resize_embedding(model, peft_config=None, tor...

FILE: got/models/projector.py
  class LinearProjector (line 4) | class LinearProjector(nn.Module):
    method __init__ (line 5) | def __init__(self, in_hidden_size, out_hidden_size, bias=True):
    method forward (line 9) | def forward(self, feature):

FILE: got/models/utils.py
  function remove_mismatched_weights (line 15) | def remove_mismatched_weights(model, pretrained_state_dict):
  function parse_coordinates_colors (line 27) | def parse_coordinates_colors(cot_text):
  class StopOnToken (line 55) | class StopOnToken(StoppingCriteria):
    method __init__ (line 56) | def __init__(self, token_id):
    method __call__ (line 59) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...

FILE: got/processer/qwen25_vl_processor.py
  function get_processor (line 9) | def get_processor(model_name, add_gen_token_num=64):
Condensed preview — 15 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (37K chars).
[
  {
    "path": ".gitignore",
    "chars": 28,
    "preview": "pretrained/\n.DS_Store\n.idea/"
  },
  {
    "path": "LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) 2025 Rongyao Fang\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "README.md",
    "chars": 8938,
    "preview": "# GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing\n<div align=\""
  },
  {
    "path": "configs/clm_models/agent_got.yaml",
    "chars": 945,
    "preview": "_target_: got.models.got_model.GenCot.from_pretrained\noutput_projector:\n  _target_: got.models.projector.LinearProjector"
  },
  {
    "path": "configs/clm_models/llm_qwen25_vl_3b_lora.yaml",
    "chars": 574,
    "preview": "_target_: got.models.peft_models.get_peft_model_without_resize_embedding\nmodel:\n  _target_: transformers.Qwen2_5_VLForCo"
  },
  {
    "path": "configs/tokenizer/qwen25_vl_tokenizer_token64.yaml",
    "chars": 125,
    "preview": "_target_: got.processer.qwen25_vl_processor.get_processor\nmodel_name: pretrained/Qwen2.5-VL-3B-Instruct\nadd_gen_token_nu"
  },
  {
    "path": "got/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "got/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "got/models/got_model.py",
    "chars": 19437,
    "preview": "import os\nimport torch\nimport torch.nn as nn\nfrom PIL import Image, ImageDraw\nfrom torchvision import transforms\nfrom tr"
  },
  {
    "path": "got/models/peft_models.py",
    "chars": 897,
    "preview": "import torch\nfrom omegaconf import DictConfig\nimport hydra\nfrom peft import (\n    LoraConfig,\n    PeftModel,\n    LoraMod"
  },
  {
    "path": "got/models/projector.py",
    "chars": 304,
    "preview": "import torch.nn as nn\n\n\nclass LinearProjector(nn.Module):\n    def __init__(self, in_hidden_size, out_hidden_size, bias=T"
  },
  {
    "path": "got/models/utils.py",
    "chars": 1828,
    "preview": "import re\nimport torch\nfrom transformers import StoppingCriteria\n\n\nBOI_TOKEN = '<|im_gen_start|>'\nEOI_TOKEN = '<|im_gen_"
  },
  {
    "path": "got/processer/qwen25_vl_processor.py",
    "chars": 473,
    "preview": "from transformers import AutoProcessor\n\n\nBOI_TOKEN = '<|im_gen_start|>'\nEOI_TOKEN = '<|im_gen_end|>'\nIMG_TOKEN = '<|im_g"
  },
  {
    "path": "requirements.txt",
    "chars": 262,
    "preview": "torch==2.0.1\ntorchvision==0.15.2\nhydra-core\nomegaconf\ntransformers==4.49.0\ndiffusers==0.29.0\nsentencepiece\nopencv-python"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the rongyaofang/GoT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 15 files (19.6 MB), approximately 9.0k tokens, and a symbol index with 15 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!