[
  {
    "path": ".gitignore",
    "content": "pretrained/\n.DS_Store\n.idea/"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 Rongyao Fang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing\n<div align=\"center\">\n<a href=\"https://github.com/rongyaofang/GoT\"><img src=\"https://img.shields.io/badge/Project-Homepage-green\" alt=\"Home\"></a>\n<a href=\"https://arxiv.org/abs/2503.10639\"><img src=\"https://img.shields.io/badge/ArXiv-2503.10639-red\"></a>\n\n[Rongyao Fang](https://scholar.google.com/citations?user=FtH3CW4AAAAJ&hl=en)<sup>1\\*</sup>, [Chengqi Duan](https://scholar.google.com/citations?user=r9qb4ZwAAAAJ&hl=zh-CN)<sup>2\\*</sup>, [Kun Wang]()<sup>3</sup>, [Linjiang Huang](https://leonhlj.github.io/)<sup>6</sup>, [Hao Li](https://scholar.google.com/citations?user=qHqQsY4AAAAJ&hl=zh-CN)<sup>1,4</sup>, [Shilin Yan](https://scholar.google.com/citations?user=2VhjOykAAAAJ&hl=zh-CN), [Hao Tian]()<sup>3</sup>, [Xingyu Zeng]()<sup>3</sup>, [Rui Zhao]()<sup>3</sup>, [Jifeng Dai](https://jifengdai.org/)<sup>4,5</sup>, [Xihui Liu](https://xh-liu.github.io/)<sup>2 :envelope:</sup>, [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)<sup>1 :envelope:</sup>\n\n<sup>1</sup>CUHK MMLab, <sup>2</sup>HKU MMLab, <sup>3</sup>SenseTime, <sup>4</sup>Shanghai AI Laboratory, <sup>5</sup>Tsinghua University, <sup>6</sup>Beihang University\n\n*Equal contribution, :envelope:Corresponding authors\n</div>\n\n<div align=\"center\">\n  <img src=\"figures/teaser.jpg\" width=\"100%\" alt=\"GoT Framework\" />\n</div>\n<hr>\n<div align=\"center\" style=\"line-height: 1.2;\">\n  <a href=\"https://arxiv.org/abs/2503.10639\" target=\"_blank\"><b>Paper</b></a> •\n  <a href=\"#introduction\">Introduction</a> •\n  <a href=\"#released-datasets\">Datasets</a> •\n  <a href=\"#released-model-got-framework\">Model</a> •\n  <a href=\"#results\">Results</a> •\n  <a href=\"https://huggingface.co/LucasFang/GoT-6B\" target=\"_blank\">🤗 Hugging Face</a> •\n  <a href=\"#license\">License</a>\n</div>\n\n## 🔥 News\n\n- **[2025-9-19]** 📝 Our GoT paper has been accepted by **NeurIPS 2025**!\n- **[2025-9-12]** 🎉 We open-sourced our latest work **FLUX-Reason-6M** dataset! This high-quality text-to-image reasoning dataset was constructed using 15,000 A100 GPU days with FLUX generation. Check it out at [FLUX-Reason-6M](https://github.com/rongyaofang/prism-bench)!\n\n## Introduction\n\nWe present **Generation Chain-of-Thought (GoT)**, a novel paradigm that enables generation and editing through an explicit language reasoning process before outputting images. This approach transforms conventional text-to-image generation and editing into a reasoning-guided framework that analyzes semantic relationships and spatial arrangements.\n\nGoT pioneers a new direction for reasoning-driven visual generation and editing, producing images that better align with human intent through:\n\n- **Semantic-Spatial Reasoning**: Integrates both semantic understanding and explicit spatial coordinates\n- **Unified Framework**: Handles both image generation and editing with the same architecture\n\n## Released Datasets\n\n| Dataset | Link | Amount |\n|---------|------|--------|\n| **Laion-Aesthetics-High-Resolution-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/Laion-Aesthetics-High-Resolution-GoT) | 3.77M  |\n| **JourneyDB-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/JourneyDB-GoT) | 4.09M  |\n| **OmniEdit-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/OmniEdit-GoT) | 736K   |\n| **FLUX-Reason-6M** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/FLUX-Reason-6M) | 6M     |\n\n## Dataset Features\n\n### Laion-Aesthetics-High-Resolution-GoT\n- 3.77 million High-quality images filtered for sizes larger than 512 pixels from Laion-Aesthetics\n- Prompts and GoT descriptions from Qwen2-VL\n- Prompts averaging 110.81 characters\n- GoT descriptions averaging 811.56 characters\n- 3.78 bounding boxes per image on average\n\n### JourneyDB-GoT\n- 4.09 million high-quality AI-generated images\n- Prompts and GoT descriptions from Qwen2-VL\n- Prompts averaging 149.78 characters\n- GoT descriptions averaging 906.01 characters\n- 4.09 bounding boxes per image on average\n- Please download the images from [JourneyDB dataset](https://opendatalab.com/OpenDataLab/JourneyDB/tree/main/raw/JourneyDB/train/imgs)\n\n### OmniEdit-GoT\n- 736K high-quality image editing samples from OmniEdit\n- Diverse editing operations (addition, removal, swap, attribute changes, style transfer)\n- Detailed reasoning chains with step-by-step editing processes\n- Precise spatial coordinate annotations for editing regions\n- Please download the images from [OmniEdit dataset](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M)\n\n### FLUX-Reason-6M\n- 6 million high-quality text-to-image reasoning dataset constructed with pure FLUX generation\n- Built using 15,000 A100 GPU days for superior quality and reasoning capabilities\n- Comprehensive reasoning chains for complex visual generation tasks\n- Designed to enhance multimodal reasoning in visual generation models\n\n## Released Model: GoT Framework\n\n| Model      | Link | Architecture         |\n|------------|------|----------------------|\n| **GoT-6B** | [🤗 HuggingFace](https://huggingface.co/LucasFang/GoT-6B) | Qwen2.5-VL-3B + SDXL |\n\n## Model Features\n\n<div align=\"center\">\n  <img src=\"figures/architecture.jpg\" width=\"100%\" alt=\"GoT Architecture\" />\n</div>\n\nOur GoT framework consists of two key components:\n\n1. **Semantic-Spatial MLLM**: Generates detailed reasoning chains with spatial information using Qwen2.5-VL as the backbone\n2. **SSGM Diffusion Module**: Leverages the semantic guidance, spatial layouts, and reference images to create high-quality visual outputs\n\nThe Semantic-Spatial Guidance Module (SSGM) combines three guidance pathways:\n- **Semantic Guidance**: Captures relationships and attributes\n- **Spatial Guidance**: Controls precise object placement\n- **Reference Guidance**: Provides context for editing tasks\n\n## Results\n\n### Text-to-Image Generation\n\nGoT achieves state-of-the-art performance on the GenEval benchmark, particularly excelling in composition tasks:\n\n<div align=\"center\">\n\n| Method | Architecture | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Attr. Binding |\n|--------|--------------|---------|-------------|----------|----------|--------|----------|---------------|\n| SD-XL | Unet+CLIP | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 |\n| SD3 | MMDIT+CLIP+T5 | 0.62 | 0.98 | 0.74 | 0.63 | 0.67 | 0.34 | 0.36 |\n| Emu3-Gen | Autoregressive | 0.54 | 0.98 | 0.71 | 0.34 | 0.81 | 0.17 | 0.21 |\n| Janus | Autoregressive | 0.61 | 0.97 | 0.68 | 0.30 | 0.84 | 0.46 | 0.42 |\n| JanusFlow | Autoregressive | 0.63 | 0.97 | 0.59 | 0.45 | 0.83 | 0.53 | 0.42 |\n| **GoT Framework** | Unet+Qwen2.5-VL | **0.64** | **0.99** | 0.69 | **0.67** | **0.85** | 0.34 | 0.27 |\n\n</div>\n\n### Image Editing\n\nOur approach also demonstrates superior performance on image editing benchmarks:\n\n<div align=\"center\">\n\n| Method | Emu-Edit |  | ImagenHub | Reason-Edit |\n|--------|----------|--------|-----------|------------|\n|        | CLIP-I   | CLIP-T | GPT-4o Eval. | GPT-4o Eval. |\n| IP2P | 0.834 | 0.219 | 0.308 | 0.286 |\n| MagicBrush | 0.838 | 0.222 | 0.513 | 0.334 |\n| SEED-X | 0.825 | 0.272 | 0.166 | 0.239 |\n| CosXL-Edit | 0.860 | 0.274 | 0.464 | 0.325 |\n| **GoT Framework** | **0.864** | **0.276** | **0.533** | 0.561 |\n\n</div>\n\n## Usage\n\n### Dependencies\n- Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))\n- [PyTorch >=2.0.1](https://pytorch.org/)\n- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)\n\n### Installation\nClone the repo and install dependent packages\n\n  ```bash\n  git clone git@github.com:rongyaofang/GoT.git\n  cd GoT\n  pip install -r requirements.txt\n  ```\n\n### Model Weights\nPlace the required model weights in the `./pretrained` directory as follows:\n\n1. GoT-6B model weights\n2. [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)\n3. [Stable Diffusion XL Base 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n\nYour directory structure should match the following:\n\n```\nGoT\n├── pretrained\n│   ├── GoT-6B\n│   ├── Qwen2.5-VL-3B-Instruct\n│   └── stable-diffusion-xl-base-1.0\n├── ...\n```\n\n### Inference\nFollow the instructions in the [inference notebook](https://github.com/rongyaofang/GoT/blob/main/inference.ipynb)\n\n## License\n\nThis code is released under the MIT License.\n\n## Citation\n\nIf you find this work helpful, please consider citing:\n\n```\n@article{fang2025got,\n  title={GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing},\n  author={Fang, Rongyao and Duan, Chengqi and Wang, Kun and Huang, Linjiang and Li, Hao and Yan, Shilin and Tian, Hao and Zeng, Xingyu and Zhao, Rui and Dai, Jifeng and Liu, Xihui and Li, Hongsheng},\n  journal={arXiv preprint arXiv:2503.10639},\n  year={2025}\n}\n```\n\n## Contact\n\nIf you have any questions, please raise an issue or contact us at [rongyaofang@gmail.com](mailto:rongyaofang@gmail.com).\n"
  },
  {
    "path": "configs/clm_models/agent_got.yaml",
    "content": "_target_: got.models.got_model.GenCot.from_pretrained\noutput_projector:\n  _target_: got.models.projector.LinearProjector\n  in_hidden_size: 2048\n  out_hidden_size: 2048\n\noutput_projector_add:\n  _target_: got.models.projector.LinearProjector\n  in_hidden_size: 2048\n  out_hidden_size: 1280\n\nscheduler:\n  _target_: diffusers.DDPMScheduler.from_pretrained\n  pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0\n  subfolder: scheduler\n\nvae:\n  _target_: diffusers.AutoencoderKL.from_pretrained\n  pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0\n  subfolder: vae\n\nunet:\n  _target_: diffusers.UNet2DConditionModel.from_pretrained\n  pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0\n  subfolder: unet\n\nprocessor:\n  _target_: got.processer.qwen25_vl_processor.get_processor\n  model_name: pretrained/Qwen2.5-VL-3B-Instruct\n  add_gen_token_num: 64\n\nnum_img_out_tokens: 64\nimg_gen_start_id: 151667\n"
  },
  {
    "path": "configs/clm_models/llm_qwen25_vl_3b_lora.yaml",
    "content": "_target_: got.models.peft_models.get_peft_model_without_resize_embedding\nmodel:\n  _target_: transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained\n  pretrained_model_name_or_path: pretrained/Qwen2.5-VL-3B-Instruct\npeft_config:\n  _target_: peft.LoraConfig\n  _convert_: object\n  r: 32\n  lora_alpha: 32\n  lora_dropout: 0.05\n  target_modules:\n    - q_proj\n    - v_proj\n    - k_proj\n    - o_proj\n    - gate_proj\n    - down_proj\n    - up_proj\n  modules_to_save:\n    - embed_tokens\n    - lm_head\n    - input_layernorm\n    - post_attention_layernorm\n  task_type: CAUSAL_LM\n"
  },
  {
    "path": "configs/tokenizer/qwen25_vl_tokenizer_token64.yaml",
    "content": "_target_: got.processer.qwen25_vl_processor.get_processor\nmodel_name: pretrained/Qwen2.5-VL-3B-Instruct\nadd_gen_token_num: 64"
  },
  {
    "path": "got/__init__.py",
    "content": ""
  },
  {
    "path": "got/models/__init__.py",
    "content": ""
  },
  {
    "path": "got/models/got_model.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nfrom PIL import Image, ImageDraw\nfrom torchvision import transforms\nfrom transformers import StoppingCriteriaList\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom tqdm import tqdm\nfrom .utils import (\n    IMG_TOKEN, BOI_TOKEN, EOI_TOKEN, EOS_TOKEN, BOV_TOKEN, EOV_TOKEN, IMG_PAD_TOKEN,\n    parse_coordinates_colors, StopOnToken\n)\n\n\nclass GenCot(nn.Module):\n    def __init__(self, mllm, output_projector, output_projector_add, scheduler, vae, unet, processor,\n                 num_img_out_tokens=64, img_gen_start_id=151667, box_start_id=151648, box_end_id=151649) -> None:\n        super().__init__()\n        self.mllm = mllm  # qwen25-vl model\n        self.output_projector = output_projector\n        self.vae = vae\n        self.unet = unet\n        self.scheduler = scheduler\n        self.output_projector_add = output_projector_add\n\n        # uses an additional image for conditioning.\n        # it uses 12 channels (instead of 4) in the first (conv) layer of the UNet.\n        in_channels = 12\n        self.unet.register_to_config(in_channels=in_channels)\n\n        with torch.no_grad():\n            conv = torch.nn.Conv2d(in_channels, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size,\n                                   self.unet.conv_in.stride, self.unet.conv_in.padding)\n            conv.weight.zero_()\n            conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)\n            self.unet.conv_in = conv\n        self.vae.requires_grad_(False)\n        self.vae_batch = 1\n\n        if is_xformers_available():\n            import xformers\n            unet.enable_xformers_memory_efficient_attention()\n\n        self.img_gen_start_id = img_gen_start_id\n        self.num_img_out_tokens = num_img_out_tokens\n        self.box_start_id = box_start_id\n        self.box_end_id = box_end_id\n        self.diffusion_transform = None\n        self.source_transform = None\n        self.processor = processor\n\n    def _get_add_time_ids(\n            self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None\n    ):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n                self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    @torch.no_grad()\n    def generate(self,\n                 text_input,\n                 image=None,\n                 max_new_tokens=1024,\n                 num_inference_steps=50,\n                 guidance_scale=7.5,\n                 image_guidance_scale=1.0,\n                 cond_image_guidance_scale=4.0,\n                 height=1024,\n                 width=1024,\n                 input_token_num=256,\n                 do_classifier_free_guidance=True,\n                 crops_coords_top_left=(0, 0),\n                 prompt_type='t2i',\n                 random_seed=42,\n                 got_input=None,\n                 only_return_got=False,\n                 **generate_kwargs\n                 ):\n        \"\"\"\n        Generate text and optional images from the model.\n\n        Args:\n            text_input (str): The input text prompt.\n            image (PIL.Image.Image, optional): A single image for Qwen2.5-VL context or editing.\n            max_new_tokens (int): Maximum number of tokens to generate.\n            num_inference_steps (int): Diffusion steps for stable diffusion.\n            guidance_scale (float): CFG scale for stable diffusion.\n            image_guidance_scale (float): Image guidance scale for stable diffusion.\n            cond_image_guidance_scale (float): Conditional image guidance scale for stable diffusion.\n            height (int): Height of the output image.\n            width (int): Width of the output image.\n            input_token_num (int): Number of image tokens in the input.\n            do_classifier_free_guidance (bool): Whether to use classifier-free guidance during inference.\n            crops_coords_top_left (Tuple[int, int]): The top-left coordinates of the crops.\n            prompt_type (str): The prompt type to use.\n            random_seed (int): Random seed for torch.random.\n            got_input (Str): The customize got content. For interactive generation only.\n            only_return_got (bool): Whether to return the got text for interactive generation.\n            generate_kwargs: Additional kwargs for self.mllm.generate().\n\n        Returns:\n            A dict with:\n                'text': str, the generated text.\n                'images': List[PIL.Image.Image], the generated images if any.\n        \"\"\"\n        device = next(self.parameters()).device\n        vae_dtype = next(self.vae.parameters()).dtype\n\n        if self.diffusion_transform is None:\n            self.diffusion_transform = transforms.Compose([\n                transforms.Resize((height, width), interpolation=transforms.InterpolationMode.BICUBIC),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5])\n            ])\n        if self.source_transform is None:\n            self.source_transform = transforms.Resize((448, 448), interpolation=transforms.InterpolationMode.BICUBIC)\n\n        # Generate image tokens\n        img_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_TOKEN.format(i)) for i in\n                         range(self.num_img_out_tokens)]\n        img_token_ids = torch.tensor(img_token_ids, device=device).unsqueeze(0)  # [1, num_img_out_tokens]\n\n        # input image tokens\n        input_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_PAD_TOKEN) for _ in\n                           range(input_token_num)]\n        input_token_ids = torch.tensor(input_token_ids, device=device).unsqueeze(0)  # [1, num_img_out_tokens]\n\n        # Convert BOI_TOKEN to ID\n        boi_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOI_TOKEN)\n        eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN)\n        bov_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOV_TOKEN)\n\n        # Define stopping criteria to stop at BOI_TOKEN\n        stopping_criteria = StoppingCriteriaList([\n            StopOnToken(boi_token_id), StopOnToken(bov_token_id), StopOnToken(eos_token_id)\n        ])\n        ori_w, ori_h = image.size if image is not None else (width, height)\n        input_images = [self.source_transform(image)] if image is not None else []\n        original_images = [image] if image is not None else []\n        generated_images = []\n        output_text = ''\n\n        if prompt_type == 't2i':\n            prompt = f\"Follow the caption to generate an image through a chain of thought process: {text_input}\"\n        elif prompt_type == 'edit':\n            prompt = f\"Follow the instruction to edit the given image through a chain of thought process: {text_input}\"\n        else:\n            raise ValueError(f\"Unknown prompt type {prompt_type}\")\n\n        # Prepare the conversation structure for Qwen2.5-VL\n        messages = [{\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": prompt}]}]\n\n        # If image is provided, add it to messages\n        if image is not None:\n            # Insert the image into the content\n            messages[0][\"content\"].insert(0, {\"type\": \"image\"})\n\n        # Apply chat template to form the prompt as Qwen2.5-VL expects\n        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n        inputs = self.processor(\n            text=[text],\n            images=None if not input_images else input_images,\n            padding=False,\n            return_tensors=\"pt\"\n        ).to(device)\n        input_ids = inputs.input_ids  # shape: [1, seq_len]\n\n        # if the last token is not EOS_TOKEN, continue generating\n        while input_ids[0, -1] != eos_token_id:\n            input_length = input_ids.shape[1]\n            image_inputs = None if not input_images \\\n                else self.processor.image_processor(images=input_images, return_tensors=\"pt\").to(device)\n\n            if got_input is None:\n                partial_generation = self.mllm.generate(\n                    input_ids=input_ids,\n                    attention_mask=torch.ones_like(input_ids),\n                    pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None,\n                    image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None,\n                    max_new_tokens=max_new_tokens,\n                    return_dict_in_generate=True,\n                    output_hidden_states=False,  # No need yet, we will do a second pass\n                    stopping_criteria=stopping_criteria,\n                    **generate_kwargs\n                )\n\n                input_ids = partial_generation['sequences']  # shape: [1, seq_len]\n            else:\n                input_ids = self.processor.tokenizer.encode(got_input)\n                input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)\n                got_input = None\n\n            if only_return_got:\n                return {\"got_text\": self.processor.tokenizer.decode(input_ids[0])}\n\n            # Decode the newly generated text\n            cur_decoded_text = self.processor.tokenizer.decode(input_ids[0, input_length:], skip_special_tokens=False)\n            output_text += cur_decoded_text\\\n                .replace(EOS_TOKEN, '').replace(EOI_TOKEN, '').replace(BOV_TOKEN, '').replace(EOV_TOKEN, '')\n\n            # generate a image\n            if input_ids[0, -1] == boi_token_id:\n                input_ids = torch.cat([input_ids, img_token_ids], dim=1)  # now includes BOI_TOKEN + image tokens\n\n                second_out = self.mllm(\n                    input_ids=input_ids,\n                    attention_mask=torch.ones_like(input_ids),\n                    pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None,\n                    image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None,\n                    output_hidden_states=True,\n                    return_dict=True\n                )\n                last_hidden_states = second_out['hidden_states'][-1]  # [batch_size, seq_len, hidden_size]\n\n                img_gen_mask = torch.logical_and(\n                    self.img_gen_start_id <= input_ids, input_ids < self.img_gen_start_id + self.num_img_out_tokens)\n\n                gen_hidden_states = last_hidden_states[img_gen_mask].view(-1, self.num_img_out_tokens,\n                                                                          last_hidden_states.shape[-1])\n                gen_hidden_states = gen_hidden_states[-1:]  # only take the last batch 64 image tokens\n                gen_hidden_states = gen_hidden_states.to(self.output_projector.projector.weight.dtype)\n\n                gen_conditioning = self.output_projector(gen_hidden_states)\n                gen_conditioning_add = self.output_projector_add(gen_hidden_states)  # [bz, gen_num, dim]\n                null_conditioning = self.output_projector(torch.zeros_like(gen_hidden_states))\n                gen_conditioning_pooled = torch.mean(gen_conditioning_add, dim=1)\n\n                self.scheduler.set_timesteps(num_inference_steps, device=device)\n                timesteps = self.scheduler.timesteps\n\n                # Prepare stable diffusion latents\n                generator = torch.Generator(device=device).manual_seed(random_seed)\n\n                latents = randn_tensor(\n                    shape=(1, self.vae.config.latent_channels, height // 8, width // 8),\n                    generator=generator,\n                    device=device,\n                    dtype=vae_dtype\n                )\n                latents = latents * self.scheduler.init_noise_sigma\n\n                # The first 4 are the noisy latents, the next 4 are original image latents (for editing).\n                # In tex-to-image generation scenario, we just provide zeros for original_image.\n                original_image = original_images[-1] if original_images \\\n                    else Image.new('RGB', (width, height), (0, 0, 0))\n\n                original_image_tensor = self.diffusion_transform(original_image).unsqueeze(0).to(device).to(vae_dtype)\n                image_latents = self.vae.encode(original_image_tensor).latent_dist.mode()\n\n                positions_colors = parse_coordinates_colors(cur_decoded_text)\n                mask_num = max(len(positions_colors), 1)\n\n                cond_images = [Image.new('RGB', (width, height), (0, 0, 0)) for _ in range(mask_num)]\n\n                for i in range(len(positions_colors)):\n                    p_c = positions_colors[i]\n                    draw = ImageDraw.Draw(cond_images[i])\n                    position = p_c['position']\n                    color = p_c['color']\n                    draw.rectangle(((position[0][0] / 1000 * width, position[0][1] / 1000 * height),\n                                    (position[1][0] / 1000 * width, position[1][1] / 1000 * height)), fill=color)\n                    del draw\n\n                cond_images_tensor = []\n                for c_image in cond_images:\n                    c_image_tensor = self.diffusion_transform(c_image)\n                    cond_images_tensor.append(c_image_tensor)\n\n                # (1, mask_num, 3, target_size, target_size)\n                cond_mask = torch.stack(cond_images_tensor, dim=0).unsqueeze(0)\n                B, N, C, H, W = cond_mask.shape\n                cond_mask = cond_mask.view(B * N, C, H, W)\n\n                unet_cond_embeds = []\n                for i in range(0, cond_mask.shape[0], self.vae_batch):\n                    sub_batch = cond_mask[i: i + self.vae_batch]\n                    embeds = self.vae.encode(sub_batch.to(device, dtype=vae_dtype)).latent_dist.mode()\n                    embeds = embeds.to(device)\n                    unet_cond_embeds.append(embeds)\n                unet_cond_embeds = torch.cat(unet_cond_embeds, dim=0)\n                unet_cond_embed = unet_cond_embeds.mean(dim=0, keepdim=True)\n\n                if do_classifier_free_guidance:\n                    uncond_image_latents = torch.zeros_like(image_latents)\n                    image_latents = torch.cat([image_latents, image_latents, image_latents, uncond_image_latents],\n                                              dim=0)\n\n                    uncond_cond_image_latents = torch.zeros_like(unet_cond_embed)\n                    unet_cond_embed = torch.cat([unet_cond_embed, uncond_cond_image_latents,\n                                                 uncond_cond_image_latents, uncond_cond_image_latents], dim=0)\n\n                combined_prompt_embeds = torch.cat(\n                    [gen_conditioning, gen_conditioning, null_conditioning, null_conditioning],\n                    dim=0) if do_classifier_free_guidance else gen_conditioning\n\n                text_encoder_projection_dim = int(gen_conditioning_pooled.shape[-1])\n\n                original_size = (height, width)\n                target_size = (height, width)\n\n                add_time_ids = self._get_add_time_ids(\n                    original_size,\n                    crops_coords_top_left,\n                    target_size,\n                    dtype=combined_prompt_embeds.dtype,\n                    text_encoder_projection_dim=text_encoder_projection_dim,\n                )\n\n                added_cond_kwargs = {\"text_embeds\": gen_conditioning_pooled.to(device),\n                                     \"time_ids\": add_time_ids.to(device)}\n\n                for i, t in enumerate(tqdm(timesteps)):\n                    latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents\n                    scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                    scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents, unet_cond_embed],\n                                                          dim=1)\n\n                    noise_pred = self.unet(\n                        scaled_latent_model_input,\n                        t,\n                        encoder_hidden_states=combined_prompt_embeds,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False\n                    )[0]\n\n                    if do_classifier_free_guidance:\n                        noise_pred_cond, noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(4,\n                                                                                                                 dim=0)\n                        noise_pred = (\n                                noise_pred_uncond\n                                + guidance_scale * (noise_pred_text - noise_pred_image)\n                                + cond_image_guidance_scale * (noise_pred_cond - noise_pred_text)\n                                + image_guidance_scale * (noise_pred_image - noise_pred_uncond)\n                        )\n\n                    # step through scheduler\n                    latents = self.scheduler.step(noise_pred, t, latents, generator=generator, return_dict=False)[0]\n\n                final_latents = latents / self.vae.config.scaling_factor\n                image_tensor = self.vae.decode(final_latents, generator=generator).sample\n                image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1)\n                pil_image = Image.fromarray(\n                    (image_tensor[0].permute(1, 2, 0).cpu().float().numpy() * 255).astype(\"uint8\"))\n\n                generated_images.append(pil_image)\n                original_images.append(pil_image)\n            elif input_ids[0, -1] == bov_token_id:\n                input_images.append(self.source_transform(generated_images[-1]))\n                input_ids = torch.cat([input_ids, input_token_ids], dim=1)\n\n        # resize generated images with ori_w, and ori_h, with the shortest side being 1024\n        if ori_w < ori_h:\n            target_size = (width, int(height * ori_h / ori_w))\n        else:\n            target_size = (int(width * ori_w / ori_h), height)\n        generated_images = [img.resize(target_size) for img in generated_images]\n\n        return {\"got_text\": output_text, \"images\": generated_images}\n\n    @classmethod\n    def from_pretrained(cls, mllm, output_projector, scheduler, vae, unet, pretrained_model_path=None, **kwargs):\n        model = cls(mllm=mllm, output_projector=output_projector, scheduler=scheduler, vae=vae, unet=unet, **kwargs)\n        if os.environ.get('DEBUG_FLAG', 'False') == 'True':\n            return model\n\n        if pretrained_model_path is not None:\n            ckpt = torch.load(pretrained_model_path, map_location='cpu')\n            logs = model.load_state_dict(ckpt, strict=False)\n            print(logs)\n        return model\n"
  },
  {
    "path": "got/models/peft_models.py",
    "content": "import torch\nfrom omegaconf import DictConfig\nimport hydra\nfrom peft import (\n    LoraConfig,\n    PeftModel,\n    LoraModel,\n    PeftModelForCausalLM,\n    get_peft_model,\n)\n\n\ndef get_peft_model_without_resize_embedding(model, peft_config=None, torch_dtype='bf16'):\n    if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':\n        torch_dtype = torch.bfloat16\n    elif torch_dtype == 'fp16' or torch_dtype == 'float16':\n        torch_dtype = torch.float16\n    else:\n        torch_dtype = torch.float32\n\n    if isinstance(model, DictConfig):\n        model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)\n\n    print('peft config: ', peft_config)\n    if isinstance(peft_config, DictConfig):\n        peft_config = hydra.utils.instantiate(peft_config)\n    peft_model = get_peft_model(model=model, peft_config=peft_config)\n\n    # peft_model.print_trainable_parameters()\n\n    return peft_model\n"
  },
  {
    "path": "got/models/projector.py",
    "content": "import torch.nn as nn\n\n\nclass LinearProjector(nn.Module):\n    def __init__(self, in_hidden_size, out_hidden_size, bias=True):\n        super().__init__()\n        self.projector = nn.Linear(in_hidden_size, out_hidden_size, bias=bias)\n\n    def forward(self, feature):\n        return self.projector(feature)\n"
  },
  {
    "path": "got/models/utils.py",
    "content": "import re\nimport torch\nfrom transformers import StoppingCriteria\n\n\nBOI_TOKEN = '<|im_gen_start|>'\nEOI_TOKEN = '<|im_gen_end|>'\nIMG_TOKEN = '<|im_gen_{:04d}|>'\nEOS_TOKEN = '<|endoftext|>'\nBOV_TOKEN = '<|vision_start|>'\nEOV_TOKEN = '<|vision_end|>'\nIMG_PAD_TOKEN = '<|image_pad|>'\n\n\ndef remove_mismatched_weights(model, pretrained_state_dict):\n    own_state = model.state_dict()\n    mismatch_keys = []\n\n    for name in list(pretrained_state_dict.keys()):\n        if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape:\n            mismatch_keys.append(name)\n            pretrained_state_dict.pop(name)\n\n    return pretrained_state_dict, mismatch_keys\n\n\ndef parse_coordinates_colors(cot_text):\n    \"\"\"\n    Parse bounding box coordinates and their colors from the CoT text.\n\n    Args:\n        cot_text (str): Chain of Thought text containing bounding box information.\n\n    Returns:\n        list: A list of dictionaries with keys 'x1', 'y1', 'x2', 'y2', and 'color'.\n    \"\"\"\n    # Regular expression to match bounding box and color patterns\n    pattern = r\"<\\|box_start\\|>\\((\\d+),(\\d+)\\),\\((\\d+),(\\d+)\\)<\\|box_end\\|> \\((\\w+)\\)\"\n\n    # Parse all matches\n    matches = re.findall(pattern, cot_text)\n\n    # Extract bounding box coordinates and colors\n    parsed_data = []\n    for match in matches:\n        x1, y1, x2, y2, color = match\n        parsed_data.append({\n            'position': [[int(x1), int(y1)], [int(x2), int(y2)]],\n            'color': color\n        })\n\n    return parsed_data\n\n\nclass StopOnToken(StoppingCriteria):\n    def __init__(self, token_id):\n        self.token_id = token_id\n\n    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n        # Check if the last generated token is BOI_TOKEN\n        return input_ids[0, -1] == self.token_id\n"
  },
  {
    "path": "got/processer/qwen25_vl_processor.py",
    "content": "from transformers import AutoProcessor\n\n\nBOI_TOKEN = '<|im_gen_start|>'\nEOI_TOKEN = '<|im_gen_end|>'\nIMG_TOKEN = '<|im_gen_{:04d}|>'\n\n\ndef get_processor(model_name, add_gen_token_num=64):\n    processor = AutoProcessor.from_pretrained(model_name)\n    add_token_list = [BOI_TOKEN, EOI_TOKEN]\n    for i in range(add_gen_token_num):\n        add_token_list.append(IMG_TOKEN.format(i))\n    processor.tokenizer.add_tokens(add_token_list, special_tokens=True)\n    return processor\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==2.0.1\ntorchvision==0.15.2\nhydra-core\nomegaconf\ntransformers==4.49.0\ndiffusers==0.29.0\nsentencepiece\nopencv-python\npeft==0.13.2\npyrootutils\nxformers==0.0.22\naccelerate==1.3.0\ntransformers_stream_generator\ntqdm\nnotebook\nnumpy==1.21.2\nhuggingface_hub==0.29.3"
  }
]